diff --git a/README.md b/README.md index 4e536880d..89e0b024a 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ API and command-line option may change frequently.*** - SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) - [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md) - [SD3/SD3.5](./docs/sd3.md) - - [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md) - - [FLUX.2-dev](./docs/flux2.md) + - [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md) + - [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md) - [Chroma](./docs/chroma.md) - [Chroma1-Radiance](./docs/chroma_radiance.md) - [Qwen Image](./docs/qwen_image.md) @@ -127,8 +127,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe - [SD1.x/SD2.x/SDXL](./docs/sd.md) - [SD3/SD3.5](./docs/sd3.md) -- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md) -- [FLUX.2-dev](./docs/flux2.md) +- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md) +- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md) - [FLUX.1-Kontext-dev](./docs/kontext.md) - [Chroma](./docs/chroma.md) - [🔥Qwen Image](./docs/qwen_image.md) diff --git a/assets/flux2/flux2-klein-4b-edit.png b/assets/flux2/flux2-klein-4b-edit.png new file mode 100644 index 000000000..481a0a6fd Binary files /dev/null and b/assets/flux2/flux2-klein-4b-edit.png differ diff --git a/assets/flux2/flux2-klein-4b.png b/assets/flux2/flux2-klein-4b.png new file mode 100644 index 000000000..2809752cb Binary files /dev/null and b/assets/flux2/flux2-klein-4b.png differ diff --git a/assets/flux2/flux2-klein-9b-edit.png b/assets/flux2/flux2-klein-9b-edit.png new file mode 100644 index 000000000..41228f1d2 Binary files /dev/null and b/assets/flux2/flux2-klein-9b-edit.png differ diff --git a/assets/flux2/flux2-klein-9b.png b/assets/flux2/flux2-klein-9b.png new file mode 100644 index 000000000..48adea2a9 Binary files /dev/null and b/assets/flux2/flux2-klein-9b.png differ diff --git a/assets/flux2/flux2-klein-base-4b.png b/assets/flux2/flux2-klein-base-4b.png new file mode 100644 index 000000000..f29a123d9 Binary files /dev/null and b/assets/flux2/flux2-klein-base-4b.png differ diff --git a/assets/flux2/flux2-klein-base-9b.png b/assets/flux2/flux2-klein-base-9b.png new file mode 100644 index 000000000..6241f425c Binary files /dev/null and b/assets/flux2/flux2-klein-base-9b.png differ diff --git a/conditioner.hpp b/conditioner.hpp index b6d5646a7..a4e84aa3b 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner { bool enable_vision = false) : version(version) { LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; - if (sd_version_is_flux2(version)) { + if (version == VERSION_FLUX2) { arch = LLM::LLMArch::MISTRAL_SMALL_3_2; - } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) { + } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) { arch = LLM::LLMArch::QWEN3; } if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { @@ -1708,6 +1708,9 @@ struct LLMEmbedder : public Conditioner { int prompt_template_encode_start_idx = 34; int max_length = 0; std::set out_layers; + std::vector tokens; + std::vector weights; + std::vector mask; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { LOG_INFO("QwenImageEditPlusPipeline"); prompt_template_encode_start_idx = 64; @@ -1771,7 +1774,7 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (sd_version_is_flux2(version)) { + } else if (version == VERSION_FLUX2) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; @@ -1793,17 +1796,28 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (sd_version_is_flux2(version)) { + } else if (version == VERSION_FLUX2_KLEIN) { prompt_template_encode_start_idx = 0; - out_layers = {10, 20, 30}; + max_length = 512; + out_layers = {9, 18, 27}; - prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; + prompt = "<|im_start|>user\n"; prompt_attn_range.first = static_cast(prompt.size()); prompt += conditioner_params.text; prompt_attn_range.second = static_cast(prompt.size()); - prompt += "[/INST]"; + prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + + mask.insert(mask.end(), tokens.size(), 1.f); + if (tokens.size() < max_length) { + mask.insert(mask.end(), max_length - tokens.size(), 0.f); + tokenizer->pad_tokens(tokens, weights, max_length, true); + } } else if (version == VERSION_OVIS_IMAGE) { prompt_template_encode_start_idx = 28; max_length = prompt_template_encode_start_idx + 256; @@ -1827,17 +1841,34 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); - auto& tokens = std::get<0>(tokens_and_weights); - auto& weights = std::get<1>(tokens_and_weights); + if (tokens.empty()) { + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + } int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584] auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + ggml_tensor* attention_mask = nullptr; + if (!mask.empty()) { + attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size()); + ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = 0.f; + if (mask[i0] == 0.f) { + value = -INFINITY; + } else if (i0 > i1) { + value = -INFINITY; + } + ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3); + }); + } + llm->compute(n_threads, input_ids, + attention_mask, image_embeds, out_layers, &hidden_states, @@ -1861,7 +1892,7 @@ struct LLMEmbedder : public Conditioner { GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); int64_t min_length = 0; - if (sd_version_is_flux2(version)) { + if (version == VERSION_FLUX2) { min_length = 512; } diff --git a/docs/flux2.md b/docs/flux2.md index 0c2c6d2b7..1524478cc 100644 --- a/docs/flux2.md +++ b/docs/flux2.md @@ -1,6 +1,8 @@ # How to Use -## Download weights +## Flux.2-dev + +### Download weights - Download FLUX.2-dev - gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main @@ -9,7 +11,7 @@ - Download Mistral-Small-3.2-24B-Instruct-2506-GGUF - gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main -## Examples +### Examples ``` .\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu @@ -17,5 +19,74 @@ flux2 example +## Flux.2 klein 4B / Flux.2 klein base 4B + +### Download weights + +- Download FLUX.2-klein-4B + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-4B + - gguf: https://huggingface.co/leejet/FLUX.2-klein-4B-GGUF/tree/main +- Download FLUX.2-klein-base-4B + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B + - gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main +- Download vae + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main +- Download Qwen3 4b + - safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders + - gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main + +### Examples + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa +``` + +flux2-klein-4b + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4 +``` + +flux2-klein-4b-edit + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa +``` + +flux2-klein-base-4b + +## Flux.2 klein 9B / Flux.2 klein base 9B + +### Download weights +- Download FLUX.2-klein-9B + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-9B + - gguf: https://huggingface.co/leejet/FLUX.2-klein-9B-GGUF/tree/main +- Download FLUX.2-klein-base-9B + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B + - gguf: https://huggingface.co/leejet/FLUX.2-klein-base-9B-GGUF/tree/main +- Download vae + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main +- Download Qwen3 8B + - safetensors: https://huggingface.co/Comfy-Org/flux2-klein-9B/tree/main/split_files/text_encoders + - gguf: https://huggingface.co/unsloth/Qwen3-8B-GGUF/tree/main + +### Examples + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa +``` + +flux2-klein-9b + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4 +``` + +flux2-klein-9b-edit + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa +``` +flux2-klein-base-9b \ No newline at end of file diff --git a/flux.hpp b/flux.hpp index 5d94fc85d..9826fadee 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1288,13 +1288,9 @@ namespace Flux { } else if (version == VERSION_OVIS_IMAGE) { flux_params.semantic_txt_norm = true; flux_params.use_yak_mlp = true; - flux_params.context_in_dim = 2048; flux_params.vec_in_dim = 0; } else if (sd_version_is_flux2(version)) { - flux_params.context_in_dim = 15360; flux_params.in_channels = 128; - flux_params.hidden_size = 6144; - flux_params.num_heads = 48; flux_params.patch_size = 1; flux_params.out_channels = 128; flux_params.mlp_ratio = 3.f; @@ -1307,12 +1303,12 @@ namespace Flux { flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; } + int64_t head_dim = 0; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (!starts_with(tensor_name, prefix)) continue; if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) { - // not schnell flux_params.guidance_embed = true; } if (tensor_name.find("__x0__") != std::string::npos) { @@ -1344,13 +1340,30 @@ namespace Flux { flux_params.depth_single_blocks = block_depth + 1; } } + if (ends_with(tensor_name, "txt_in.weight")) { + flux_params.context_in_dim = pair.second.ne[0]; + flux_params.hidden_size = pair.second.ne[1]; + } + if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) { + head_dim = pair.second.ne[0]; + } + if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) { + head_dim = pair.second.ne[0]; + } } - LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); + flux_params.num_heads = static_cast(flux_params.hidden_size / head_dim); + + LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64 + ", hidden_size = %" PRId64 ", num_heads = %d", + flux_params.depth, + flux_params.depth_single_blocks, + flux_params.guidance_embed ? "true" : "false", + flux_params.context_in_dim, + flux_params.hidden_size, + flux_params.num_heads); if (flux_params.is_chroma) { LOG_INFO("Using pruned modulation (Chroma)"); - } else if (!flux_params.guidance_embed) { - LOG_INFO("Flux guidance is disabled (Schnell mode)"); } flux = Flux(flux_params); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 6f498ffa7..9d5ea316b 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1348,7 +1348,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k] auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] - kq = ggml_scale_inplace(ctx, kq, scale); + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + kq = ggml_scale_inplace(ctx, kq, scale); if (mask) { kq = ggml_add_inplace(ctx, kq, mask); } diff --git a/llm.hpp b/llm.hpp index 67b1ea165..781774db7 100644 --- a/llm.hpp +++ b/llm.hpp @@ -837,7 +837,8 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask = nullptr) { // x: [N, n_token, hidden_size] int64_t n_token = x->ne[1]; int64_t N = x->ne[2]; @@ -880,7 +881,7 @@ namespace LLM { k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] return x; @@ -898,7 +899,8 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask = nullptr) { // x: [N, n_token, hidden_size] auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); @@ -907,7 +909,7 @@ namespace LLM { auto residual = x; x = input_layernorm->forward(ctx, x); - x = self_attn->forward(ctx, x, input_pos); + x = self_attn->forward(ctx, x, input_pos, attention_mask); x = ggml_add_inplace(ctx->ggml_ctx, x, residual); residual = x; @@ -936,6 +938,7 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { // input_ids: [N, n_token] @@ -990,7 +993,7 @@ namespace LLM { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = block->forward(ctx, x, input_pos); + x = block->forward(ctx, x, input_pos, attention_mask); if (out_layers.find(i + 1) != out_layers.end()) { intermediate_outputs.push_back(x); } @@ -1036,12 +1039,13 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers); + auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); return x; } @@ -1063,6 +1067,7 @@ namespace LLM { LLM model; std::vector input_pos_vec; + std::vector attention_mask_vec; std::vector window_mask_vec; std::vector window_index_vec; std::vector window_inverse_index_vec; @@ -1157,9 +1162,10 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { - auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size] + auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size] return hidden_states; } @@ -1174,6 +1180,7 @@ namespace LLM { } struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); @@ -1205,9 +1212,26 @@ namespace LLM { input_pos_vec.size()); set_backend_tensor_data(input_pos, input_pos_vec.data()); + if (attention_mask != nullptr) { + attention_mask = to_backend(attention_mask); + } else { + attention_mask_vec.resize(n_tokens * n_tokens); + for (int i0 = 0; i0 < n_tokens; i0++) { + for (int i1 = 0; i1 < n_tokens; i1++) { + float value = 0.f; + if (i0 > i1) { + value = -INFINITY; + } + attention_mask_vec[i1 * n_tokens + i0] = value; + } + } + attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens); + set_backend_tensor_data(attention_mask, attention_mask_vec.data()); + } + auto runner_ctx = get_context(); - struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers); + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); ggml_build_forward_expand(gf, hidden_states); @@ -1216,12 +1240,13 @@ namespace LLM { bool compute(const int n_threads, struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers, ggml_tensor** output, ggml_context* output_ctx = nullptr) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids, image_embeds, out_layers); + return build_graph(input_ids, attention_mask, image_embeds, out_layers); }; return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } @@ -1525,7 +1550,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, image_embeds, {}, &out, work_ctx); + model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1565,7 +1590,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1588,7 +1613,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {35}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1611,7 +1636,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/model.cpp b/model.cpp index e05d31468..c14f255ac 100644 --- a/model.cpp +++ b/model.cpp @@ -1034,6 +1034,8 @@ SDVersion ModelLoader::get_sd_version() { bool is_xl = false; bool is_flux = false; + bool is_flux2 = false; + bool has_single_block_47 = false; bool is_wan = false; int64_t patch_embedding_channels = 0; bool has_img_emb = false; @@ -1055,7 +1057,10 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_QWEN_IMAGE; } if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { - return VERSION_FLUX2; + is_flux2 = true; + } + if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) { + has_single_block_47 = true; } if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { return VERSION_OVIS_IMAGE; @@ -1138,7 +1143,7 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SDXL; } - if (is_flux) { + if (is_flux && !is_flux2) { if (input_block_weight.ne[0] == 384) { return VERSION_FLUX_FILL; } @@ -1151,6 +1156,13 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_FLUX; } + if (is_flux2) { + if (has_single_block_47) { + return VERSION_FLUX2; + } + return VERSION_FLUX2_KLEIN; + } + if (token_embedding_weight.ne[0] == 768) { if (is_inpaint) { return VERSION_SD1_INPAINT; diff --git a/model.h b/model.h index e52766cc0..3f054c46d 100644 --- a/model.h +++ b/model.h @@ -45,6 +45,7 @@ enum SDVersion { VERSION_WAN2_2_TI2V, VERSION_QWEN_IMAGE, VERSION_FLUX2, + VERSION_FLUX2_KLEIN, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_COUNT, @@ -100,7 +101,7 @@ static inline bool sd_version_is_flux(SDVersion version) { } static inline bool sd_version_is_flux2(SDVersion version) { - if (version == VERSION_FLUX2) { + if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 060b85302..2d9b6e6fa 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -48,6 +48,7 @@ const char* model_version_to_str[] = { "Wan 2.2 TI2V", "Qwen Image", "Flux.2", + "Flux.2 klein", "Z-Image", "Ovis Image", };