From 4533474043fbe10d6641ed85aea75fe1c1cc104a Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Thu, 2 Apr 2026 16:39:42 +0800 Subject: [PATCH 01/17] Add ERNIE-Image --- .../api/models/ernie_image_transformer2d.md | 19 + docs/source/en/api/pipelines/ernie_image.md | 57 +++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_ernie_image.py | 311 ++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/ernie_image/__init__.py | 47 ++ .../ernie_image/pipeline_ernie_image.py | 457 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_ernie_image.py | 199 ++++++++ 12 files changed, 1129 insertions(+) create mode 100644 docs/source/en/api/models/ernie_image_transformer2d.md create mode 100644 docs/source/en/api/pipelines/ernie_image.md create mode 100644 src/diffusers/models/transformers/transformer_ernie_image.py create mode 100644 src/diffusers/pipelines/ernie_image/__init__.py create mode 100644 src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py create mode 100644 tests/models/transformers/test_models_transformer_ernie_image.py diff --git a/docs/source/en/api/models/ernie_image_transformer2d.md b/docs/source/en/api/models/ernie_image_transformer2d.md new file mode 100644 index 000000000000..058616c47814 --- /dev/null +++ b/docs/source/en/api/models/ernie_image_transformer2d.md @@ -0,0 +1,19 @@ + + +# ErnieImageTransformer2DModel + +A Transformer model for image-like data from [Ernie-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo). + +## ErnieImageTransformer2DModel + +[[autodoc]] ErnieImageTransformer2DModel \ No newline at end of file diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md new file mode 100644 index 000000000000..34c1dc4cd489 --- /dev/null +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -0,0 +1,57 @@ + + +# Ernie-Image + +
+ LoRA +
+ +[Ernie-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released: + +|Model|Hugging Face| +|---|---| +|Ernie-Image|https://huggingface.co/Tongyi-MAI/Ernie-Image-Turbo| + +## Ernie-Image + +Ernie-Image-Turbo is a distilled version of Ernie-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. + +## ZImagePipeline + +Use [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt. + +```python +import torch +from diffusers import ErnieImagePipeline +from diffusers.utils import load_image + +pipe = ErnieImagePipeline.from_pretrained("Tongyi-MAI/Ernie-Image-Turbo", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "一只黑白相间的中华田园犬" +images = pipe( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=5.0, + generator=generator, +).images +images[0].save("ernie-image-output.png") +``` + +## ZImagePipeline + +[[autodoc]] ErnieImagePipeline + - all + - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7d966452d1a2..8fea3482d1c3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -292,6 +292,7 @@ "ZImageControlNetModel", "ZImageTransformer2DModel", "attention_backend", + "ErnieImageTransformer2DModel" ] ) _import_structure["modular_pipelines"].extend( @@ -732,6 +733,7 @@ "ZImageInpaintPipeline", "ZImageOmniPipeline", "ZImagePipeline", + "ErnieImagePipeline", ] ) @@ -1079,6 +1081,7 @@ ZImageControlNetModel, ZImageTransformer2DModel, attention_backend, + ErnieImageTransformer2DModel, ) from .modular_pipelines import ( AutoPipelineBlocks, @@ -1493,6 +1496,7 @@ ZImageInpaintPipeline, ZImageOmniPipeline, ZImagePipeline, + ErnieImagePipeline, ) try: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..6c62db841b14 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -101,6 +101,7 @@ _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"] _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] + _import_structure["transformers.transformer_ernie_image"] = ["ErnieImageTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] @@ -218,6 +219,7 @@ DiTTransformer2DModel, DualTransformer2DModel, EasyAnimateTransformer3DModel, + ErnieImageTransformer2DModel, Flux2Transformer2DModel, FluxTransformer2DModel, GlmImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..bf9bd49881e4 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -52,3 +52,4 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel + from .transformer_ernie_image import ErnieImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py new file mode 100644 index 000000000000..63fe3c47b811 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -0,0 +1,311 @@ +# Copyright (c) 2025, Baidu Inc. All rights reserved. +# Author: fengzhida (fengzhida@baidu.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ernie-Image Transformer2DModel for HuggingFace Diffusers. +""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...configuration_utils import ConfigMixin, register_to_config +from ..embeddings import Timesteps +from ..modeling_utils import ModelMixin +from ...utils import BaseOutput + + +@dataclass +class ErnieImageTransformer2DModelOutput(BaseOutput): + sample: torch.Tensor + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta ** scale) + out = torch.einsum("...n,d->...nd", pos, omega) + return out.float() + + +class EmbedND3(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = list(axes_dim) + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) + emb = emb.unsqueeze(1).permute(2, 0, 1, 3) + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) + + +class PatchEmbedDynamic(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + B, D, Hp, Wp = x.shape + return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous() + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.linear_1(sample.to(self.linear_1.weight.dtype)) + return self.linear_2(self.act(sample).to(self.linear_2.weight.dtype)) + + +class RMSNorm(nn.Module): + """RMSNorm implementation matching Megatron's TENorm.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # 内部计算转换为FP32,对齐transform engine的TENorm计算精度 + x_norm = self._norm(x.float()) + output = x_norm * self.weight.float() + return output.to(x.dtype) + + +class Attention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-6, qk_layernorm: bool = True): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + # Separate Q, K, V projections (matches converted weights) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.qk_layernorm = qk_layernorm + if qk_layernorm: + # self.q_layernorm = RMSNorm(self.head_dim, eps=eps) + # self.k_layernorm = RMSNorm(self.head_dim, eps=eps) + self.q_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) + self.k_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) + + def forward(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + S, B, H = x.shape + # Separate Q, K, V projections + q = self.q_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() + k = self.k_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() + v = self.v_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() + if self.qk_layernorm: + q, k = self.q_layernorm(q), self.k_layernorm(k) + q, k = self._apply_rotary(q, rotary_pos_emb), self._apply_rotary(k, rotary_pos_emb) + q, k, v = q.permute(1, 2, 0, 3), k.permute(1, 2, 0, 3), v.permute(1, 2, 0, 3) + attn_mask = ~attention_mask if attention_mask is not None else None + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + return self.linear_proj(out.permute(2, 0, 1, 3).reshape(S, B, H)) + + def _apply_rotary(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply rotary position embedding. + + Matches Megatron's _apply_rotary_pos_emb_bshd with rotary_interleaved=False. + freqs: [S, B, 1, dim] containing angles [θ0, θ0, θ1, θ1, ...] + """ + rot_dim = freqs.shape[-1] + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + + cos_ = torch.cos(freqs).to(x.dtype) + sin_ = torch.sin(freqs).to(x.dtype) + + # Non-interleaved rotate_half: [-x2, x1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + + x = x * cos_ + x_rotated * sin_ + return torch.cat((x, x_pass), dim=-1) + + +class FeedForward(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int): + super().__init__() + # Separate gate and up projections (matches converted weights) + self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + + +class SharedAdaLNBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): + super().__init__() + # self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) + self.adaLN_sa_ln = torch.nn.RMSNorm(hidden_size, eps=eps) + self.self_attention = Attention(hidden_size, num_heads, eps=eps, qk_layernorm=qk_layernorm) + # self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) + self.adaLN_mlp_ln = torch.nn.RMSNorm(hidden_size, eps=eps) + self.mlp = FeedForward(hidden_size, ffn_hidden_size) + + def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None): + residual = x + x = self.adaLN_sa_ln(x) + x = self._modulate(x, shift_msa, scale_msa) + attn_out = self.self_attention(x, rotary_pos_emb, attention_mask) + x = residual + self._apply_gate(gate_msa, attn_out) + residual = x + x = self._modulate(self.adaLN_mlp_ln(x), shift_mlp, scale_mlp) + return residual + self._apply_gate(gate_mlp, self.mlp(x)) + + def _modulate(self, x, shift, scale): + """AdaLN modulation: x * (1 + scale) + shift,在FP32下计算确保数值稳定""" + x_fp32 = x.float() + shift_fp32 = shift.float() + scale_fp32 = scale.float() + out = x_fp32 * (1 + scale_fp32) + shift_fp32 + return out.to(x.dtype) + + def _apply_gate(self, gate, x): + """Gate乘法在FP32下计算,对齐TE精度""" + return (gate.float() * x.float()).to(x.dtype) + +class AdaLNContinuous(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps) + self.linear = nn.Linear(hidden_size, hidden_size * 2) + # 对齐 Megatron 实现:zero init + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(conditioning).chunk(2, dim=-1) + x = self.norm(x) + # Broadcast conditioning to sequence dimension + x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0) + return x + + +class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 3072, + num_attention_heads: int = 24, + num_layers: int = 24, + ffn_hidden_size: int = 8192, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + text_in_dim: int = 2560, + rope_theta: int = 256, + rope_axes_dim: Tuple[int, int, int] = (32, 48, 48), + eps: float = 1e-6, + qk_layernorm: bool = True, + ): + super().__init__() + self.gradient_checkpointing = False + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.num_layers = num_layers + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.text_in_dim = text_in_dim + + self.x_embedder = PatchEmbedDynamic(in_channels, hidden_size, patch_size) + self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) + self.pos_embed = EmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) + self.layers = nn.ModuleList([SharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm) for _ in range(num_layers)]) + self.final_norm = AdaLNContinuous(hidden_size, eps) + self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + nn.init.zeros_(self.final_linear.weight) + nn.init.zeros_(self.final_linear.bias) + + def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: List[torch.Tensor], return_dict: bool = True): + device, dtype = hidden_states.device, hidden_states.dtype + B, C, H, W = hidden_states.shape + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size + N_img = Hp * Wp + + img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous() + text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) + if self.text_proj is not None and text_bth.numel() > 0: + text_bth = self.text_proj(text_bth) + Tmax = text_bth.shape[1] + text_sbh = text_bth.transpose(0, 1).contiguous() + + x = torch.cat([img_sbh, text_sbh], dim=0) + S = x.shape[0] + + # Position IDs + text_ids = torch.cat([torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1), torch.zeros((B, Tmax, 2), device=device)], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device) + grid_yx = torch.stack(torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32), torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"), dim=-1).reshape(-1, 2) + image_ids = torch.cat([text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], dim=-1) + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) + + # Attention mask + valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool) + attention_mask = (~torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1))[:, None, None, :] + + # AdaLN + c = self.time_embedding(self.time_proj(timestep.to(dtype))) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] + for layer in self.layers: + if self.gradient_checkpointing and self.training: + x = self._gradient_checkpointing_func( + layer.__call__, + x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask + ) + else: + x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) + x = self.final_norm(x, c).type_as(x) + patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() + output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) + + return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) + + def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): + B = len(text_hiddens) + if B == 0: + return torch.zeros((0, 0, self.text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) + normalized = [th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens] + lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) + Tmax = int(lens.max().item()) + text_bth = torch.zeros((B, Tmax, self.text_in_dim), device=device, dtype=dtype) + for i, t in enumerate(normalized): + text_bth[i, :t.shape[0], :] = t + return text_bth, lens diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3dafb56fdd65..1985a12846e9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -313,6 +313,7 @@ _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] + _import_structure["ernie_image"] = ["ErnieImagePipeline"] _import_structure["ovis_image"] = ["OvisImagePipeline"] _import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] @@ -750,6 +751,7 @@ from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline from .omnigen import OmniGenPipeline + from .ernie_image import ErnieImagePipeline from .ovis_image import OvisImagePipeline from .pag import ( AnimateDiffPAGPipeline, diff --git a/src/diffusers/pipelines/ernie_image/__init__.py b/src/diffusers/pipelines/ernie_image/__init__.py new file mode 100644 index 000000000000..97355fb609f3 --- /dev/null +++ b/src/diffusers/pipelines/ernie_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_ernie_image"] = ["ErnieImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_ernie_image import ErnieImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py new file mode 100644 index 000000000000..29c1b9fd9bab --- /dev/null +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -0,0 +1,457 @@ +# Copyright (c) 2025, Baidu Inc. All rights reserved. +# Author: fengzhida (fengzhida@baidu.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ernie-Image Pipeline for HuggingFace Diffusers. +""" + +import json +import os +import torch +from PIL import Image +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import BaseOutput +from ...models import AutoencoderKLFlux2 +from ...models.transformers import ErnieImageTransformer2DModel + + +@dataclass +class ErnieImagePipelineOutput(BaseOutput): + images: List[Image.Image] + + +class ErnieImagePipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using ErnieImageTransformer2DModel. + + This pipeline uses: + - A custom DiT transformer model + - A Flux2-style VAE for encoding/decoding latents + - A text encoder (e.g., Qwen) for text conditioning + - Flow Matching Euler Discrete Scheduler + """ + + model_cpu_offload_seq = "pe->text_encoder->transformer->vae" + + def __init__( + self, + transformer, + vae, + text_encoder, + tokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + pe=None, + pe_tokenizer=None, + ): + super().__init__() + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + pe=pe, + pe_tokenizer=pe_tokenizer, + ) + self.vae_scale_factor = 16 # VAE downsample factor + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): + """ + Load pipeline from a pretrained model directory. + + Args: + pretrained_model_name_or_path: Path to the saved pipeline directory + **kwargs: Additional arguments passed to component loaders + - torch_dtype: Data type for model weights (default: torch.bfloat16) + - device_map: Device map for model loading + - trust_remote_code: Whether to trust remote code for text encoder + + Returns: + ErnieImagePipeline instance + """ + + torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) + trust_remote_code = kwargs.pop("trust_remote_code", True) + + # Determine whether this is a local directory or a Hub repo ID. + # For local paths we join sub-directories; for Hub IDs we use `subfolder`. + is_local = os.path.isdir(pretrained_model_name_or_path) + + def _path_or_subfolder(subfolder: str): + if is_local: + return {"pretrained_model_name_or_path": os.path.join(pretrained_model_name_or_path, subfolder)} + return {"pretrained_model_name_or_path": pretrained_model_name_or_path, "subfolder": subfolder} + + # Load transformer + transformer = ErnieImageTransformer2DModel.from_pretrained( + **_path_or_subfolder("transformer"), + torch_dtype=torch_dtype, + ) + + # Load VAE + vae = AutoencoderKLFlux2.from_pretrained( + **_path_or_subfolder("vae"), + torch_dtype=torch_dtype, + ) + + # Load text encoder + text_encoder = AutoModel.from_pretrained( + **_path_or_subfolder("text_encoder"), + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + **_path_or_subfolder("tokenizer"), + trust_remote_code=trust_remote_code, + ) + + # Load PE + pe = AutoModelForCausalLM.from_pretrained( + **_path_or_subfolder("pe"), + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + + # Load PE tokenizer (auto-picks up chat_template.jinja in the same dir) + pe_tokenizer = AutoTokenizer.from_pretrained( + **_path_or_subfolder("pe"), + trust_remote_code=trust_remote_code, + ) + + # Load scheduler + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + **_path_or_subfolder("scheduler"), + ) + + return cls( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + pe=pe, + pe_tokenizer=pe_tokenizer, + scheduler=scheduler, + ) + + @torch.no_grad() + def _enhance_prompt_with_pe( + self, + prompt: str, + device: torch.device, + width: int = 1024, + height: int = 1024, + system_prompt: Optional[str] = None, + max_length: int = 1536, + temperature: float = 0.6, + top_p: float = 0.95, + ) -> str: + """Use PE model to rewrite/enhance a short prompt via chat_template.""" + # Build user message as JSON carrying prompt text and target resolution + user_content = json.dumps( + {"prompt": prompt, "width": width, "height": height}, + ensure_ascii=False, + ) + messages = [] + if system_prompt is not None: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": user_content}) + + # apply_chat_template picks up the chat_template.jinja loaded with pe_tokenizer + input_text = self.pe_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, # "Output:" is already in the user block + ) + # pe_device = next(self.pe.parameters()).device + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) + + output_ids = self.pe.generate( + **inputs, + max_new_tokens=max_length, + do_sample=temperature != 1.0 or top_p != 1.0, + temperature=temperature, + top_p=top_p, + pad_token_id=self.pe_tokenizer.pad_token_id, + eos_token_id=self.pe_tokenizer.eos_token_id, + ) + # Decode only newly generated tokens + generated_ids = output_ids[0][inputs["input_ids"].shape[1]:] + return self.pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + + @torch.no_grad() + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_images_per_prompt: int = 1, + max_length: int = 64, + ) -> List[torch.Tensor]: + """Encode text prompts to embeddings.""" + if isinstance(prompt, str): + prompt = [prompt] + + text_hiddens = [] + + for p in prompt: + ids = self.tokenizer( + p, + add_special_tokens=True, + truncation=True, + max_length=max_length, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if self.tokenizer.bos_token_id is not None: + ids = [self.tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=device) + with torch.no_grad(): + outputs = self.text_encoder( + input_ids=input_ids, + output_hidden_states=True, + ) + # Use second to last hidden state (matches training) + hidden = outputs.hidden_states[-2][0] # [T, H] + + # Repeat for num_images_per_prompt + for _ in range(num_images_per_prompt): + text_hiddens.append(hidden) + + return text_hiddens + + @torch.no_grad() + def _encode_negative_prompt( + self, + negative_prompt: List[str], + device: torch.device, + num_images_per_prompt: int = 1, + max_length: int = 64, + ) -> List[torch.Tensor]: + """Encode negative prompts for CFG.""" + text_hiddens = [] + + for np in negative_prompt: + ids = self.tokenizer( + np, + add_special_tokens=True, + truncation=True, + max_length=max_length, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if self.tokenizer.bos_token_id is not None: + ids = [self.tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=device) + with torch.no_grad(): + outputs = self.text_encoder( + input_ids=input_ids, + output_hidden_states=True, + ) + hidden = outputs.hidden_states[-2][0] + + for _ in range(num_images_per_prompt): + text_hiddens.append(hidden) + + return text_hiddens + + @staticmethod + def _patchify_latents(latents: torch.Tensor) -> torch.Tensor: + """2x2 patchify: [B, 32, H, W] -> [B, 128, H/2, W/2]""" + b, c, h, w = latents.shape + latents = latents.view(b, c, h // 2, 2, w // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + return latents.reshape(b, c * 4, h // 2, w // 2) + + @staticmethod + def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: + """Reverse patchify: [B, 128, H/2, W/2] -> [B, 32, H, W]""" + b, c, h, w = latents.shape + latents = latents.reshape(b, c // 4, 2, 2, h, w) + latents = latents.permute(0, 1, 4, 2, 5, 3) + return latents.reshape(b, c // 4, h * 2, w * 2) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = "", + height: int = 256, + width: int = 256, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + max_length: int = 1536, + use_pe: bool = True, # 默认使用PE进行改写 + ): + """ + Generate images from text prompts. + + Args: + prompt: Text prompt(s) + negative_prompt: Negative prompt(s) for CFG. Default is "". + height: Image height (must be divisible by 16) + width: Image width (must be divisible by 16) + num_inference_steps: Number of denoising steps + guidance_scale: CFG scale (1.0 = no guidance) + num_images_per_prompt: Number of images per prompt + generator: Random generator for reproducibility + latents: Pre-generated latents (optional) + output_type: "pil" or "latent" + return_dict: Whether to return a dataclass + callback: Optional callback function + callback_steps: Steps between callbacks + max_length: Max token length for text encoding + + Returns: + Generated images + """ + device = self._execution_device + dtype = self.transformer.dtype + + self.pe.to(device) + # Validate dimensions + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") + + # Handle prompts + if isinstance(prompt, str): + prompt = [prompt] + + # Enhance prompts with PE if enabled + if use_pe and self.pe is not None and self.pe_tokenizer is not None: + prompt = [ + self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) + for p in prompt + ] + + batch_size = len(prompt) + total_batch_size = batch_size * num_images_per_prompt + + # Handle negative prompt + if negative_prompt is None: + negative_prompt = "" + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") + + # Encode prompts + text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) + + # CFG with negative prompt + do_cfg = guidance_scale > 1.0 + if do_cfg: + uncond_text_hiddens = self._encode_negative_prompt( + negative_prompt, device, num_images_per_prompt, max_length + ) + + # Latent dimensions + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + latent_channels = 128 # After patchify + + # Initialize latents + if latents is None: + latents = torch.randn( + (total_batch_size, latent_channels, latent_h, latent_w), + device=device, + dtype=dtype, + generator=generator, + ) + + # Setup scheduler + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # Denoising loop + if do_cfg: + cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens) + else: + cfg_text_hiddens = text_hiddens + + for i, t in enumerate(self.scheduler.timesteps): + if do_cfg: + latent_model_input = torch.cat([latents, latents], dim=0) + t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) + else: + latent_model_input = latents + t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) + + # Model prediction + pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_batch, + encoder_hidden_states=cfg_text_hiddens, + return_dict=False, + )[0] + + # Apply CFG + if do_cfg: + pred_uncond, pred_cond = pred.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + # Scheduler step + latents = self.scheduler.step(pred, t, latents).prev_sample + + # Callback + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + return latents + + # Decode latents to images + # Unnormalize latents using VAE's BN stats + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device) + latents = latents * bn_std + bn_mean + + # Unpatchify + latents = self._unpatchify_latents(latents) + + # Decode + images = self.vae.decode(latents, return_dict=False)[0] + + # Post-process + images = (images.clamp(-1, 1) + 1) / 2 + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + images = [Image.fromarray((img * 255).astype("uint8")) for img in images] + + if not return_dict: + return (images,) + + return ErnieImagePipelineOutput(images=images) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index fa37388fe75a..2a5b2bd6b8a2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1016,6 +1016,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ErnieImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Flux2Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1e4d14566160..aa1c9fbb3c10 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2582,6 +2582,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ErnieImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class OvisImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py new file mode 100644 index 000000000000..7ef855609ed8 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import unittest + +import torch + +from diffusers import ErnieImageTransformer2DModel + +from ...testing_utils import IS_GITHUB_ACTIONS, torch_device +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +# Ernie-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations +# Cannot use enable_full_determinism() which sets it to True +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(False) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +if hasattr(torch.backends, "cuda"): + torch.backends.cuda.matmul.allow_tf32 = False + + +@unittest.skipIf( + IS_GITHUB_ACTIONS, + reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.", +) +class ErnieImageTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = ErnieImageTransformer2DModel + main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.9, 0.9, 0.9] + + def prepare_dummy_input(self, height=16, width=16): + batch_size = 1 + num_channels = 16 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = [ + torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size) + ] + timestep = torch.tensor([1.0]).to(torch_device) + + return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep} + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 16, 16) + + @property + def output_shape(self): + return (16, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "hidden_size": 16, + "num_attention_heads": 1, + "num_layers": 1, + "ffn_hidden_size": 16, + "in_channels": 16, + "out_channels": 16, + "patch_size": 1, + "text_in_dim": 16, + "rope_theta": 256, + "rope_axes_dim": (8, 4, 4), + "eps": 1e-6, + "qk_layernorm": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def tearDown(self): + super().tearDown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"ErnieImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_training(self): + super().test_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_ema_training(self): + super().test_ema_training() + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing() + + @unittest.skip( + "Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices." + ) + def test_layerwise_casting_training(self): + super().test_layerwise_casting_training() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_layerwise_casting_memory(self): + super().test_layerwise_casting_memory() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_group_offloading_with_layerwise_casting(self): + super().test_group_offloading_with_layerwise_casting() + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_group_offloading_with_layerwise_casting_0(self): + pass + + @unittest.skip( + "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." + ) + def test_group_offloading_with_layerwise_casting_1(self): + pass + + @unittest.skip("Test is not supported for handling main inputs that are lists.") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") + def test_group_offloading(self): + super().test_group_offloading() + + @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") + def test_group_offloading_with_disk(self): + super().test_group_offloading_with_disk() + + +class ErnieImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = ErnieImageTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def prepare_init_args_and_inputs_for_common(self): + return ErnieImageTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return ErnieImageTransformerTests().prepare_dummy_input(height=height, width=width) + + @unittest.skip( + "The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice." + ) + def test_torch_compile_recompilation_and_graph_break(self): + super().test_torch_compile_recompilation_and_graph_break() + + @unittest.skip("Fullgraph AoT is broken") + def test_compile_works_with_aot(self): + super().test_compile_works_with_aot() + + @unittest.skip("Fullgraph is broken") + def test_compile_on_different_shapes(self): + super().test_compile_on_different_shapes() From 4049a2072901ea38e9ad664c36944df6fec79f02 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Thu, 2 Apr 2026 17:51:24 +0800 Subject: [PATCH 02/17] Update doc --- .../api/models/ernie_image_transformer2d.md | 2 +- docs/source/en/api/pipelines/ernie_image.md | 24 +++++++------------ 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/docs/source/en/api/models/ernie_image_transformer2d.md b/docs/source/en/api/models/ernie_image_transformer2d.md index 058616c47814..8be37d56bf42 100644 --- a/docs/source/en/api/models/ernie_image_transformer2d.md +++ b/docs/source/en/api/models/ernie_image_transformer2d.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # ErnieImageTransformer2DModel -A Transformer model for image-like data from [Ernie-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo). +A Transformer model for image-like data from [ERNIE-Image-Turbo](https://huggingface.co/baidu/ERNIE-Image-Turbo). ## ErnieImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 34c1dc4cd489..6162713f2c75 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -16,26 +16,26 @@ specific language governing permissions and limitations under the License. LoRA -[Ernie-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released: +[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only one model with two more to be released: |Model|Hugging Face| |---|---| -|Ernie-Image|https://huggingface.co/Tongyi-MAI/Ernie-Image-Turbo| +|ERNIE-Image-Turbo|https://huggingface.co/baidu/ERNIE-Image-Turbo| ## Ernie-Image -Ernie-Image-Turbo is a distilled version of Ernie-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. +ERNIE-Image-Turbo is a distilled version of ERNIE-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. -## ZImagePipeline +## ErnieImagePipeline -Use [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt. +Use [`ErnieImagePipeline`] to generate an image based on a text prompt. ```python import torch from diffusers import ErnieImagePipeline from diffusers.utils import load_image -pipe = ErnieImagePipeline.from_pretrained("Tongyi-MAI/Ernie-Image-Turbo", torch_dtype=torch.bfloat16) +pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "一只黑白相间的中华田园犬" @@ -43,15 +43,9 @@ images = pipe( prompt=prompt, height=1024, width=1024, - num_inference_steps=50, + num_inference_steps=8, guidance_scale=5.0, generator=generator, ).images -images[0].save("ernie-image-output.png") -``` - -## ZImagePipeline - -[[autodoc]] ErnieImagePipeline - - all - - __call__ +images[0].save("ernie-image-turbo-output.png") +``` \ No newline at end of file From 579e6c7f6642508796a09f8d69bf3cb6ef9195cb Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Thu, 2 Apr 2026 18:53:56 +0800 Subject: [PATCH 03/17] Update doc --- docs/source/en/api/pipelines/ernie_image.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 6162713f2c75..5bb6f550096f 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -28,7 +28,7 @@ ERNIE-Image-Turbo is a distilled version of ERNIE-Image that matches or exceeds ## ErnieImagePipeline -Use [`ErnieImagePipeline`] to generate an image based on a text prompt. +Use [`ErnieImagePipeline`] to generate an image based on a text prompt. If you do not want to use PE, please set use_pe=False. ```python import torch From d16d16e9b2b755659e2bbf08509a5b60620d10d9 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Fri, 3 Apr 2026 22:57:21 +0800 Subject: [PATCH 04/17] Change from Custom-Attention to Diffusers Style Attention --- .../transformers/transformer_ernie_image.py | 187 +++++++++--------- .../ernie_image/pipeline_ernie_image.py | 1 + 2 files changed, 100 insertions(+), 88 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 63fe3c47b811..7a63e6d6818d 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -28,6 +28,9 @@ from ..embeddings import Timesteps from ..modeling_utils import ModelMixin from ...utils import BaseOutput +from ..normalization import RMSNorm +from ..attention_processor import Attention +from ..attention_dispatch import dispatch_attention_fn @dataclass @@ -52,8 +55,8 @@ def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): def forward(self, ids: torch.Tensor) -> torch.Tensor: emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) - emb = emb.unsqueeze(1).permute(2, 0, 1, 3) - return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) + emb = emb.unsqueeze(2) # [B, S, 1, head_dim//2] + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] class PatchEmbedDynamic(nn.Module): @@ -76,78 +79,84 @@ def __init__(self, in_channels: int, time_embed_dim: int): self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = self.linear_1(sample.to(self.linear_1.weight.dtype)) - return self.linear_2(self.act(sample).to(self.linear_2.weight.dtype)) + sample = sample.to(self.linear_1.weight.dtype) + return self.linear_2(self.act(self.linear_1(sample))) -class RMSNorm(nn.Module): - """RMSNorm implementation matching Megatron's TENorm.""" - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor) -> torch.Tensor: - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # 内部计算转换为FP32,对齐transform engine的TENorm计算精度 - x_norm = self._norm(x.float()) - output = x_norm * self.weight.float() - return output.to(x.dtype) +class ErnieImageSingleStreamAttnProcessor: + _attention_backend = None + _parallel_config = None + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) -class Attention(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-6, qk_layernorm: bool = True): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - # Separate Q, K, V projections (matches converted weights) - self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) - self.qk_layernorm = qk_layernorm - if qk_layernorm: - # self.q_layernorm = RMSNorm(self.head_dim, eps=eps) - # self.k_layernorm = RMSNorm(self.head_dim, eps=eps) - self.q_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) - self.k_layernorm = torch.nn.RMSNorm(self.head_dim, eps=eps) - - def forward(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - S, B, H = x.shape - # Separate Q, K, V projections - q = self.q_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() - k = self.k_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() - v = self.v_proj(x).view(S, B, self.num_heads, self.head_dim).contiguous() - if self.qk_layernorm: - q, k = self.q_layernorm(q), self.k_layernorm(k) - q, k = self._apply_rotary(q, rotary_pos_emb), self._apply_rotary(k, rotary_pos_emb) - q, k, v = q.permute(1, 2, 0, 3), k.permute(1, 2, 0, 3), v.permute(1, 2, 0, 3) - attn_mask = ~attention_mask if attention_mask is not None else None - out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) - return self.linear_proj(out.permute(2, 0, 1, 3).reshape(S, B, H)) - - def _apply_rotary(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - """Apply rotary position embedding. - - Matches Megatron's _apply_rotary_pos_emb_bshd with rotary_interleaved=False. - freqs: [S, B, 1, dim] containing angles [θ0, θ0, θ1, θ1, ...] - """ - rot_dim = freqs.shape[-1] - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - - cos_ = torch.cos(freqs).to(x.dtype) - sin_ = torch.sin(freqs).to(x.dtype) - - # Non-interleaved rotate_half: [-x2, x1] - x1, x2 = x.chunk(2, dim=-1) - x_rotated = torch.cat((-x2, x1), dim=-1) - - x = x * cos_ + x_rotated * sin_ - return torch.cat((x, x_pass), dim=-1) + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + freqs_cis: torch.Tensor | None = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE: same rotate_half logic as Megatron _apply_rotary_pos_emb_bshd (rotary_interleaved=False) + # x_in: [B, S, heads, head_dim], freqs_cis: [B, S, 1, head_dim] with angles [θ0,θ0,θ1,θ1,...] + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = torch.cos(freqs_cis).to(x.dtype) + sin_ = torch.sin(freqs_cis).to(x.dtype) + # Non-interleaved rotate_half: [-x2, x1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + output = attn.to_out[0](hidden_states) + + return output class FeedForward(nn.Module): @@ -161,22 +170,31 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) - class SharedAdaLNBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): super().__init__() - # self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) - self.adaLN_sa_ln = torch.nn.RMSNorm(hidden_size, eps=eps) - self.self_attention = Attention(hidden_size, num_heads, eps=eps, qk_layernorm=qk_layernorm) - # self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) - self.adaLN_mlp_ln = torch.nn.RMSNorm(hidden_size, eps=eps) + self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) + self.self_attention = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=hidden_size // num_heads, + heads=num_heads, + qk_norm="rms_norm" if qk_layernorm else None, + eps=eps, + bias=False, + out_bias=False, + processor=ErnieImageSingleStreamAttnProcessor(), + ) + self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) self.mlp = FeedForward(hidden_size, ffn_hidden_size) def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None): residual = x x = self.adaLN_sa_ln(x) x = self._modulate(x, shift_msa, scale_msa) - attn_out = self.self_attention(x, rotary_pos_emb, attention_mask) + x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first) + attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, freqs_cis=rotary_pos_emb) + attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H] x = residual + self._apply_gate(gate_msa, attn_out) residual = x x = self._modulate(self.adaLN_mlp_ln(x), shift_mlp, scale_mlp) @@ -231,7 +249,6 @@ def __init__( qk_layernorm: bool = True, ): super().__init__() - self.gradient_checkpointing = False self.hidden_size = hidden_size self.num_heads = num_attention_heads self.head_dim = hidden_size // num_attention_heads @@ -277,26 +294,20 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h image_ids = torch.cat([text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], dim=-1) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) - # Attention mask + # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool) - attention_mask = (~torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1))[:, None, None, :] + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[:, None, None, :] # AdaLN c = self.time_embedding(self.time_proj(timestep.to(dtype))) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] for layer in self.layers: - if self.gradient_checkpointing and self.training: - x = self._gradient_checkpointing_func( - layer.__call__, - x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask - ) - else: - x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) + x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) x = self.final_norm(x, c).type_as(x) patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) - return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) + return Text2ImgDiTTransformer2DModelOutput(sample=output) if return_dict else (output,) def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): B = len(text_hiddens) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 29c1b9fd9bab..242017d1caf7 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -19,6 +19,7 @@ import json import os +import numpy as np import torch from PIL import Image from dataclasses import dataclass From 9cbbf5d92b69b43034210979355b402aedbbc84a Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Fri, 3 Apr 2026 23:03:08 +0800 Subject: [PATCH 05/17] Change from Custom-Attention to Diffusers Style Attention --- src/diffusers/models/transformers/transformer_ernie_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 7a63e6d6818d..a0a0fc08ab31 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -90,7 +90,7 @@ class ErnieImageSingleStreamAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + "ErnieImageSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." ) def __call__( @@ -307,7 +307,7 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) - return Text2ImgDiTTransformer2DModelOutput(sample=output) if return_dict else (output,) + return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): B = len(text_hiddens) From 9fca91205f75f4a7914d18e20cd538b0aace36ba Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Tue, 7 Apr 2026 17:13:22 +0800 Subject: [PATCH 06/17] =?UTF-8?q?=E5=85=BC=E5=AE=B9SGLang?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pipelines/ernie_image/pipeline_ernie_image.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 242017d1caf7..56bb1af87444 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -50,6 +50,8 @@ class ErnieImagePipeline(DiffusionPipeline): """ model_cpu_offload_seq = "pe->text_encoder->transformer->vae" + # For SGLang fallback ... + _optional_components = ["pe", "pe_tokenizer"] def __init__( self, @@ -350,8 +352,8 @@ def __call__( # Handle prompts if isinstance(prompt, str): prompt = [prompt] - - # Enhance prompts with PE if enabled + + # [Phase 1] PE: enhance prompts, then offload to CPU if use_pe and self.pe is not None and self.pe_tokenizer is not None: prompt = [ self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) @@ -369,7 +371,7 @@ def __call__( if len(negative_prompt) != batch_size: raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") - # Encode prompts + # [Phase 2] Text encoding, then offload text_encoder to CPU text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) # CFG with negative prompt From 465f00979b3170b71ab4a3844c4047723f578ea3 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Tue, 7 Apr 2026 19:24:58 +0800 Subject: [PATCH 07/17] =?UTF-8?q?=E4=BC=98=E5=8C=96PE=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E7=9A=84=E5=8A=A0=E8=BD=BD=E4=B8=8Eoffload=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/en/api/pipelines/ernie_image.md | 2 + .../ernie_image/pipeline_ernie_image.py | 95 +++++++++++-------- 2 files changed, 60 insertions(+), 37 deletions(-) diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 5bb6f550096f..4f2ad62bdd5c 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -37,6 +37,8 @@ from diffusers.utils import load_image pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16) pipe.to("cuda") +# 如果显存不足,可以开启offload +pipe.enable_model_cpu_offload() prompt = "一只黑白相间的中华田园犬" images = pipe( diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 56bb1af87444..7ee8293d896e 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -52,6 +52,7 @@ class ErnieImagePipeline(DiffusionPipeline): model_cpu_offload_seq = "pe->text_encoder->transformer->vae" # For SGLang fallback ... _optional_components = ["pe", "pe_tokenizer"] + _callback_tensor_inputs = ["latents"] def __init__( self, @@ -93,6 +94,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) trust_remote_code = kwargs.pop("trust_remote_code", True) + device_map = kwargs.pop("device_map", None) # Determine whether this is a local directory or a Hub repo ID. # For local paths we join sub-directories; for Hub IDs we use `subfolder`. @@ -133,6 +135,8 @@ def _path_or_subfolder(subfolder: str): **_path_or_subfolder("pe"), torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + **({"device_map": device_map} if device_map else {}), ) # Load PE tokenizer (auto-picks up chat_template.jinja in the same dir) @@ -185,8 +189,13 @@ def _enhance_prompt_with_pe( tokenize=False, add_generation_prompt=False, # "Output:" is already in the user block ) - # pe_device = next(self.pe.parameters()).device - inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) + # When accelerate offload hooks are installed, use the hook's execution_device + # to ensure inputs land on the same device as the model weights during forward() + if hasattr(self.pe, "_hf_hook") and hasattr(self.pe._hf_hook, "execution_device"): + pe_device = self.pe._hf_hook.execution_device + else: + pe_device = device + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(pe_device) output_ids = self.pe.generate( **inputs, @@ -314,8 +323,8 @@ def __call__( latents: Optional[torch.Tensor] = None, output_type: str = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, + callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_length: int = 1536, use_pe: bool = True, # 默认使用PE进行改写 ): @@ -334,8 +343,12 @@ def __call__( latents: Pre-generated latents (optional) output_type: "pil" or "latent" return_dict: Whether to return a dataclass - callback: Optional callback function - callback_steps: Steps between callbacks + callback_on_step_end: Optional callback invoked at the end of each denoising step. + Called as `callback_on_step_end(pipeline, step, timestep, callback_kwargs)` where + `callback_kwargs` contains the tensors listed in `callback_on_step_end_tensor_inputs`. + The callback may return a dict to override those tensors for subsequent steps. + callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs. + Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`). max_length: Max token length for text encoding Returns: @@ -344,7 +357,6 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype - self.pe.to(device) # Validate dimensions if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") @@ -353,7 +365,7 @@ def __call__( if isinstance(prompt, str): prompt = [prompt] - # [Phase 1] PE: enhance prompts, then offload to CPU + # [Phase 1] PE: enhance prompts if use_pe and self.pe is not None and self.pe_tokenizer is not None: prompt = [ self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) @@ -371,7 +383,7 @@ def __call__( if len(negative_prompt) != batch_size: raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") - # [Phase 2] Text encoding, then offload text_encoder to CPU + # [Phase 2] Text encoding text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) # CFG with negative prompt @@ -396,7 +408,8 @@ def __call__( ) # Setup scheduler - self.scheduler.set_timesteps(num_inference_steps, device=device) + sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1) + self.scheduler.set_timesteps(sigmas=sigmas[:-1], device=device) # Denoising loop if do_cfg: @@ -404,33 +417,38 @@ def __call__( else: cfg_text_hiddens = text_hiddens - for i, t in enumerate(self.scheduler.timesteps): - if do_cfg: - latent_model_input = torch.cat([latents, latents], dim=0) - t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) - else: - latent_model_input = latents - t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) - - # Model prediction - pred = self.transformer( - hidden_states=latent_model_input, - timestep=t_batch, - encoder_hidden_states=cfg_text_hiddens, - return_dict=False, - )[0] - - # Apply CFG - if do_cfg: - pred_uncond, pred_cond = pred.chunk(2, dim=0) - pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) - - # Scheduler step - latents = self.scheduler.step(pred, t, latents).prev_sample - - # Callback - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(self.scheduler.timesteps): + if do_cfg: + latent_model_input = torch.cat([latents, latents], dim=0) + t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) + else: + latent_model_input = latents + t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) + + # Model prediction + pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_batch, + encoder_hidden_states=cfg_text_hiddens, + return_dict=False, + )[0] + + # Apply CFG + if do_cfg: + pred_uncond, pred_cond = pred.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + # Scheduler step + latents = self.scheduler.step(pred, t, latents).prev_sample + + # Callback + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + progress_bar.update() if output_type == "latent": return latents @@ -454,6 +472,9 @@ def __call__( if output_type == "pil": images = [Image.fromarray((img * 255).astype("uint8")) for img in images] + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (images,) From 6afd5342bb2fec21af31f75a12f59fb4cc0a20a4 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Wed, 8 Apr 2026 11:56:09 +0800 Subject: [PATCH 08/17] =?UTF-8?q?=E6=9B=B4=E6=96=B0Doc=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=8Econfig=E9=85=8D=E7=BD=AE=E7=9B=B8=E5=85=B3=E5=86=85?= =?UTF-8?q?=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api/models/ernie_image_transformer2d.md | 2 + docs/source/en/api/pipelines/ernie_image.md | 41 ++++++++++++++-- .../transformers/transformer_ernie_image.py | 3 +- .../ernie_image/pipeline_ernie_image.py | 48 +++++++------------ .../pipelines/ernie_image/pipeline_output.py | 36 ++++++++++++++ 5 files changed, 94 insertions(+), 36 deletions(-) create mode 100644 src/diffusers/pipelines/ernie_image/pipeline_output.py diff --git a/docs/source/en/api/models/ernie_image_transformer2d.md b/docs/source/en/api/models/ernie_image_transformer2d.md index 8be37d56bf42..9fe03090577f 100644 --- a/docs/source/en/api/models/ernie_image_transformer2d.md +++ b/docs/source/en/api/models/ernie_image_transformer2d.md @@ -12,6 +12,8 @@ specific language governing permissions and limitations under the License. # ErnieImageTransformer2DModel +A Transformer model for image-like data from [ERNIE-Image](https://huggingface.co/baidu/ERNIE-Image). + A Transformer model for image-like data from [ERNIE-Image-Turbo](https://huggingface.co/baidu/ERNIE-Image-Turbo). ## ErnieImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 4f2ad62bdd5c..69c0234d4cbf 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -16,19 +16,51 @@ specific language governing permissions and limitations under the License. LoRA -[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only one model with two more to be released: +[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only two models to be released: |Model|Hugging Face| |---|---| +|ERNIE-Image|https://huggingface.co/baidu/ERNIE-Image| |ERNIE-Image-Turbo|https://huggingface.co/baidu/ERNIE-Image-Turbo| -## Ernie-Image +## ERNIE-Image -ERNIE-Image-Turbo is a distilled version of ERNIE-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. +ERNIE-Image is designed with a relatively compact architecture and solid instruction-following capability, emphasizing parameter efficiency. Based on an 8B DiT backbone, it provides performance that is comparable in some scenarios to larger (20B+) models, while maintaining reasonable parameter efficiency. It offers a relatively stable level of performance in instruction understanding and execution, text generation (e.g., English / Chinese / Japanese), and overall stability. + +## ERNIE-Image-Turbo + +ERNIE-Image-Turbo is a distilled variant of ERNIE-Image, requiring only 8 NFEs (Number of Function Evaluations) and offering a more efficient alternative with relatively comparable performance to the full model in certain cases. ## ErnieImagePipeline -Use [`ErnieImagePipeline`] to generate an image based on a text prompt. If you do not want to use PE, please set use_pe=False. +Use [ErnieImagePipeline] to generate images from text prompts. The pipeline supports Prompt Enhancer (PE) by default, which enhances the user’s raw prompt to improve output quality, though it may reduce instruction-following accuracy. + +We provide a pretrained 3B-parameter PE model; however, using larger language models (e.g., Gemini or ChatGPT) for prompt enhancement may yield better results. The system prompt template is available at: https://huggingface.co/baidu/ERNIE-Image/blob/main/pe/chat_template.jinja. + +If you prefer not to use PE, set use_pe=False. + +```python +import torch +from diffusers import ErnieImagePipeline +from diffusers.utils import load_image + +pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image", torch_dtype=torch.bfloat16) +pipe.to("cuda") +# 如果显存不足,可以开启offload +pipe.enable_model_cpu_offload() + +prompt = "一只黑白相间的中华田园犬" +images = pipe( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=5.0, + generator=generator, + use_pe=True, +).images +images[0].save("ernie-image-output.png") +``` ```python import torch @@ -48,6 +80,7 @@ images = pipe( num_inference_steps=8, guidance_scale=5.0, generator=generator, + use_pe=True, ).images images[0].save("ernie-image-turbo-output.png") ``` \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index a0a0fc08ab31..f92d65eb40b5 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -1,5 +1,4 @@ -# Copyright (c) 2025, Baidu Inc. All rights reserved. -# Author: fengzhida (fengzhida@baidu.com) +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 7ee8293d896e..4f79a4501059 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -1,5 +1,4 @@ -# Copyright (c) 2025, Baidu Inc. All rights reserved. -# Author: fengzhida (fengzhida@baidu.com) +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,20 +21,14 @@ import numpy as np import torch from PIL import Image -from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import BaseOutput from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel - - -@dataclass -class ErnieImagePipelineOutput(BaseOutput): - images: List[Image.Image] +from .pipeline_output import ErnieImagePipelineOutput class ErnieImagePipeline(DiffusionPipeline): @@ -168,7 +161,6 @@ def _enhance_prompt_with_pe( width: int = 1024, height: int = 1024, system_prompt: Optional[str] = None, - max_length: int = 1536, temperature: float = 0.6, top_p: float = 0.95, ) -> str: @@ -196,10 +188,9 @@ def _enhance_prompt_with_pe( else: pe_device = device inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(pe_device) - output_ids = self.pe.generate( **inputs, - max_new_tokens=max_length, + max_new_tokens=self.pe_tokenizer.model_max_length, do_sample=temperature != 1.0 or top_p != 1.0, temperature=temperature, top_p=top_p, @@ -216,7 +207,6 @@ def encode_prompt( prompt: Union[str, List[str]], device: torch.device, num_images_per_prompt: int = 1, - max_length: int = 64, ) -> List[torch.Tensor]: """Encode text prompts to embeddings.""" if isinstance(prompt, str): @@ -229,7 +219,6 @@ def encode_prompt( p, add_special_tokens=True, truncation=True, - max_length=max_length, padding=False, )["input_ids"] @@ -260,7 +249,6 @@ def _encode_negative_prompt( negative_prompt: List[str], device: torch.device, num_images_per_prompt: int = 1, - max_length: int = 64, ) -> List[torch.Tensor]: """Encode negative prompts for CFG.""" text_hiddens = [] @@ -270,7 +258,6 @@ def _encode_negative_prompt( np, add_special_tokens=True, truncation=True, - max_length=max_length, padding=False, )["input_ids"] @@ -314,10 +301,10 @@ def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = "", - height: int = 256, - width: int = 256, + height: int = 1024, + width: int = 1024, num_inference_steps: int = 50, - guidance_scale: float = 5.0, + guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, @@ -325,7 +312,6 @@ def __call__( return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_length: int = 1536, use_pe: bool = True, # 默认使用PE进行改写 ): """ @@ -334,10 +320,10 @@ def __call__( Args: prompt: Text prompt(s) negative_prompt: Negative prompt(s) for CFG. Default is "". - height: Image height (must be divisible by 16) - width: Image width (must be divisible by 16) + height: Image height in pixels (must be divisible by 16). Default: 1024. + width: Image width in pixels (must be divisible by 16). Default: 1024. num_inference_steps: Number of denoising steps - guidance_scale: CFG scale (1.0 = no guidance) + guidance_scale: CFG scale (1.0 = no guidance). Default: 4.0. num_images_per_prompt: Number of images per prompt generator: Random generator for reproducibility latents: Pre-generated latents (optional) @@ -349,10 +335,10 @@ def __call__( The callback may return a dict to override those tensors for subsequent steps. callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs. Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`). - max_length: Max token length for text encoding + use_pe: Whether to use the PE model to enhance prompts before generation. Returns: - Generated images + :class:`ErnieImagePipelineOutput` with `images` and `revised_prompts`. """ device = self._execution_device dtype = self.transformer.dtype @@ -366,11 +352,13 @@ def __call__( prompt = [prompt] # [Phase 1] PE: enhance prompts + revised_prompts: Optional[List[str]] = None if use_pe and self.pe is not None and self.pe_tokenizer is not None: prompt = [ - self._enhance_prompt_with_pe(p, device, width=width, height=height, max_length=max_length) + self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt ] + revised_prompts = list(prompt) batch_size = len(prompt) total_batch_size = batch_size * num_images_per_prompt @@ -384,13 +372,13 @@ def __call__( raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") # [Phase 2] Text encoding - text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt, max_length) + text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt) # CFG with negative prompt do_cfg = guidance_scale > 1.0 if do_cfg: uncond_text_hiddens = self._encode_negative_prompt( - negative_prompt, device, num_images_per_prompt, max_length + negative_prompt, device, num_images_per_prompt ) # Latent dimensions @@ -478,4 +466,4 @@ def __call__( if not return_dict: return (images,) - return ErnieImagePipelineOutput(images=images) + return ErnieImagePipelineOutput(images=images, revised_prompts=revised_prompts) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_output.py b/src/diffusers/pipelines/ernie_image/pipeline_output.py new file mode 100644 index 000000000000..8919db0c0aca --- /dev/null +++ b/src/diffusers/pipelines/ernie_image/pipeline_output.py @@ -0,0 +1,36 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional + +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class ErnieImagePipelineOutput(BaseOutput): + """ + Output class for ERNIE-Image pipelines. + + Args: + images (`List[PIL.Image.Image]`): + List of generated images. + revised_prompts (`List[str]`, *optional*): + List of PE-revised prompts. `None` when PE is disabled or unavailable. + """ + + images: List[PIL.Image.Image] + revised_prompts: Optional[List[str]] From b360596fa8933d59abf4edc91f036807ee6bbe61 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Wed, 8 Apr 2026 21:18:17 +0800 Subject: [PATCH 09/17] =?UTF-8?q?Fix=E5=AE=98=E6=96=B9=E5=8F=8D=E9=A6=88?= =?UTF-8?q?=E7=9A=84=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transformers/transformer_ernie_image.py | 134 +++++++++++++---- .../ernie_image/pipeline_ernie_image.py | 136 +++++------------- 2 files changed, 138 insertions(+), 132 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index f92d65eb40b5..e87995d57bfb 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -17,6 +17,7 @@ """ import math +import inspect from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -25,11 +26,13 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ..embeddings import Timesteps +from ..embeddings import TimestepEmbedding from ..modeling_utils import ModelMixin from ...utils import BaseOutput from ..normalization import RMSNorm from ..attention_processor import Attention from ..attention_dispatch import dispatch_attention_fn +from ..attention import AttentionMixin, AttentionModuleMixin @dataclass @@ -45,7 +48,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: return out.float() -class EmbedND3(nn.Module): +class ErnieImageEmbedND3(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): super().__init__() self.dim = dim @@ -70,18 +73,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous() -class TimestepEmbedding(nn.Module): - def __init__(self, in_channels: int, time_embed_dim: int): - super().__init__() - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = sample.to(self.linear_1.weight.dtype) - return self.linear_2(self.act(self.linear_1(sample))) - - class ErnieImageSingleStreamAttnProcessor: _attention_backend = None _parallel_config = None @@ -157,6 +148,89 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso return output +class ErnieImageAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = ErnieImageSingleStreamAttnProcessor + _available_processors = [ErnieImageSingleStreamAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: str = "rms_norm", + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + if qk_norm == "layer_norm": + self.norm_q = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm": + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + class FeedForward(nn.Module): def __init__(self, hidden_size: int, ffn_hidden_size: int): @@ -173,9 +247,8 @@ class SharedAdaLNBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): super().__init__() self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) - self.self_attention = Attention( + self.self_attention = ErnieImageAttention( query_dim=hidden_size, - cross_attention_dim=None, dim_head=hidden_size // num_heads, heads=num_heads, qk_norm="rms_norm" if qk_layernorm else None, @@ -192,7 +265,7 @@ def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, x = self.adaLN_sa_ln(x) x = self._modulate(x, shift_msa, scale_msa) x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first) - attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, freqs_cis=rotary_pos_emb) + attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H] x = residual + self._apply_gate(gate_msa, attn_out) residual = x @@ -261,7 +334,7 @@ def __init__( self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) - self.pos_embed = EmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) nn.init.zeros_(self.adaLN_modulation[-1].weight) nn.init.zeros_(self.adaLN_modulation[-1].bias) @@ -271,14 +344,22 @@ def __init__( nn.init.zeros_(self.final_linear.weight) nn.init.zeros_(self.final_linear.bias) - def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: List[torch.Tensor], return_dict: bool = True): + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + # encoder_hidden_states: List[torch.Tensor], + text_bth: torch.Tensor, + text_lens: torch.Tensor, + return_dict: bool = True + ): device, dtype = hidden_states.device, hidden_states.dtype B, C, H, W = hidden_states.shape p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size N_img = Hp * Wp img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous() - text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) + # text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) if self.text_proj is not None and text_bth.numel() > 0: text_bth = self.text_proj(text_bth) Tmax = text_bth.shape[1] @@ -298,7 +379,9 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[:, None, None, :] # AdaLN - c = self.time_embedding(self.time_proj(timestep.to(dtype))) + sample = self.time_proj(timestep.to(dtype)) + sample = sample.to(self.time_embedding.linear_1.weight.dtype) + c = self.time_embedding(sample) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] for layer in self.layers: x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) @@ -308,14 +391,3 @@ def forward(self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_h return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) - def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): - B = len(text_hiddens) - if B == 0: - return torch.zeros((0, 0, self.text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) - normalized = [th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens] - lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) - Tmax = int(lens.max().item()) - text_bth = torch.zeros((B, Tmax, self.text_in_dim), device=device, dtype=dtype) - for i, t in enumerate(normalized): - text_bth[i, :t.shape[0], :] = t - return text_bth, lens diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 4f79a4501059..e2526c0f5700 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -49,13 +49,13 @@ class ErnieImagePipeline(DiffusionPipeline): def __init__( self, - transformer, - vae, - text_encoder, - tokenizer, + transformer: ErnieImageTransformer2DModel, + vae: AutoencoderKLFlux2, + text_encoder: AutoModel, + tokenizer: AutoTokenizer, scheduler: FlowMatchEulerDiscreteScheduler, - pe=None, - pe_tokenizer=None, + pe: Optional[AutoModelForCausalLM] = None, + pe_tokenizer: Optional[AutoTokenizer] = None, ): super().__init__() self.register_modules( @@ -69,89 +69,13 @@ def __init__( ) self.vae_scale_factor = 16 # VAE downsample factor - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): - """ - Load pipeline from a pretrained model directory. - - Args: - pretrained_model_name_or_path: Path to the saved pipeline directory - **kwargs: Additional arguments passed to component loaders - - torch_dtype: Data type for model weights (default: torch.bfloat16) - - device_map: Device map for model loading - - trust_remote_code: Whether to trust remote code for text encoder + @property + def guidance_scale(self): + return self._guidance_scale - Returns: - ErnieImagePipeline instance - """ - - torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) - trust_remote_code = kwargs.pop("trust_remote_code", True) - device_map = kwargs.pop("device_map", None) - - # Determine whether this is a local directory or a Hub repo ID. - # For local paths we join sub-directories; for Hub IDs we use `subfolder`. - is_local = os.path.isdir(pretrained_model_name_or_path) - - def _path_or_subfolder(subfolder: str): - if is_local: - return {"pretrained_model_name_or_path": os.path.join(pretrained_model_name_or_path, subfolder)} - return {"pretrained_model_name_or_path": pretrained_model_name_or_path, "subfolder": subfolder} - - # Load transformer - transformer = ErnieImageTransformer2DModel.from_pretrained( - **_path_or_subfolder("transformer"), - torch_dtype=torch_dtype, - ) - - # Load VAE - vae = AutoencoderKLFlux2.from_pretrained( - **_path_or_subfolder("vae"), - torch_dtype=torch_dtype, - ) - - # Load text encoder - text_encoder = AutoModel.from_pretrained( - **_path_or_subfolder("text_encoder"), - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - ) - - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained( - **_path_or_subfolder("tokenizer"), - trust_remote_code=trust_remote_code, - ) - - # Load PE - pe = AutoModelForCausalLM.from_pretrained( - **_path_or_subfolder("pe"), - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - low_cpu_mem_usage=True, - **({"device_map": device_map} if device_map else {}), - ) - - # Load PE tokenizer (auto-picks up chat_template.jinja in the same dir) - pe_tokenizer = AutoTokenizer.from_pretrained( - **_path_or_subfolder("pe"), - trust_remote_code=trust_remote_code, - ) - - # Load scheduler - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - **_path_or_subfolder("scheduler"), - ) - - return cls( - transformer=transformer, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - pe=pe, - pe_tokenizer=pe_tokenizer, - scheduler=scheduler, - ) + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 @torch.no_grad() def _enhance_prompt_with_pe( @@ -181,13 +105,7 @@ def _enhance_prompt_with_pe( tokenize=False, add_generation_prompt=False, # "Output:" is already in the user block ) - # When accelerate offload hooks are installed, use the hook's execution_device - # to ensure inputs land on the same device as the model weights during forward() - if hasattr(self.pe, "_hf_hook") and hasattr(self.pe._hf_hook, "execution_device"): - pe_device = self.pe._hf_hook.execution_device - else: - pe_device = device - inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(pe_device) + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) output_ids = self.pe.generate( **inputs, max_new_tokens=self.pe_tokenizer.model_max_length, @@ -296,6 +214,19 @@ def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 1, 4, 2, 5, 3) return latents.reshape(b, c // 4, h * 2, w * 2) + def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): + text_in_dim = self.transformer.config.text_in_dim + B = len(text_hiddens) + if B == 0: + return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) + normalized = [th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens] + lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) + Tmax = int(lens.max().item()) + text_bth = torch.zeros((B, Tmax, text_in_dim), device=device, dtype=dtype) + for i, t in enumerate(normalized): + text_bth[i, :t.shape[0], :] = t + return text_bth, lens + @torch.no_grad() def __call__( self, @@ -343,6 +274,7 @@ def __call__( device = self._execution_device dtype = self.transformer.dtype + self._guidance_scale = guidance_scale # Validate dimensions if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") @@ -375,8 +307,7 @@ def __call__( text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt) # CFG with negative prompt - do_cfg = guidance_scale > 1.0 - if do_cfg: + if self.do_classifier_free_guidance: uncond_text_hiddens = self._encode_negative_prompt( negative_prompt, device, num_images_per_prompt ) @@ -400,14 +331,14 @@ def __call__( self.scheduler.set_timesteps(sigmas=sigmas[:-1], device=device) # Denoising loop - if do_cfg: + if self.do_classifier_free_guidance: cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens) else: cfg_text_hiddens = text_hiddens with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(self.scheduler.timesteps): - if do_cfg: + if self.do_classifier_free_guidance: latent_model_input = torch.cat([latents, latents], dim=0) t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) else: @@ -415,15 +346,18 @@ def __call__( t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) # Model prediction + text_bth, text_lens = self._pad_text(cfg_text_hiddens, device, dtype) pred = self.transformer( hidden_states=latent_model_input, timestep=t_batch, - encoder_hidden_states=cfg_text_hiddens, + # encoder_hidden_states=cfg_text_hiddens, + text_bth=text_bth, + text_lens=text_lens, return_dict=False, )[0] # Apply CFG - if do_cfg: + if self.do_classifier_free_guidance: pred_uncond, pred_cond = pred.chunk(2, dim=0) pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) From 298322dbfe63b0158972b70b3f447a82bbdd8d97 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Thu, 9 Apr 2026 10:26:36 +0800 Subject: [PATCH 10/17] =?UTF-8?q?=E6=A0=B9=E6=8D=AE=E5=AE=98=E6=96=B9?= =?UTF-8?q?=E5=BB=BA=E8=AE=AE=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transformers/transformer_ernie_image.py | 63 ++++++------------- .../ernie_image/pipeline_ernie_image.py | 46 ++------------ 2 files changed, 23 insertions(+), 86 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index e87995d57bfb..c3171598366a 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -61,7 +61,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] -class PatchEmbedDynamic(nn.Module): +class ErnieImagePatchEmbedDynamic(nn.Module): def __init__(self, in_channels: int, embed_dim: int, patch_size: int): super().__init__() self.patch_size = patch_size @@ -69,8 +69,8 @@ def __init__(self, in_channels: int, embed_dim: int, patch_size: int): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) - B, D, Hp, Wp = x.shape - return x.reshape(B, D, Hp * Wp).transpose(1, 2).contiguous() + batch_size, dim, height, width = x.shape + return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous() class ErnieImageSingleStreamAttnProcessor: @@ -87,7 +87,6 @@ def __call__( self, attn: Attention, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, freqs_cis: torch.Tensor | None = None, ) -> torch.Tensor: @@ -148,9 +147,9 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso return output + class ErnieImageAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = ErnieImageSingleStreamAttnProcessor - _available_processors = [ErnieImageSingleStreamAttnProcessor] def __init__( self, @@ -160,7 +159,6 @@ def __init__( dropout: float = 0.0, bias: bool = False, qk_norm: str = "rms_norm", - added_kv_proj_dim: int | None = None, added_proj_bias: bool | None = True, out_bias: bool = True, eps: float = 1e-5, @@ -179,7 +177,6 @@ def __init__( self.use_bias = bias self.dropout = dropout - self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) @@ -200,15 +197,6 @@ def __init__( self.to_out = torch.nn.ModuleList([]) self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(torch.nn.Dropout(dropout)) - - if added_kv_proj_dim is not None: - self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) - self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) - self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) if processor is None: processor = self._default_processor_cls() @@ -229,10 +217,10 @@ def forward( f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." ) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) -class FeedForward(nn.Module): +class ErnieImageFeedForward(nn.Module): def __init__(self, hidden_size: int, ffn_hidden_size: int): super().__init__() # Separate gate and up projections (matches converted weights) @@ -243,7 +231,7 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) -class SharedAdaLNBlock(nn.Module): +class ErnieImageSharedAdaLNBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): super().__init__() self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) @@ -258,40 +246,27 @@ def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: processor=ErnieImageSingleStreamAttnProcessor(), ) self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) - self.mlp = FeedForward(hidden_size, ffn_hidden_size) + self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size) def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None): residual = x x = self.adaLN_sa_ln(x) - x = self._modulate(x, shift_msa, scale_msa) + x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first) attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H] - x = residual + self._apply_gate(gate_msa, attn_out) + x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) residual = x - x = self._modulate(self.adaLN_mlp_ln(x), shift_mlp, scale_mlp) - return residual + self._apply_gate(gate_mlp, self.mlp(x)) - - def _modulate(self, x, shift, scale): - """AdaLN modulation: x * (1 + scale) + shift,在FP32下计算确保数值稳定""" - x_fp32 = x.float() - shift_fp32 = shift.float() - scale_fp32 = scale.float() - out = x_fp32 * (1 + scale_fp32) + shift_fp32 - return out.to(x.dtype) - - def _apply_gate(self, gate, x): - """Gate乘法在FP32下计算,对齐TE精度""" - return (gate.float() * x.float()).to(x.dtype) - -class AdaLNContinuous(nn.Module): + x = self.adaLN_mlp_ln(x) + x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype) + + +class ErnieImageAdaLNContinuous(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps) self.linear = nn.Linear(hidden_size, hidden_size * 2) - # 对齐 Megatron 实现:zero init - nn.init.zeros_(self.linear.weight) - nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: scale, shift = self.linear(conditioning).chunk(2, dim=-1) @@ -330,7 +305,7 @@ def __init__( self.out_channels = out_channels self.text_in_dim = text_in_dim - self.x_embedder = PatchEmbedDynamic(in_channels, hidden_size, patch_size) + self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size) self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) @@ -338,8 +313,8 @@ def __init__( self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) nn.init.zeros_(self.adaLN_modulation[-1].weight) nn.init.zeros_(self.adaLN_modulation[-1].bias) - self.layers = nn.ModuleList([SharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm) for _ in range(num_layers)]) - self.final_norm = AdaLNContinuous(hidden_size, eps) + self.layers = nn.ModuleList([ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm) for _ in range(num_layers)]) + self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps) self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) nn.init.zeros_(self.final_linear.weight) nn.init.zeros_(self.final_linear.bias) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index e2526c0f5700..3fb1948739fa 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -161,43 +161,6 @@ def encode_prompt( return text_hiddens - @torch.no_grad() - def _encode_negative_prompt( - self, - negative_prompt: List[str], - device: torch.device, - num_images_per_prompt: int = 1, - ) -> List[torch.Tensor]: - """Encode negative prompts for CFG.""" - text_hiddens = [] - - for np in negative_prompt: - ids = self.tokenizer( - np, - add_special_tokens=True, - truncation=True, - padding=False, - )["input_ids"] - - if len(ids) == 0: - if self.tokenizer.bos_token_id is not None: - ids = [self.tokenizer.bos_token_id] - else: - ids = [0] - - input_ids = torch.tensor([ids], device=device) - with torch.no_grad(): - outputs = self.text_encoder( - input_ids=input_ids, - output_hidden_states=True, - ) - hidden = outputs.hidden_states[-2][0] - - for _ in range(num_images_per_prompt): - text_hiddens.append(hidden) - - return text_hiddens - @staticmethod def _patchify_latents(latents: torch.Tensor) -> torch.Tensor: """2x2 patchify: [B, 32, H, W] -> [B, 128, H/2, W/2]""" @@ -214,8 +177,8 @@ def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: latents = latents.permute(0, 1, 4, 2, 5, 3) return latents.reshape(b, c // 4, h * 2, w * 2) - def _pad_text(self, text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype): - text_in_dim = self.transformer.config.text_in_dim + @staticmethod + def _pad_text(text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype, text_in_dim: int): B = len(text_hiddens) if B == 0: return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) @@ -308,7 +271,7 @@ def __call__( # CFG with negative prompt if self.do_classifier_free_guidance: - uncond_text_hiddens = self._encode_negative_prompt( + uncond_text_hiddens = self.encode_prompt( negative_prompt, device, num_images_per_prompt ) @@ -335,6 +298,7 @@ def __call__( cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens) else: cfg_text_hiddens = text_hiddens + text_bth, text_lens = self._pad_text(text_hiddens=cfg_text_hiddens, device=device, dtype=dtype, text_in_dim=self.transformer.config.text_in_dim) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(self.scheduler.timesteps): @@ -346,11 +310,9 @@ def __call__( t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) # Model prediction - text_bth, text_lens = self._pad_text(cfg_text_hiddens, device, dtype) pred = self.transformer( hidden_states=latent_model_input, timestep=t_batch, - # encoder_hidden_states=cfg_text_hiddens, text_bth=text_bth, text_lens=text_lens, return_dict=False, From c482b0d953ef8704bde3319d723c900619fb1fe5 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Fri, 10 Apr 2026 10:12:01 +0800 Subject: [PATCH 11/17] Update code --- .../transformers/transformer_ernie_image.py | 17 ++++++++++++++++- .../test_models_transformer_ernie_image.py | 4 ---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index c3171598366a..3aaeffe57254 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -318,6 +318,7 @@ def __init__( self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) nn.init.zeros_(self.final_linear.weight) nn.init.zeros_(self.final_linear.bias) + self.gradient_checkpointing = False def forward( self, @@ -359,7 +360,21 @@ def forward( c = self.time_embedding(sample) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] for layer in self.layers: - x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func( + layer, + x, + rotary_pos_emb, + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + attention_mask, + ) + else: + x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) x = self.final_norm(x, c).type_as(x) patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py index 7ef855609ed8..3aea22991ce8 100644 --- a/tests/models/transformers/test_models_transformer_ernie_image.py +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -36,10 +36,6 @@ torch.backends.cuda.matmul.allow_tf32 = False -@unittest.skipIf( - IS_GITHUB_ACTIONS, - reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.", -) class ErnieImageTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = ErnieImageTransformer2DModel main_input_name = "hidden_states" From f8b1395c10b5065a69964773d5a34a37f2c7b717 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Fri, 10 Apr 2026 17:00:09 +0800 Subject: [PATCH 12/17] update --- docs/source/en/_toctree.yml | 4 + docs/source/en/api/pipelines/ernie_image.md | 12 +- fix_turbo_weight_keys.py | 163 +++++++++++++++ .../transformers/transformer_ernie_image.py | 21 +- .../ernie_image/pipeline_ernie_image.py | 5 +- .../test_models_transformer_ernie_image.py | 189 ++++++------------ 6 files changed, 251 insertions(+), 143 deletions(-) create mode 100644 fix_turbo_weight_keys.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 67f0bff38fbf..6871f12eff49 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -412,6 +412,8 @@ title: WanTransformer3DModel - local: api/models/z_image_transformer2d title: ZImageTransformer2DModel + - local: api/models/ernie_image_transformer2d + title: ErnieImageTransformer2DModel title: Transformers - sections: - local: api/models/stable_cascade_unet @@ -634,6 +636,8 @@ title: VisualCloze - local: api/pipelines/z_image title: Z-Image + - local: api/pipelines/ernie_image + title: ERNIE-Image title: Image - sections: - local: api/pipelines/llada2 diff --git a/docs/source/en/api/pipelines/ernie_image.md b/docs/source/en/api/pipelines/ernie_image.md index 69c0234d4cbf..79f35cf93a2e 100644 --- a/docs/source/en/api/pipelines/ernie_image.md +++ b/docs/source/en/api/pipelines/ernie_image.md @@ -46,7 +46,7 @@ from diffusers.utils import load_image pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image", torch_dtype=torch.bfloat16) pipe.to("cuda") -# 如果显存不足,可以开启offload +# If you are running low on GPU VRAM, you can enable offloading pipe.enable_model_cpu_offload() prompt = "一只黑白相间的中华田园犬" @@ -55,8 +55,8 @@ images = pipe( height=1024, width=1024, num_inference_steps=50, - guidance_scale=5.0, - generator=generator, + guidance_scale=4.0, + generator=torch.Generator("cuda").manual_seed(42), use_pe=True, ).images images[0].save("ernie-image-output.png") @@ -69,7 +69,7 @@ from diffusers.utils import load_image pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16) pipe.to("cuda") -# 如果显存不足,可以开启offload +# If you are running low on GPU VRAM, you can enable offloading pipe.enable_model_cpu_offload() prompt = "一只黑白相间的中华田园犬" @@ -78,8 +78,8 @@ images = pipe( height=1024, width=1024, num_inference_steps=8, - guidance_scale=5.0, - generator=generator, + guidance_scale=1.0, + generator=torch.Generator("cuda").manual_seed(42), use_pe=True, ).images images[0].save("ernie-image-turbo-output.png") diff --git a/fix_turbo_weight_keys.py b/fix_turbo_weight_keys.py new file mode 100644 index 000000000000..a8ad51d87bf8 --- /dev/null +++ b/fix_turbo_weight_keys.py @@ -0,0 +1,163 @@ +""" +将 ERNIE-Image-Turbo/transformer 的权重键名修正为与 ERNIE-Image/transformer 一致。 + +差异均位于每层 self_attention 子模块,共 6 类 × 36 层 = 216 个键需要重命名: + k_layernorm -> norm_k + q_layernorm -> norm_q + k_proj -> to_k + q_proj -> to_q + v_proj -> to_v + linear_proj -> to_out.0 +""" + +import json +import os +import shutil +from pathlib import Path + +import torch +from safetensors.torch import load_file, save_file + +# ── 路径配置 ────────────────────────────────────────────────────────────────── +TURBO_DIR = Path("/root/paddlejob/gpfsspace/model_weights/turbo/ERNIE-Image-Turbo/transformer") +# 修正后的文件直接覆盖原目录(先备份),如需输出到新目录请修改此变量 +OUTPUT_DIR = TURBO_DIR # 或改为 Path("/your/output/path") +BACKUP_SUFFIX = ".bak" # 原文件备份后缀,设为 None 则不备份 + +# ── 键名映射(只处理 self_attention 子键,前缀 layers.N. 由脚本动态拼接)─── +KEY_REMAP = { + "self_attention.k_layernorm.weight": "self_attention.norm_k.weight", + "self_attention.q_layernorm.weight": "self_attention.norm_q.weight", + "self_attention.k_proj.weight": "self_attention.to_k.weight", + "self_attention.q_proj.weight": "self_attention.to_q.weight", + "self_attention.v_proj.weight": "self_attention.to_v.weight", + "self_attention.linear_proj.weight": "self_attention.to_out.0.weight", +} + +NUM_LAYERS = 36 # layers.0 ~ layers.35 + + +def build_full_remap() -> dict[str, str]: + """构建完整的旧键名 -> 新键名映射表(含层前缀)。""" + remap = {} + for layer_idx in range(NUM_LAYERS): + prefix = f"layers.{layer_idx}." + for old_suffix, new_suffix in KEY_REMAP.items(): + remap[prefix + old_suffix] = prefix + new_suffix + return remap + + +def rename_keys_in_tensor_dict( + tensors: dict[str, torch.Tensor], + remap: dict[str, str], +) -> tuple[dict[str, torch.Tensor], int]: + """重命名张量字典中的键,返回新字典和实际重命名的数量。""" + renamed = 0 + new_tensors: dict[str, torch.Tensor] = {} + for key, tensor in tensors.items(): + new_key = remap.get(key, key) + if new_key != key: + renamed += 1 + new_tensors[new_key] = tensor + return new_tensors, renamed + + +def backup_file(path: Path) -> None: + if BACKUP_SUFFIX is None: + return + backup = path.with_suffix(path.suffix + BACKUP_SUFFIX) + shutil.copy2(path, backup) + print(f" [备份] {path.name} -> {backup.name}") + + +def process_safetensors_files(remap: dict[str, str]) -> None: + index_path = TURBO_DIR / "diffusion_pytorch_model.safetensors.index.json" + with open(index_path, "r", encoding="utf-8") as f: + index = json.load(f) + + # 找出所有需要处理的 shard 文件(去重) + shard_files = sorted(set(index["weight_map"].values())) + print(f"\n共发现 {len(shard_files)} 个 shard 文件,开始处理...\n") + + total_renamed = 0 + for shard_name in shard_files: + shard_path = TURBO_DIR / shard_name + print(f"[处理] {shard_name}") + + tensors = load_file(shard_path) + new_tensors, renamed = rename_keys_in_tensor_dict(tensors, remap) + total_renamed += renamed + print(f" 本文件重命名: {renamed} 个键") + + if renamed > 0: + # 保留原始 metadata(如果有) + metadata = {} + + out_path = OUTPUT_DIR / shard_name + if out_path == shard_path and BACKUP_SUFFIX: + backup_file(shard_path) + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + save_file(new_tensors, out_path, metadata=metadata) + print(f" [保存] {out_path}") + else: + if OUTPUT_DIR != TURBO_DIR: + shutil.copy2(shard_path, OUTPUT_DIR / shard_name) + print(f" [复制(无变更)] {shard_name}") + + print(f"\n所有 shard 处理完毕,共重命名 {total_renamed} 个键。") + + # ── 更新 index.json 中的 weight_map ───────────────────────────────────── + new_weight_map: dict[str, str] = {} + for old_key, shard_name in index["weight_map"].items(): + new_key = remap.get(old_key, old_key) + new_weight_map[new_key] = shard_name + + index["weight_map"] = new_weight_map + + out_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json" + if out_index_path == index_path and BACKUP_SUFFIX: + backup_file(index_path) + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + with open(out_index_path, "w", encoding="utf-8") as f: + json.dump(index, f, indent=2, ensure_ascii=False) + print(f"[更新] index.json 已写入: {out_index_path}\n") + + +def verify_against_base() -> None: + """(可选)验证修正后的 Turbo 键名与 Base 完全一致。""" + BASE_DIR = Path("/root/paddlejob/gpfsspace/model_weights/base/ERNIE-Image/transformer") + base_index_path = BASE_DIR / "diffusion_pytorch_model.safetensors.index.json" + turbo_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json" + + if not base_index_path.exists() or not turbo_index_path.exists(): + print("[验证] 找不到 index.json,跳过验证。") + return + + with open(base_index_path, "r") as f: + base_keys = set(json.load(f)["weight_map"].keys()) + with open(turbo_index_path, "r") as f: + turbo_keys = set(json.load(f)["weight_map"].keys()) + + only_in_base = base_keys - turbo_keys + only_in_turbo = turbo_keys - base_keys + + if not only_in_base and only_in_turbo: + print(f"[验证] 警告:Turbo 中多余的键 ({len(only_in_turbo)}):") + for k in sorted(only_in_turbo): + print(f" + {k}") + elif only_in_base: + print(f"[验证] 警告:Base 中存在但 Turbo 中缺少的键 ({len(only_in_base)}):") + for k in sorted(only_in_base): + print(f" - {k}") + else: + print("[验证] 通过!修正后 Turbo 的键名与 Base 完全一致。") + + +if __name__ == "__main__": + remap = build_full_remap() + print(f"键名映射表共 {len(remap)} 条({NUM_LAYERS} 层 × {len(KEY_REMAP)} 类)") + + process_safetensors_files(remap) + verify_against_base() diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 3aaeffe57254..937dd3c17190 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -28,12 +28,13 @@ from ..embeddings import Timesteps from ..embeddings import TimestepEmbedding from ..modeling_utils import ModelMixin -from ...utils import BaseOutput +from ...utils import BaseOutput, logging from ..normalization import RMSNorm from ..attention_processor import Attention from ..attention_dispatch import dispatch_attention_fn from ..attention import AttentionMixin, AttentionModuleMixin +logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ErnieImageTransformer2DModelOutput(BaseOutput): @@ -248,7 +249,13 @@ def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size) - def forward(self, x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask=None): + def forward( + self, + x, + rotary_pos_emb, temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None + ): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb residual = x x = self.adaLN_sa_ln(x) x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) @@ -360,21 +367,17 @@ def forward( c = self.time_embedding(sample) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] for layer in self.layers: + temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] if torch.is_grad_enabled() and self.gradient_checkpointing: x = self._gradient_checkpointing_func( layer, x, rotary_pos_emb, - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, + temb, attention_mask, ) else: - x = layer(x, rotary_pos_emb, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, attention_mask) + x = layer(x, rotary_pos_emb, temb, attention_mask) x = self.final_norm(x, c).type_as(x) patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 3fb1948739fa..714ea6fbd01a 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -67,7 +67,8 @@ def __init__( pe=pe, pe_tokenizer=pe_tokenizer, ) - self.vae_scale_factor = 16 # VAE downsample factor + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 + print(f"vae_scale_factor: {self.vae_scale_factor}") @property def guidance_scale(self): @@ -278,7 +279,7 @@ def __call__( # Latent dimensions latent_h = height // self.vae_scale_factor latent_w = width // self.vae_scale_factor - latent_channels = 128 # After patchify + latent_channels = self.transformer.config.in_channels # After patchify # Initialize latents if latents is None: diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py index 3aea22991ce8..bff0894df08b 100644 --- a/tests/models/transformers/test_models_transformer_ernie_image.py +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -13,20 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import os -import unittest +import pytest import torch from diffusers import ErnieImageTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import IS_GITHUB_ACTIONS, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ...testing_utils import torch_device +from ..testing_utils import ( + BaseModelTesterConfig, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) -# Ernie-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations -# Cannot use enable_full_determinism() which sets it to True +# Ernie-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations. +# Cannot use enable_full_determinism() which sets it to True. os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" torch.use_deterministic_algorithms(False) @@ -36,40 +41,34 @@ torch.backends.cuda.matmul.allow_tf32 = False -class ErnieImageTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = ErnieImageTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.9, 0.9, 0.9] - - def prepare_dummy_input(self, height=16, width=16): - batch_size = 1 - num_channels = 16 - embedding_dim = 16 - sequence_length = 16 - - hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) - encoder_hidden_states = [ - torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size) - ] - timestep = torch.tensor([1.0]).to(torch_device) - - return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep} +class ErnieImageTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return ErnieImageTransformer2DModel @property - def dummy_input(self): - return self.prepare_dummy_input() + def main_input_name(self) -> str: + return "hidden_states" @property - def input_shape(self): + def output_shape(self) -> tuple: return (16, 16, 16) @property - def output_shape(self): + def input_shape(self) -> tuple: return (16, 16, 16) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def model_split_percents(self) -> list: + # We override the items here because the transformer under consideration is small. + return [0.9, 0.9, 0.9] + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { "hidden_size": 16, "num_attention_heads": 1, "num_layers": 1, @@ -83,113 +82,51 @@ def prepare_init_args_and_inputs_for_common(self): "eps": 1e-6, "qk_layernorm": True, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def setUp(self): - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - def tearDown(self): - super().tearDown() - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - def test_gradient_checkpointing_is_applied(self): - expected_set = {"ErnieImageTransformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - @unittest.skip("Test is not supported for handling main inputs that are lists.") - def test_training(self): - super().test_training() - - @unittest.skip("Test is not supported for handling main inputs that are lists.") - def test_ema_training(self): - super().test_ema_training() - - @unittest.skip("Test is not supported for handling main inputs that are lists.") - def test_effective_gradient_checkpointing(self): - super().test_effective_gradient_checkpointing() - - @unittest.skip( - "Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices." - ) - def test_layerwise_casting_training(self): - super().test_layerwise_casting_training() - - @unittest.skip( - "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." - ) - def test_layerwise_casting_inference(self): - super().test_layerwise_casting_inference() - - @unittest.skip( - "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." - ) - def test_layerwise_casting_memory(self): - super().test_layerwise_casting_memory() - - @unittest.skip( - "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." - ) - def test_group_offloading_with_layerwise_casting(self): - super().test_group_offloading_with_layerwise_casting() - - @unittest.skip( - "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." - ) - def test_group_offloading_with_layerwise_casting_0(self): - pass - @unittest.skip( - "TimestepEmbedding uses explicit dtype casting that conflicts with float8 layerwise casting hooks." - ) - def test_group_offloading_with_layerwise_casting_1(self): - pass - - @unittest.skip("Test is not supported for handling main inputs that are lists.") - def test_outputs_equivalence(self): - super().test_outputs_equivalence() + def get_dummy_inputs(self, height: int = 16, width: int = 16, batch_size: int = 1) -> dict: + num_channels = 16 # in_channels + sequence_length = 16 + text_in_dim = 16 # text_in_dim + + return { + "hidden_states": randn_tensor( + (batch_size, num_channels, height, width), generator=self.generator, device=torch_device + ), + "timestep": torch.tensor([1.0] * batch_size, device=torch_device), + "text_bth": randn_tensor( + (batch_size, sequence_length, text_in_dim), generator=self.generator, device=torch_device + ), + "text_lens": torch.tensor([sequence_length] * batch_size, device=torch_device), + } - @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") - def test_group_offloading(self): - super().test_group_offloading() - @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") - def test_group_offloading_with_disk(self): - super().test_group_offloading_with_disk() +class TestErnieImageTransformer(ErnieImageTransformerTesterConfig, ModelTesterMixin): + pass -class ErnieImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = ErnieImageTransformer2DModel - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] +class TestErnieImageTransformerTraining(ErnieImageTransformerTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"ErnieImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - def prepare_init_args_and_inputs_for_common(self): - return ErnieImageTransformerTests().prepare_init_args_and_inputs_for_common() - def prepare_dummy_input(self, height, width): - return ErnieImageTransformerTests().prepare_dummy_input(height=height, width=width) +class TestErnieImageTransformerCompile(ErnieImageTransformerTesterConfig, TorchCompileTesterMixin): + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] - @unittest.skip( - "The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice." + @pytest.mark.skip( + reason="The repeated block in this model is ErnieImageSharedAdaLNBlock. As a consequence of this, " + "the inputs recorded for the block would vary during compilation and full compilation with " + "fullgraph=True would trigger recompilation." ) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() - @unittest.skip("Fullgraph AoT is broken") - def test_compile_works_with_aot(self): - super().test_compile_works_with_aot() + @pytest.mark.skip(reason="Fullgraph AoT is broken.") + def test_compile_works_with_aot(self, tmp_path): + super().test_compile_works_with_aot(tmp_path) - @unittest.skip("Fullgraph is broken") + @pytest.mark.skip(reason="Fullgraph is broken.") def test_compile_on_different_shapes(self): super().test_compile_on_different_shapes() From 5024bc795df15ee46509646a9fc23761aa759bc8 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Fri, 10 Apr 2026 19:08:17 +0800 Subject: [PATCH 13/17] update --- fix_turbo_weight_keys.py | 163 --------------------------------------- 1 file changed, 163 deletions(-) delete mode 100644 fix_turbo_weight_keys.py diff --git a/fix_turbo_weight_keys.py b/fix_turbo_weight_keys.py deleted file mode 100644 index a8ad51d87bf8..000000000000 --- a/fix_turbo_weight_keys.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -将 ERNIE-Image-Turbo/transformer 的权重键名修正为与 ERNIE-Image/transformer 一致。 - -差异均位于每层 self_attention 子模块,共 6 类 × 36 层 = 216 个键需要重命名: - k_layernorm -> norm_k - q_layernorm -> norm_q - k_proj -> to_k - q_proj -> to_q - v_proj -> to_v - linear_proj -> to_out.0 -""" - -import json -import os -import shutil -from pathlib import Path - -import torch -from safetensors.torch import load_file, save_file - -# ── 路径配置 ────────────────────────────────────────────────────────────────── -TURBO_DIR = Path("/root/paddlejob/gpfsspace/model_weights/turbo/ERNIE-Image-Turbo/transformer") -# 修正后的文件直接覆盖原目录(先备份),如需输出到新目录请修改此变量 -OUTPUT_DIR = TURBO_DIR # 或改为 Path("/your/output/path") -BACKUP_SUFFIX = ".bak" # 原文件备份后缀,设为 None 则不备份 - -# ── 键名映射(只处理 self_attention 子键,前缀 layers.N. 由脚本动态拼接)─── -KEY_REMAP = { - "self_attention.k_layernorm.weight": "self_attention.norm_k.weight", - "self_attention.q_layernorm.weight": "self_attention.norm_q.weight", - "self_attention.k_proj.weight": "self_attention.to_k.weight", - "self_attention.q_proj.weight": "self_attention.to_q.weight", - "self_attention.v_proj.weight": "self_attention.to_v.weight", - "self_attention.linear_proj.weight": "self_attention.to_out.0.weight", -} - -NUM_LAYERS = 36 # layers.0 ~ layers.35 - - -def build_full_remap() -> dict[str, str]: - """构建完整的旧键名 -> 新键名映射表(含层前缀)。""" - remap = {} - for layer_idx in range(NUM_LAYERS): - prefix = f"layers.{layer_idx}." - for old_suffix, new_suffix in KEY_REMAP.items(): - remap[prefix + old_suffix] = prefix + new_suffix - return remap - - -def rename_keys_in_tensor_dict( - tensors: dict[str, torch.Tensor], - remap: dict[str, str], -) -> tuple[dict[str, torch.Tensor], int]: - """重命名张量字典中的键,返回新字典和实际重命名的数量。""" - renamed = 0 - new_tensors: dict[str, torch.Tensor] = {} - for key, tensor in tensors.items(): - new_key = remap.get(key, key) - if new_key != key: - renamed += 1 - new_tensors[new_key] = tensor - return new_tensors, renamed - - -def backup_file(path: Path) -> None: - if BACKUP_SUFFIX is None: - return - backup = path.with_suffix(path.suffix + BACKUP_SUFFIX) - shutil.copy2(path, backup) - print(f" [备份] {path.name} -> {backup.name}") - - -def process_safetensors_files(remap: dict[str, str]) -> None: - index_path = TURBO_DIR / "diffusion_pytorch_model.safetensors.index.json" - with open(index_path, "r", encoding="utf-8") as f: - index = json.load(f) - - # 找出所有需要处理的 shard 文件(去重) - shard_files = sorted(set(index["weight_map"].values())) - print(f"\n共发现 {len(shard_files)} 个 shard 文件,开始处理...\n") - - total_renamed = 0 - for shard_name in shard_files: - shard_path = TURBO_DIR / shard_name - print(f"[处理] {shard_name}") - - tensors = load_file(shard_path) - new_tensors, renamed = rename_keys_in_tensor_dict(tensors, remap) - total_renamed += renamed - print(f" 本文件重命名: {renamed} 个键") - - if renamed > 0: - # 保留原始 metadata(如果有) - metadata = {} - - out_path = OUTPUT_DIR / shard_name - if out_path == shard_path and BACKUP_SUFFIX: - backup_file(shard_path) - - OUTPUT_DIR.mkdir(parents=True, exist_ok=True) - save_file(new_tensors, out_path, metadata=metadata) - print(f" [保存] {out_path}") - else: - if OUTPUT_DIR != TURBO_DIR: - shutil.copy2(shard_path, OUTPUT_DIR / shard_name) - print(f" [复制(无变更)] {shard_name}") - - print(f"\n所有 shard 处理完毕,共重命名 {total_renamed} 个键。") - - # ── 更新 index.json 中的 weight_map ───────────────────────────────────── - new_weight_map: dict[str, str] = {} - for old_key, shard_name in index["weight_map"].items(): - new_key = remap.get(old_key, old_key) - new_weight_map[new_key] = shard_name - - index["weight_map"] = new_weight_map - - out_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json" - if out_index_path == index_path and BACKUP_SUFFIX: - backup_file(index_path) - - OUTPUT_DIR.mkdir(parents=True, exist_ok=True) - with open(out_index_path, "w", encoding="utf-8") as f: - json.dump(index, f, indent=2, ensure_ascii=False) - print(f"[更新] index.json 已写入: {out_index_path}\n") - - -def verify_against_base() -> None: - """(可选)验证修正后的 Turbo 键名与 Base 完全一致。""" - BASE_DIR = Path("/root/paddlejob/gpfsspace/model_weights/base/ERNIE-Image/transformer") - base_index_path = BASE_DIR / "diffusion_pytorch_model.safetensors.index.json" - turbo_index_path = OUTPUT_DIR / "diffusion_pytorch_model.safetensors.index.json" - - if not base_index_path.exists() or not turbo_index_path.exists(): - print("[验证] 找不到 index.json,跳过验证。") - return - - with open(base_index_path, "r") as f: - base_keys = set(json.load(f)["weight_map"].keys()) - with open(turbo_index_path, "r") as f: - turbo_keys = set(json.load(f)["weight_map"].keys()) - - only_in_base = base_keys - turbo_keys - only_in_turbo = turbo_keys - base_keys - - if not only_in_base and only_in_turbo: - print(f"[验证] 警告:Turbo 中多余的键 ({len(only_in_turbo)}):") - for k in sorted(only_in_turbo): - print(f" + {k}") - elif only_in_base: - print(f"[验证] 警告:Base 中存在但 Turbo 中缺少的键 ({len(only_in_base)}):") - for k in sorted(only_in_base): - print(f" - {k}") - else: - print("[验证] 通过!修正后 Turbo 的键名与 Base 完全一致。") - - -if __name__ == "__main__": - remap = build_full_remap() - print(f"键名映射表共 {len(remap)} 条({NUM_LAYERS} 层 × {len(KEY_REMAP)} 类)") - - process_safetensors_files(remap) - verify_against_base() From 2c43be6babc5962f4a4f3d6fbc5ead9bc04f74fa Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 10 Apr 2026 17:09:44 +0000 Subject: [PATCH 14/17] Apply style fixes --- docs/source/en/_toctree.yml | 8 +- src/diffusers/__init__.py | 8 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_ernie_image.py | 104 +++++++++++++----- src/diffusers/pipelines/__init__.py | 2 +- .../ernie_image/pipeline_ernie_image.py | 44 ++++---- 6 files changed, 106 insertions(+), 62 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6871f12eff49..b3f3fae24b90 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -350,6 +350,8 @@ title: DiTTransformer2DModel - local: api/models/easyanimate_transformer3d title: EasyAnimateTransformer3DModel + - local: api/models/ernie_image_transformer2d + title: ErnieImageTransformer2DModel - local: api/models/flux2_transformer title: Flux2Transformer2DModel - local: api/models/flux_transformer @@ -412,8 +414,6 @@ title: WanTransformer3DModel - local: api/models/z_image_transformer2d title: ZImageTransformer2DModel - - local: api/models/ernie_image_transformer2d - title: ErnieImageTransformer2DModel title: Transformers - sections: - local: api/models/stable_cascade_unet @@ -536,6 +536,8 @@ title: DiT - local: api/pipelines/easyanimate title: EasyAnimate + - local: api/pipelines/ernie_image + title: ERNIE-Image - local: api/pipelines/flux title: Flux - local: api/pipelines/flux2 @@ -636,8 +638,6 @@ title: VisualCloze - local: api/pipelines/z_image title: Z-Image - - local: api/pipelines/ernie_image - title: ERNIE-Image title: Image - sections: - local: api/pipelines/llada2 diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 80f2415384ae..bd71eded9ad1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -235,6 +235,7 @@ "CosmosTransformer3DModel", "DiTTransformer2DModel", "EasyAnimateTransformer3DModel", + "ErnieImageTransformer2DModel", "Flux2Transformer2DModel", "FluxControlNetModel", "FluxMultiControlNetModel", @@ -302,7 +303,6 @@ "ZImageControlNetModel", "ZImageTransformer2DModel", "attention_backend", - "ErnieImageTransformer2DModel" ] ) _import_structure["modular_pipelines"].extend( @@ -526,6 +526,7 @@ "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", + "ErnieImagePipeline", "Flux2KleinKVPipeline", "Flux2KleinPipeline", "Flux2Pipeline", @@ -745,7 +746,6 @@ "ZImageInpaintPipeline", "ZImageOmniPipeline", "ZImagePipeline", - "ErnieImagePipeline", ] ) @@ -1037,6 +1037,7 @@ CosmosTransformer3DModel, DiTTransformer2DModel, EasyAnimateTransformer3DModel, + ErnieImageTransformer2DModel, Flux2Transformer2DModel, FluxControlNetModel, FluxMultiControlNetModel, @@ -1103,7 +1104,6 @@ ZImageControlNetModel, ZImageTransformer2DModel, attention_backend, - ErnieImageTransformer2DModel, ) from .modular_pipelines import ( AutoPipelineBlocks, @@ -1303,6 +1303,7 @@ EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, + ErnieImagePipeline, Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline, @@ -1520,7 +1521,6 @@ ZImageInpaintPipeline, ZImageOmniPipeline, ZImagePipeline, - ErnieImagePipeline, ) try: diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 9087a2bc857d..2074618f952a 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -25,6 +25,7 @@ from .transformer_cogview4 import CogView4Transformer2DModel from .transformer_cosmos import CosmosTransformer3DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel + from .transformer_ernie_image import ErnieImageTransformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel from .transformer_glm_image import GlmImageTransformer2DModel @@ -53,4 +54,3 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel - from .transformer_ernie_image import ErnieImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 937dd3c17190..09682a218d91 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -16,26 +16,27 @@ Ernie-Image Transformer2DModel for HuggingFace Diffusers. """ -import math import inspect from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F + from ...configuration_utils import ConfigMixin, register_to_config -from ..embeddings import Timesteps -from ..embeddings import TimestepEmbedding -from ..modeling_utils import ModelMixin from ...utils import BaseOutput, logging -from ..normalization import RMSNorm -from ..attention_processor import Attention +from ..attention import AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn -from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_processor import Attention +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + @dataclass class ErnieImageTransformer2DModelOutput(BaseOutput): sample: torch.Tensor @@ -44,7 +45,7 @@ class ErnieImageTransformer2DModelOutput(BaseOutput): def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim - omega = 1.0 / (theta ** scale) + omega = 1.0 / (theta**scale) out = torch.einsum("...n,d->...nd", pos, omega) return out.float() @@ -232,8 +233,11 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + class ErnieImageSharedAdaLNBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True): + def __init__( + self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True + ): super().__init__() self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) self.self_attention = ErnieImageAttention( @@ -250,22 +254,23 @@ def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size) def forward( - self, - x, - rotary_pos_emb, temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None = None + self, + x, + rotary_pos_emb, + temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, ): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb residual = x x = self.adaLN_sa_ln(x) - x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) + x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first) attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H] x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) residual = x x = self.adaLN_mlp_ln(x) - x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype) @@ -320,7 +325,14 @@ def __init__( self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) nn.init.zeros_(self.adaLN_modulation[-1].weight) nn.init.zeros_(self.adaLN_modulation[-1].bias) - self.layers = nn.ModuleList([ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm) for _ in range(num_layers)]) + self.layers = nn.ModuleList( + [ + ErnieImageSharedAdaLNBlock( + hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm + ) + for _ in range(num_layers) + ] + ) self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps) self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) nn.init.zeros_(self.final_linear.weight) @@ -328,13 +340,13 @@ def __init__( self.gradient_checkpointing = False def forward( - self, - hidden_states: torch.Tensor, - timestep: torch.Tensor, - # encoder_hidden_states: List[torch.Tensor], + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + # encoder_hidden_states: List[torch.Tensor], text_bth: torch.Tensor, text_lens: torch.Tensor, - return_dict: bool = True + return_dict: bool = True, ): device, dtype = hidden_states.device, hidden_states.dtype B, C, H, W = hidden_states.shape @@ -352,20 +364,48 @@ def forward( S = x.shape[0] # Position IDs - text_ids = torch.cat([torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1), torch.zeros((B, Tmax, 2), device=device)], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device) - grid_yx = torch.stack(torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32), torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"), dim=-1).reshape(-1, 2) - image_ids = torch.cat([text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], dim=-1) + text_ids = ( + torch.cat( + [ + torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1), + torch.zeros((B, Tmax, 2), device=device), + ], + dim=-1, + ) + if Tmax > 0 + else torch.zeros((B, 0, 3), device=device) + ) + grid_yx = torch.stack( + torch.meshgrid( + torch.arange(Hp, device=device, dtype=torch.float32), + torch.arange(Wp, device=device, dtype=torch.float32), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + image_ids = torch.cat( + [text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], + dim=-1, + ) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention - valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool) - attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[:, None, None, :] + valid_text = ( + torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) + if Tmax > 0 + else torch.zeros((B, 0), device=device, dtype=torch.bool) + ) + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ + :, None, None, : + ] # AdaLN sample = self.time_proj(timestep.to(dtype)) sample = sample.to(self.time_embedding.linear_1.weight.dtype) c = self.time_embedding(sample) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ + t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1) + ] for layer in self.layers: temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -380,7 +420,11 @@ def forward( x = layer(x, rotary_pos_emb, temb, attention_mask) x = self.final_norm(x, c).type_as(x) patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous() - output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W) + output = ( + patches.view(B, Hp, Wp, p, p, self.out_channels) + .permute(0, 5, 1, 3, 2, 4) + .contiguous() + .view(B, self.out_channels, H, W) + ) return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) - diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index cd3437fcdaaf..1278574f9232 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -679,6 +679,7 @@ EasyAnimateInpaintPipeline, EasyAnimatePipeline, ) + from .ernie_image import ErnieImagePipeline from .flux import ( FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, @@ -772,7 +773,6 @@ from .mochi import MochiPipeline from .nucleusmoe_image import NucleusMoEImagePipeline from .omnigen import OmniGenPipeline - from .ernie_image import ErnieImagePipeline from .ovis_image import OvisImagePipeline from .pag import ( AnimateDiffPAGPipeline, diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 714ea6fbd01a..7147e855f142 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -17,17 +17,16 @@ """ import json -import os -import numpy as np +from typing import Callable, List, Optional, Union + import torch from PIL import Image -from typing import Callable, List, Optional, Union from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from ...pipelines.pipeline_utils import DiffusionPipeline -from ...schedulers import FlowMatchEulerDiscreteScheduler from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler from .pipeline_output import ErnieImagePipelineOutput @@ -117,7 +116,7 @@ def _enhance_prompt_with_pe( eos_token_id=self.pe_tokenizer.eos_token_id, ) # Decode only newly generated tokens - generated_ids = output_ids[0][inputs["input_ids"].shape[1]:] + generated_ids = output_ids[0][inputs["input_ids"].shape[1] :] return self.pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() @torch.no_grad() @@ -182,13 +181,17 @@ def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: def _pad_text(text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype, text_in_dim: int): B = len(text_hiddens) if B == 0: - return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) - normalized = [th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens] + return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros( + (0,), device=device, dtype=torch.long + ) + normalized = [ + th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens + ] lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) Tmax = int(lens.max().item()) text_bth = torch.zeros((B, Tmax, text_in_dim), device=device, dtype=dtype) for i, t in enumerate(normalized): - text_bth[i, :t.shape[0], :] = t + text_bth[i, : t.shape[0], :] = t return text_bth, lens @torch.no_grad() @@ -207,7 +210,7 @@ def __call__( return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - use_pe: bool = True, # 默认使用PE进行改写 + use_pe: bool = True, # 默认使用PE进行改写 ): """ Generate images from text prompts. @@ -225,9 +228,9 @@ def __call__( output_type: "pil" or "latent" return_dict: Whether to return a dataclass callback_on_step_end: Optional callback invoked at the end of each denoising step. - Called as `callback_on_step_end(pipeline, step, timestep, callback_kwargs)` where - `callback_kwargs` contains the tensors listed in `callback_on_step_end_tensor_inputs`. - The callback may return a dict to override those tensors for subsequent steps. + Called as `callback_on_step_end(pipeline, step, timestep, callback_kwargs)` where `callback_kwargs` + contains the tensors listed in `callback_on_step_end_tensor_inputs`. The callback may return a dict to + override those tensors for subsequent steps. callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs. Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`). use_pe: Whether to use the PE model to enhance prompts before generation. @@ -250,10 +253,7 @@ def __call__( # [Phase 1] PE: enhance prompts revised_prompts: Optional[List[str]] = None if use_pe and self.pe is not None and self.pe_tokenizer is not None: - prompt = [ - self._enhance_prompt_with_pe(p, device, width=width, height=height) - for p in prompt - ] + prompt = [self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt] revised_prompts = list(prompt) batch_size = len(prompt) @@ -272,9 +272,7 @@ def __call__( # CFG with negative prompt if self.do_classifier_free_guidance: - uncond_text_hiddens = self.encode_prompt( - negative_prompt, device, num_images_per_prompt - ) + uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt) # Latent dimensions latent_h = height // self.vae_scale_factor @@ -299,8 +297,10 @@ def __call__( cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens) else: cfg_text_hiddens = text_hiddens - text_bth, text_lens = self._pad_text(text_hiddens=cfg_text_hiddens, device=device, dtype=dtype, text_in_dim=self.transformer.config.text_in_dim) - + text_bth, text_lens = self._pad_text( + text_hiddens=cfg_text_hiddens, device=device, dtype=dtype, text_in_dim=self.transformer.config.text_in_dim + ) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(self.scheduler.timesteps): if self.do_classifier_free_guidance: From a4ebb0c523bd0792459753bd7fd075184270118b Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Sat, 11 Apr 2026 10:16:49 +0800 Subject: [PATCH 15/17] update --- .../ernie_image/pipeline_ernie_image.py | 62 ++++++++++++++----- .../dummy_torch_and_transformers_objects.py | 30 ++++----- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 7147e855f142..55a46ddbfeeb 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -23,6 +23,7 @@ from PIL import Image from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from ...utils.torch_utils import randn_tensor from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline @@ -67,7 +68,6 @@ def __init__( pe_tokenizer=pe_tokenizer, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 - print(f"vae_scale_factor: {self.vae_scale_factor}") @property def guidance_scale(self): @@ -197,7 +197,7 @@ def _pad_text(text_hiddens: List[torch.Tensor], device: torch.device, dtype: tor @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = "", height: int = 1024, width: int = 1024, @@ -206,6 +206,8 @@ def __call__( num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, output_type: str = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, @@ -225,6 +227,10 @@ def __call__( num_images_per_prompt: Number of images per prompt generator: Random generator for reproducibility latents: Pre-generated latents (optional) + prompt_embeds: Pre-computed text embeddings for positive prompts (optional). + If provided, `encode_prompt` is skipped for positive prompts. + negative_prompt_embeds: Pre-computed text embeddings for negative prompts (optional). + If provided, `encode_prompt` is skipped for negative prompts. output_type: "pil" or "latent" return_dict: Whether to return a dataclass callback_on_step_end: Optional callback invoked at the end of each denoising step. @@ -242,21 +248,35 @@ def __call__( dtype = self.transformer.dtype self._guidance_scale = guidance_scale + + # Validate prompt / prompt_embeds + if prompt is None and prompt_embeds is None: + raise ValueError("Must provide either `prompt` or `prompt_embeds`.") + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot provide both `prompt` and `prompt_embeds` at the same time.") + # Validate dimensions if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") # Handle prompts - if isinstance(prompt, str): - prompt = [prompt] + if prompt is not None: + if isinstance(prompt, str): + prompt = [prompt] # [Phase 1] PE: enhance prompts revised_prompts: Optional[List[str]] = None - if use_pe and self.pe is not None and self.pe_tokenizer is not None: - prompt = [self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt] + if prompt is not None and use_pe and self.pe is not None and self.pe_tokenizer is not None: + prompt = [ + self._enhance_prompt_with_pe(p, device, width=width, height=height) + for p in prompt + ] revised_prompts = list(prompt) - batch_size = len(prompt) + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) total_batch_size = batch_size * num_images_per_prompt # Handle negative prompt @@ -268,11 +288,19 @@ def __call__( raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") # [Phase 2] Text encoding - text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt) + if prompt_embeds is not None: + text_hiddens = prompt_embeds + else: + text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt) # CFG with negative prompt if self.do_classifier_free_guidance: - uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt) + if negative_prompt_embeds is not None: + uncond_text_hiddens = negative_prompt_embeds + else: + uncond_text_hiddens = self.encode_prompt( + negative_prompt, device, num_images_per_prompt + ) # Latent dimensions latent_h = height // self.vae_scale_factor @@ -281,12 +309,18 @@ def __call__( # Initialize latents if latents is None: - latents = torch.randn( - (total_batch_size, latent_channels, latent_h, latent_w), - device=device, - dtype=dtype, - generator=generator, + latents = randn_tensor( + (total_batch_size, latent_channels, latent_h, latent_w), + generator=generator, + device=device, + dtype=dtype ) + # latents = torch.randn( + # (total_batch_size, latent_channels, latent_h, latent_w), + # device=device, + # dtype=dtype, + # generator=generator, + # ) # Setup scheduler sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index fecc0882e695..4f3906a5f98c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1202,6 +1202,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ErnieImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2KleinKVPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2597,21 +2612,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class ErnieImagePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class OvisImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 071d181c0fb619fbeab72aefe3c04d8eb6201a70 Mon Sep 17 00:00:00 2001 From: HsiaWinter Date: Sat, 11 Apr 2026 10:35:33 +0800 Subject: [PATCH 16/17] update --- src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 55a46ddbfeeb..19dd4a7fafb0 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -315,12 +315,6 @@ def __call__( device=device, dtype=dtype ) - # latents = torch.randn( - # (total_batch_size, latent_channels, latent_h, latent_w), - # device=device, - # dtype=dtype, - # generator=generator, - # ) # Setup scheduler sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1) From 3aec976fc30347e4ea70e5f97c1bb4123cc218fd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 11 Apr 2026 02:39:09 +0000 Subject: [PATCH 17/17] Apply style fixes --- .../ernie_image/pipeline_ernie_image.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 19dd4a7fafb0..9fbeee3395ec 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -23,11 +23,11 @@ from PIL import Image from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from ...utils.torch_utils import randn_tensor from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor from .pipeline_output import ErnieImagePipelineOutput @@ -206,8 +206,8 @@ def __call__( num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, - prompt_embeds: list[torch.FloatTensor] | None = None, - negative_prompt_embeds: list[torch.FloatTensor] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, output_type: str = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, @@ -267,10 +267,7 @@ def __call__( # [Phase 1] PE: enhance prompts revised_prompts: Optional[List[str]] = None if prompt is not None and use_pe and self.pe is not None and self.pe_tokenizer is not None: - prompt = [ - self._enhance_prompt_with_pe(p, device, width=width, height=height) - for p in prompt - ] + prompt = [self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt] revised_prompts = list(prompt) if prompt is not None: @@ -298,9 +295,7 @@ def __call__( if negative_prompt_embeds is not None: uncond_text_hiddens = negative_prompt_embeds else: - uncond_text_hiddens = self.encode_prompt( - negative_prompt, device, num_images_per_prompt - ) + uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt) # Latent dimensions latent_h = height // self.vae_scale_factor @@ -310,10 +305,10 @@ def __call__( # Initialize latents if latents is None: latents = randn_tensor( - (total_batch_size, latent_channels, latent_h, latent_w), - generator=generator, - device=device, - dtype=dtype + (total_batch_size, latent_channels, latent_h, latent_w), + generator=generator, + device=device, + dtype=dtype, ) # Setup scheduler