diff --git a/flux.hpp b/flux.hpp index 9826fadee..77a65c557 100644 --- a/flux.hpp +++ b/flux.hpp @@ -748,7 +748,7 @@ namespace Flux { int nerf_depth = 4; int nerf_max_freqs = 8; bool use_x0 = false; - bool use_patch_size_32 = false; + bool fake_patch_size_x2 = false; }; struct FluxParams { @@ -786,8 +786,11 @@ namespace Flux { Flux(FluxParams params) : params(params) { if (params.version == VERSION_CHROMA_RADIANCE) { - std::pair kernel_size = {16, 16}; - std::pair stride = kernel_size; + std::pair kernel_size = {params.patch_size, params.patch_size}; + if (params.chroma_radiance_params.fake_patch_size_x2) { + kernel_size = {params.patch_size / 2, params.patch_size / 2}; + } + std::pair stride = kernel_size; blocks["img_in_patch"] = std::make_shared(params.in_channels, params.hidden_size, @@ -1082,7 +1085,7 @@ namespace Flux { auto img = pad_to_patch_size(ctx, x); auto orig_img = img; - if (params.chroma_radiance_params.use_patch_size_32) { + if (params.chroma_radiance_params.fake_patch_size_x2) { // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch? // img = F.interpolate(img, size=(H//2, W//2), mode="nearest") @@ -1303,7 +1306,8 @@ namespace Flux { flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; } - int64_t head_dim = 0; + int64_t head_dim = 0; + int64_t actual_radiance_patch_size = -1; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (!starts_with(tensor_name, prefix)) @@ -1316,9 +1320,12 @@ namespace Flux { flux_params.chroma_radiance_params.use_x0 = true; } if (tensor_name.find("__32x32__") != std::string::npos) { - LOG_DEBUG("using patch size 32 prediction"); - flux_params.chroma_radiance_params.use_patch_size_32 = true; - flux_params.patch_size = 32; + LOG_DEBUG("using patch size 32"); + flux_params.patch_size = 32; + } + if (tensor_name.find("img_in_patch.weight") != std::string::npos) { + actual_radiance_patch_size = pair.second.ne[0]; + LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size); } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma @@ -1351,6 +1358,11 @@ namespace Flux { head_dim = pair.second.ne[0]; } } + if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) { + GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size); + LOG_DEBUG("using fake x2 patch size"); + flux_params.chroma_radiance_params.fake_patch_size_x2 = true; + } flux_params.num_heads = static_cast(flux_params.hidden_size / head_dim);