From 5a6bc313ce3f53f724ed1e05626ad1ea87d7ad10 Mon Sep 17 00:00:00 2001 From: Oleg Skutte Date: Fri, 16 Jan 2026 18:28:51 +0400 Subject: [PATCH 1/4] Add taef2 support --- tae.hpp | 73 +++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 58 insertions(+), 15 deletions(-) diff --git a/tae.hpp b/tae.hpp index 2cfd0a190..018fd431a 100644 --- a/tae.hpp +++ b/tae.hpp @@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock { protected: int n_in; int n_out; + bool use_midblock_gn; public: - TAEBlock(int n_in, int n_out) - : n_in(n_in), n_out(n_out) { + TAEBlock(int n_in, int n_out, bool use_midblock_gn = false) + : n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) { blocks["conv.0"] = std::shared_ptr(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.2"] = std::shared_ptr(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); blocks["conv.4"] = std::shared_ptr(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1})); if (n_in != n_out) { blocks["skip"] = std::shared_ptr(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false)); } + if (use_midblock_gn) { + int n_gn = n_in * 4; + blocks["pool.0"] = std::shared_ptr(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + blocks["pool.1"] = std::shared_ptr(new GroupNorm(4, n_gn)); + // pool.2 is ReLU, handled in forward + blocks["pool.3"] = std::shared_ptr(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [n, n_in, h, w] // return: [n, n_out, h, w] + if (use_midblock_gn) { + auto pool_0 = std::dynamic_pointer_cast(blocks["pool.0"]); + auto pool_1 = std::dynamic_pointer_cast(blocks["pool.1"]); + auto pool_3 = std::dynamic_pointer_cast(blocks["pool.3"]); + + auto p = pool_0->forward(ctx, x); + p = pool_1->forward(ctx, p); + p = ggml_relu_inplace(ctx->ggml_ctx, p); + p = pool_3->forward(ctx, p); + + x = ggml_add(ctx->ggml_ctx, x, p); + } + auto conv_0 = std::dynamic_pointer_cast(blocks["conv.0"]); auto conv_2 = std::dynamic_pointer_cast(blocks["conv.2"]); auto conv_4 = std::dynamic_pointer_cast(blocks["conv.4"]); @@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock { int num_blocks = 3; public: - TinyEncoder(int z_channels = 4) + TinyEncoder(int z_channels = 4, bool use_midblock_gn = false) : z_channels(z_channels) { int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); @@ -80,7 +101,7 @@ class TinyEncoder : public UnaryBlock { blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); for (int i = 0; i < num_blocks; i++) { - blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); + blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels, use_midblock_gn)); } blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); @@ -105,17 +126,19 @@ class TinyDecoder : public UnaryBlock { int channels = 64; int out_channels = 3; int num_blocks = 3; + int index_offset = 0; public: - TinyDecoder(int z_channels = 4) + TinyDecoder(int z_channels = 4, bool use_midblock_gn = false) : z_channels(z_channels) { - int index = 0; + index_offset = use_midblock_gn ? 1 : 0; + int index = index_offset; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() for (int i = 0; i < num_blocks; i++) { - blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); + blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels, use_midblock_gn)); } index++; // nn.Upsample() blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); @@ -144,9 +167,9 @@ class TinyDecoder : public UnaryBlock { h = ggml_tanh_inplace(ctx->ggml_ctx, h); h = ggml_scale(ctx->ggml_ctx, h, 3.0f); - for (int i = 0; i < num_blocks * 3 + 10; i++) { + for (int i = index_offset; i < num_blocks * 3 + 10 + index_offset; i++) { if (blocks.find(std::to_string(i)) == blocks.end()) { - if (i == 1) { + if (i == 1 + index_offset) { h = ggml_relu_inplace(ctx->ggml_ctx, h); } else { h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST); @@ -470,29 +493,48 @@ class TAEHV : public GGMLBlock { class TAESD : public GGMLBlock { protected: bool decode_only; + bool taef2 = false; public: TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) : decode_only(decode_only) { - int z_channels = 4; + int z_channels = 4; + bool use_midblock_gn = false; + taef2 = sd_version_is_flux2(version); + if (sd_version_is_dit(version)) { z_channels = 16; } - blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels)); + if (taef2) { + z_channels = 32; + use_midblock_gn = true; + } + std::string decoder_key = "decoder.layers"; + if (taef2 && decode_only) { + decoder_key = ""; + } + blocks[decoder_key] = std::shared_ptr(new TinyDecoder(z_channels, use_midblock_gn)); if (!decode_only) { - blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels)); + blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels, use_midblock_gn)); } } struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { - auto decoder = std::dynamic_pointer_cast(blocks["decoder.layers"]); + auto decoder = std::dynamic_pointer_cast(blocks.begin()->second); + if (taef2) { + z = unpatchify(ctx->ggml_ctx, z, 2); + } return decoder->forward(ctx, z); } struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto encoder = std::dynamic_pointer_cast(blocks["encoder.layers"]); - return encoder->forward(ctx, x); + auto z = encoder->forward(ctx, x); + if (taef2) { + z = patchify(ctx->ggml_ctx, z, 2); + } + return z; } }; @@ -522,7 +564,8 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder { : decode_only(decoder_only), taesd(decoder_only, version), TinyAutoEncoder(backend, offload_params_to_cpu) { - taesd.init(params_ctx, tensor_storage_map, prefix); + bool taef2 = sd_version_is_flux2(version); + taesd.init(params_ctx, tensor_storage_map, taef2 ? "" : prefix); } std::string get_desc() override { From a073df27c9029c898a621a179b49c32bce7eb68c Mon Sep 17 00:00:00 2001 From: Oleg Skutte Date: Sun, 18 Jan 2026 23:46:56 +0400 Subject: [PATCH 2/4] Use diffusers format --- tae.hpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tae.hpp b/tae.hpp index 018fd431a..d43925b2e 100644 --- a/tae.hpp +++ b/tae.hpp @@ -509,11 +509,7 @@ class TAESD : public GGMLBlock { z_channels = 32; use_midblock_gn = true; } - std::string decoder_key = "decoder.layers"; - if (taef2 && decode_only) { - decoder_key = ""; - } - blocks[decoder_key] = std::shared_ptr(new TinyDecoder(z_channels, use_midblock_gn)); + blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels, use_midblock_gn)); if (!decode_only) { blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels, use_midblock_gn)); @@ -521,7 +517,7 @@ class TAESD : public GGMLBlock { } struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { - auto decoder = std::dynamic_pointer_cast(blocks.begin()->second); + auto decoder = std::dynamic_pointer_cast(blocks["decoder.layers"]); if (taef2) { z = unpatchify(ctx->ggml_ctx, z, 2); } @@ -564,8 +560,7 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder { : decode_only(decoder_only), taesd(decoder_only, version), TinyAutoEncoder(backend, offload_params_to_cpu) { - bool taef2 = sd_version_is_flux2(version); - taesd.init(params_ctx, tensor_storage_map, taef2 ? "" : prefix); + taesd.init(params_ctx, tensor_storage_map, prefix); } std::string get_desc() override { From 989a51ec6fe614740a5f0ecc71d0f4b429f277d2 Mon Sep 17 00:00:00 2001 From: Oleg Skutte Date: Mon, 19 Jan 2026 01:46:07 +0400 Subject: [PATCH 3/4] Use official weights --- tae.hpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tae.hpp b/tae.hpp index d43925b2e..11a19136b 100644 --- a/tae.hpp +++ b/tae.hpp @@ -126,13 +126,11 @@ class TinyDecoder : public UnaryBlock { int channels = 64; int out_channels = 3; int num_blocks = 3; - int index_offset = 0; public: TinyDecoder(int z_channels = 4, bool use_midblock_gn = false) : z_channels(z_channels) { - index_offset = use_midblock_gn ? 1 : 0; - int index = index_offset; + int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() @@ -167,9 +165,9 @@ class TinyDecoder : public UnaryBlock { h = ggml_tanh_inplace(ctx->ggml_ctx, h); h = ggml_scale(ctx->ggml_ctx, h, 3.0f); - for (int i = index_offset; i < num_blocks * 3 + 10 + index_offset; i++) { + for (int i = 0; i < num_blocks * 3 + 10; i++) { if (blocks.find(std::to_string(i)) == blocks.end()) { - if (i == 1 + index_offset) { + if (i == 1) { h = ggml_relu_inplace(ctx->ggml_ctx, h); } else { h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST); From ea0234b394a6a702e497dcc396cc7c435cd79163 Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 19 Jan 2026 23:29:20 +0800 Subject: [PATCH 4/4] format code --- tae.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tae.hpp b/tae.hpp index 11a19136b..a22db1967 100644 --- a/tae.hpp +++ b/tae.hpp @@ -29,7 +29,7 @@ class TAEBlock : public UnaryBlock { blocks["skip"] = std::shared_ptr(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false)); } if (use_midblock_gn) { - int n_gn = n_in * 4; + int n_gn = n_in * 4; blocks["pool.0"] = std::shared_ptr(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); blocks["pool.1"] = std::shared_ptr(new GroupNorm(4, n_gn)); // pool.2 is ReLU, handled in forward