Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ API and command-line option may change frequently.***
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
- [Chroma](./docs/chroma.md)
- [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md)
Expand Down Expand Up @@ -127,8 +127,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe

- [SD1.x/SD2.x/SDXL](./docs/sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Chroma](./docs/chroma.md)
- [🔥Qwen Image](./docs/qwen_image.md)
Expand Down
Binary file added assets/flux2/flux2-klein-4b-edit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux2/flux2-klein-4b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux2/flux2-klein-9b-edit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux2/flux2-klein-9b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux2/flux2-klein-base-4b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux2/flux2-klein-base-9b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 42 additions & 11 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
bool enable_vision = false)
: version(version) {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
if (version == VERSION_FLUX2) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
Expand Down Expand Up @@ -1708,6 +1708,9 @@ struct LLMEmbedder : public Conditioner {
int prompt_template_encode_start_idx = 34;
int max_length = 0;
std::set<int> out_layers;
std::vector<int> tokens;
std::vector<float> weights;
std::vector<float> mask;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
Expand Down Expand Up @@ -1771,7 +1774,7 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
} else if (version == VERSION_FLUX2) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};

Expand All @@ -1793,17 +1796,28 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
} else if (version == VERSION_FLUX2_KLEIN) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
max_length = 512;
out_layers = {9, 18, 27};

prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt = "<|im_start|>user\n";

prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "[/INST]";
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";

auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);

mask.insert(mask.end(), tokens.size(), 1.f);
if (tokens.size() < max_length) {
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
Expand All @@ -1827,17 +1841,34 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}

auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
if (tokens.empty()) {
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);
}

int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);

ggml_tensor* attention_mask = nullptr;
if (!mask.empty()) {
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (mask[i0] == 0.f) {
value = -INFINITY;
} else if (i0 > i1) {
value = -INFINITY;
}
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
});
}

llm->compute(n_threads,
input_ids,
attention_mask,
image_embeds,
out_layers,
&hidden_states,
Expand All @@ -1861,7 +1892,7 @@ struct LLMEmbedder : public Conditioner {
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);

int64_t min_length = 0;
if (sd_version_is_flux2(version)) {
if (version == VERSION_FLUX2) {
min_length = 512;
}

Expand Down
75 changes: 73 additions & 2 deletions docs/flux2.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# How to Use

## Download weights
## Flux.2-dev

### Download weights

- Download FLUX.2-dev
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
Expand All @@ -9,13 +11,82 @@
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main

## Examples
### Examples

```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
```

<img alt="flux2 example" src="../assets/flux2/example.png" />

## Flux.2 klein 4B / Flux.2 klein base 4B

### Download weights

- Download FLUX.2-klein-4B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-4B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-4B-GGUF/tree/main
- Download FLUX.2-klein-base-4B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download Qwen3 4b
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main

### Examples

```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
```

<img alt="flux2-klein-4b" src="../assets/flux2/flux2-klein-4b.png" />

```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
```

<img alt="flux2-klein-4b-edit" src="../assets/flux2/flux2-klein-4b-edit.png" />

```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
```

<img alt="flux2-klein-base-4b" src="../assets/flux2/flux2-klein-base-4b.png" />

## Flux.2 klein 9B / Flux.2 klein base 9B

### Download weights

- Download FLUX.2-klein-9B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-9B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-9B-GGUF/tree/main
- Download FLUX.2-klein-base-9B
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-9B-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download Qwen3 8B
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-9B/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-8B-GGUF/tree/main

### Examples

```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
```

<img alt="flux2-klein-9b" src="../assets/flux2/flux2-klein-9b.png" />

```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
```

<img alt="flux2-klein-9b-edit" src="../assets/flux2/flux2-klein-9b-edit.png" />

```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
```

<img alt="flux2-klein-base-9b" src="../assets/flux2/flux2-klein-base-9b.png" />
29 changes: 21 additions & 8 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,13 +1288,9 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) {
flux_params.semantic_txt_norm = true;
flux_params.use_yak_mlp = true;
flux_params.context_in_dim = 2048;
flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) {
flux_params.context_in_dim = 15360;
flux_params.in_channels = 128;
flux_params.hidden_size = 6144;
flux_params.num_heads = 48;
flux_params.patch_size = 1;
flux_params.out_channels = 128;
flux_params.mlp_ratio = 3.f;
Expand All @@ -1307,12 +1303,12 @@ namespace Flux {
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
}
int64_t head_dim = 0;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix))
continue;
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("__x0__") != std::string::npos) {
Expand Down Expand Up @@ -1344,13 +1340,30 @@ namespace Flux {
flux_params.depth_single_blocks = block_depth + 1;
}
}
if (ends_with(tensor_name, "txt_in.weight")) {
flux_params.context_in_dim = pair.second.ne[0];
flux_params.hidden_size = pair.second.ne[1];
}
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
}

LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);

LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64
", hidden_size = %" PRId64 ", num_heads = %d",
flux_params.depth,
flux_params.depth_single_blocks,
flux_params.guidance_embed ? "true" : "false",
flux_params.context_in_dim,
flux_params.hidden_size,
flux_params.num_heads);
if (flux_params.is_chroma) {
LOG_INFO("Using pruned modulation (Chroma)");
} else if (!flux_params.guidance_embed) {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
}

flux = Flux(flux_params);
Expand Down
3 changes: 2 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]

auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
kq = ggml_scale_inplace(ctx, kq, scale);
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_scale_inplace(ctx, kq, scale);
if (mask) {
kq = ggml_add_inplace(ctx, kq, mask);
}
Expand Down
Loading
Loading