diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index 232c02288..3174b18f8 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -1,8 +1,8 @@ -# Running distilled models: SSD1B and SDx.x with tiny U-Nets +# Running distilled models: SSD1B, Vega and SDx.x with tiny U-Nets ## Preface -These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1. +These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B and Vega U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1. Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf. ## SSD1B @@ -17,7 +17,17 @@ Useful LoRAs are also available: * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors -These files can be used out-of-the-box, unlike the models described in the next section. +## Vega + +Segmind's Vega model is available online here: + + * https://huggingface.co/segmind/Segmind-Vega/resolve/main/segmind-vega.safetensors + +VegaRT is an example for an LCM-LoRA: + + * https://huggingface.co/segmind/Segmind-VegaRT/resolve/main/pytorch_lora_weights.safetensors + +Both files can be used out-of-the-box, unlike the models described in next sections. ## SD1.x, SD2.x with tiny U-Nets diff --git a/model.cpp b/model.cpp index e05d31468..12cf44c8c 100644 --- a/model.cpp +++ b/model.cpp @@ -1038,6 +1038,7 @@ SDVersion ModelLoader::get_sd_version() { int64_t patch_embedding_channels = 0; bool has_img_emb = false; bool has_middle_block_1 = false; + bool has_output_block_311 = false; bool has_output_block_71 = false; for (auto& [name, tensor_storage] : tensor_storage_map) { @@ -1095,6 +1096,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { has_middle_block_1 = true; } + if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) { + has_output_block_311 = true; + } if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) { has_output_block_71 = true; } @@ -1133,6 +1137,9 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SDXL_PIX2PIX; } if (!has_middle_block_1) { + if (!has_output_block_311) { + return VERSION_SDXL_VEGA; + } return VERSION_SDXL_SSD1B; } return VERSION_SDXL; diff --git a/model.h b/model.h index e52766cc0..536867936 100644 --- a/model.h +++ b/model.h @@ -32,6 +32,7 @@ enum SDVersion { VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, + VERSION_SDXL_VEGA, VERSION_SDXL_SSD1B, VERSION_SVD, VERSION_SD3, @@ -65,7 +66,7 @@ static inline bool sd_version_is_sd2(SDVersion version) { } static inline bool sd_version_is_sdxl(SDVersion version) { - if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 060b85302..feeb6ea4e 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -35,6 +35,7 @@ const char* model_version_to_str[] = { "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", + "SDXL (Vega)", "SDXL (SSD1B)", "SVD", "SD3.x", diff --git a/unet.hpp b/unet.hpp index 9fe24e243..6e15e1f45 100644 --- a/unet.hpp +++ b/unet.hpp @@ -201,6 +201,9 @@ class UnetModelBlock : public GGMLBlock { num_head_channels = 64; num_heads = -1; use_linear_projection = true; + if (version == VERSION_SDXL_VEGA) { + transformer_depth = {1, 1, 2}; + } } else if (version == VERSION_SVD) { in_channels = 8; out_channels = 4; @@ -319,7 +322,7 @@ class UnetModelBlock : public GGMLBlock { } if (!tiny_unet) { blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - if (version != VERSION_SDXL_SSD1B) { + if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) { blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, @@ -520,7 +523,7 @@ class UnetModelBlock : public GGMLBlock { // middle_block if (!tiny_unet) { h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - if (version != VERSION_SDXL_SSD1B) { + if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) { h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] }