From c1f00a293c51a67b07ed4c2f3d92a1a5bfce1aa7 Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 9 Apr 2026 19:35:36 +0800 Subject: [PATCH 1/6] [Feat] support JoyAIImagePipeline Signed-off-by: Lancer --- src/diffusers/__init__.py | 1 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_joyai_image.py | 5 + src/diffusers/pipelines/__init__.py | 2 + .../pipelines/joyai_image/__init__.py | 47 + .../joyai_image/pipeline_joyai_image.py | 896 ++++++++++++++++++ .../pipelines/joyai_image/pipeline_output.py | 15 + 7 files changed, 967 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_joyai_image.py create mode 100644 src/diffusers/pipelines/joyai_image/__init__.py create mode 100644 src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py create mode 100644 src/diffusers/pipelines/joyai_image/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f74c0bbcb4a..a23b89a77b0a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -545,6 +545,7 @@ "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", + "JoyAIImagePipeline", "HunyuanImagePipeline", "HunyuanImageRefinerPipeline", "HunyuanSkyreelsImageToVideoPipeline", diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..ce5a0a7cf5e3 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -34,6 +34,7 @@ from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel + from .transformer_joyai_image import JoyAIImageTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_joyai_image.py b/src/diffusers/models/transformers/transformer_joyai_image.py new file mode 100644 index 000000000000..86493eae48c1 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyai_image.py @@ -0,0 +1,5 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. + +from ..joyai_image.transformer import Transformer3DModel as JoyAIImageTransformer3DModel + +__all__ = ["JoyAIImageTransformer3DModel"] diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 05aad6e349f6..432fc4034f22 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -263,6 +263,7 @@ _import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"] _import_structure["hidream_image"] = ["HiDreamImagePipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] + _import_structure["joyai_image"] = ["JoyAIImagePipeline"] _import_structure["hunyuan_video"] = [ "HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline", @@ -706,6 +707,7 @@ ) from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline + from .joyai_image import JoyAIImagePipeline from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, diff --git a/src/diffusers/pipelines/joyai_image/__init__.py b/src/diffusers/pipelines/joyai_image/__init__.py new file mode 100644 index 000000000000..cacb9296401a --- /dev/null +++ b/src/diffusers/pipelines/joyai_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from diffusers.utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from diffusers.utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_joyai_image"] = ["JoyAIImagePipeline"] + _import_structure["pipeline_output"] = ["JoyAIImagePipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from diffusers.utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_joyai_image import JoyAIImagePipeline + from .pipeline_output import JoyAIImagePipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py new file mode 100644 index 000000000000..dc77ffe11e2e --- /dev/null +++ b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py @@ -0,0 +1,896 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import AutoProcessor + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.joyai_image import JoyAIFlowMatchDiscreteScheduler, Transformer3DModel, WanxVAE, load_joyai_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, empty_device_cache, get_device +from diffusers.utils import is_accelerate_available, is_accelerate_version, logging +from diffusers.utils.torch_utils import randn_tensor + +from .pipeline_output import JoyAIImagePipelineOutput + + +logger = logging.get_logger(__name__) + +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, +} + +PROMPT_TEMPLATE_ENCODE = { + "image": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "multiple_images": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n", + "video": "<|im_start|>system\n \nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", +} + +PROMPT_TEMPLATE_START_IDX = { + "image": 34, + "multiple_images": 34, + "video": 91, +} + + +@dataclass +class JoyAIImageSourceConfig: + source_root: Path + dit_precision: str = "bf16" + vae_precision: str = "bf16" + text_encoder_precision: str = "bf16" + text_token_max_length: int = 2048 + enable_multi_task_training: bool = False + hsdp_shard_dim: int = 1 + reshard_after_forward: bool = False + use_fsdp_inference: bool = False + cpu_offload: bool = False + pin_cpu_memory: bool = False + dit_arch_config: dict = field(default_factory=lambda: { + "hidden_size": 4096, + "in_channels": 16, + "heads_num": 32, + "mm_double_blocks_depth": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_states_dim": 4096, + "rope_type": "rope", + "dit_modulation_type": "wanx", + "theta": 10000, + "attn_backend": "flash_attn", + }) + scheduler_arch_config: dict = field(default_factory=lambda: { + "num_train_timesteps": 1000, + "shift": 4.0, + }) + + @property + def text_encoder_arch_config(self) -> dict: + return {"params": {"text_encoder_ckpt": str(self.source_root / "JoyAI-Image-Und")}} + + +def _load_transformer_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]: + state = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + if "model" in state: + state = state["model"] + return state + + +def _build_joyai_source_config(source_root: Path) -> JoyAIImageSourceConfig: + return JoyAIImageSourceConfig(source_root=source_root) + + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class JoyAIImagePipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: Any, + text_encoder: Any, + tokenizer: Any, + transformer: Any, + scheduler: Any, + args: Any = None, + ): + super().__init__() + self.args = args + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.enable_multi_task = bool(getattr(self.args, "enable_multi_task_training", False)) + if hasattr(self.vae, "ffactor_spatial"): + self.vae_scale_factor = self.vae.ffactor_spatial + self.vae_scale_factor_temporal = self.vae.ffactor_temporal + else: + self.vae_scale_factor = 8 + self.vae_scale_factor_temporal = 4 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.qwen_processor = None + text_encoder_ckpt = None + text_encoder_cfg = getattr(self.args, "text_encoder_arch_config", None) + if isinstance(text_encoder_cfg, dict): + text_encoder_params = text_encoder_cfg.get("params", {}) + text_encoder_ckpt = text_encoder_params.get("text_encoder_ckpt") + if text_encoder_ckpt is not None: + self.qwen_processor = AutoProcessor.from_pretrained( + text_encoder_ckpt, + local_files_only=True, + trust_remote_code=True, + ) + + self.text_token_max_length = int(getattr(self.args, "text_token_max_length", 2048)) + self.prompt_template_encode = PROMPT_TEMPLATE_ENCODE + self.prompt_template_encode_start_idx = PROMPT_TEMPLATE_START_IDX + + @staticmethod + def _dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]: + if torch_dtype is None: + return None + for name, value in PRECISION_TO_TYPE.items(): + if value == torch_dtype and name in {"fp32", "fp16", "bf16"}: + return name + raise ValueError(f"Unsupported torch dtype for JoyAIImagePipeline: {torch_dtype}") + + @staticmethod + def _resolve_manifest_path(source_root: Path, manifest_value: Optional[str]) -> Optional[Path]: + if manifest_value is None: + return None + path = Path(manifest_value) + if path.parts and path.parts[0] == source_root.name: + path = Path(*path.parts[1:]) + return source_root / path + + @classmethod + def _is_joyai_source_dir(cls, path: Path) -> bool: + return ( + path.is_dir() + and (path / "infer_config.py").is_file() + and (path / "manifest.json").is_file() + and (path / "transformer").is_dir() + and (path / "vae").is_dir() + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + source_path = Path(pretrained_model_name_or_path) if pretrained_model_name_or_path is not None else None + if source_path is not None and cls._is_joyai_source_dir(source_path): + return cls.from_joyai_sources(pretrained_model_name_or_path, **kwargs) + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + @classmethod + def from_joyai_sources( + cls, + pretrained_model_name_or_path: Union[str, Path], + torch_dtype: Optional[torch.dtype] = None, + official_repo_path: Optional[Union[str, Path]] = None, + device: Optional[Union[str, torch.device]] = None, + hsdp_shard_dim_override: Optional[int] = None, + **kwargs, + ): + source_root = Path(pretrained_model_name_or_path) + if not cls._is_joyai_source_dir(source_root): + raise ValueError(f"Not a valid JoyAI source checkpoint directory: {source_root}") + + precision = cls._dtype_to_precision(torch_dtype) + cfg = _build_joyai_source_config(source_root) + + manifest = json.loads((source_root / "manifest.json").read_text()) + transformer_ckpt = cls._resolve_manifest_path(source_root, manifest.get("transformer_ckpt")) + vae_ckpt = source_root / "vae" / "Wan2.1_VAE.pth" + text_encoder_ckpt = source_root / "JoyAI-Image-Und" + + if precision is not None: + cfg.dit_precision = precision + cfg.vae_precision = precision + cfg.text_encoder_precision = precision + + if hsdp_shard_dim_override is not None: + cfg.hsdp_shard_dim = hsdp_shard_dim_override + + load_device = torch.device(device) if device is not None else torch.device("cpu") + dit = Transformer3DModel( + args=cfg, + dtype=PRECISION_TO_TYPE[cfg.dit_precision], + device=load_device, + **cfg.dit_arch_config, + ) + state_dict = _load_transformer_state_dict(transformer_ckpt) + if "img_in.weight" in state_dict and dit.img_in.weight.shape != state_dict["img_in.weight"].shape: + v = state_dict["img_in.weight"] + v_new = v.new_zeros(dit.img_in.weight.shape) + v_new[:, : v.shape[1], :, :, :] = v + state_dict["img_in.weight"] = v_new + dit.load_state_dict(state_dict, strict=True) + dit = dit.to(dtype=PRECISION_TO_TYPE[cfg.dit_precision]) + dit = dit.eval() + + vae = WanxVAE( + pretrained=str(vae_ckpt), + torch_dtype=PRECISION_TO_TYPE[cfg.vae_precision], + device=load_device, + ) + tokenizer, text_encoder = load_joyai_text_encoder( + text_encoder_ckpt=str(text_encoder_ckpt), + torch_dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision], + device=load_device, + ) + scheduler = JoyAIFlowMatchDiscreteScheduler(**cfg.scheduler_arch_config) + + pipe = cls( + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=dit, + scheduler=scheduler, + args=cfg, + ) + if device is not None: + pipe._joyai_execution_device_override = torch.device(device) + return pipe + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + return torch.split(selected, valid_lengths.tolist(), dim=0) + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + template_type: str = "image", + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._get_runtime_execution_device() + dtype = dtype or next(self.text_encoder.parameters()).dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + max_length=self.text_token_max_length + drop_idx, + padding=True, + truncation=True, + return_tensors="pt", + ).to(device) + encoder_hidden_states = self._run_text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = min( + self.text_token_max_length, + max(u.size(0) for u in split_hidden_states), + max(u.size(0) for u in attn_mask_list), + ) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + return prompt_embeds.to(dtype=dtype, device=device), encoder_attention_mask + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + images: Optional[List[Any]] = None, + template_type: str = "multiple_images", + max_sequence_length: Optional[int] = None, + drop_vit_feature: bool = False, + ): + if self.qwen_processor is None: + raise ValueError("Qwen processor is required for JoyAI image-edit prompt encoding.") + device = device or self._get_runtime_execution_device() + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + inputs = self.qwen_processor(text=prompt, images=images, padding=True, return_tensors="pt").to(device) + encoder_hidden_states = self._run_text_encoder(**inputs) + last_hidden_states = encoder_hidden_states.hidden_states[-1] + if drop_vit_feature: + input_ids = inputs["input_ids"] + vlm_image_end_idx = torch.where(input_ids[0] == 151653)[0][-1] + drop_idx = int(vlm_image_end_idx.item()) + 1 + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + images: Optional[List[Any]] = None, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + template_type: str = "image", + drop_vit_feature: bool = False, + ): + if images is not None: + return self.encode_prompt_multiple_images( + prompt=prompt, + images=images, + device=device, + max_sequence_length=max_sequence_length, + drop_vit_feature=drop_vit_feature, + ) + + device = device or self._get_runtime_execution_device() + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, template_type, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + images=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.") + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + if prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.") + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` must also be passed.") + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` must also be passed." + ) + + def _vae_compute_dtype(self) -> torch.dtype: + if hasattr(self.vae, "model"): + return next(self.vae.model.parameters()).dtype + return next(self.vae.parameters()).dtype + + def _get_runtime_execution_device(self) -> torch.device: + override = getattr(self, "_joyai_execution_device_override", None) + if override is not None: + return torch.device(override) + return self._execution_device + + def _is_sequential_cpu_offload_enabled(self) -> bool: + return bool(getattr(self, "_joyai_sequential_cpu_offload_enabled", False)) + + def _uses_manual_sequential_offload(self, component_name: str) -> bool: + manual_components = getattr(self, "_joyai_manual_offload_components", set()) + return self._is_sequential_cpu_offload_enabled() and component_name in manual_components + + def _offload_component_to_cpu(self, component_name: str): + component = getattr(self, component_name, None) + if component is None: + return + component.to("cpu") + empty_device_cache(getattr(self._get_runtime_execution_device(), "type", "cuda")) + + def _run_text_encoder(self, **inputs): + if self._uses_manual_sequential_offload("text_encoder"): + self.text_encoder.to(self._get_runtime_execution_device()) + try: + return self.text_encoder(**inputs, output_hidden_states=True) + finally: + self._offload_component_to_cpu("text_encoder") + return self.text_encoder(**inputs, output_hidden_states=True) + + def _get_vae_scale(self, device: torch.device, dtype: torch.dtype): + mean = getattr(self.vae, "mean", None) + std = getattr(self.vae, "std", None) + if mean is None or std is None: + return None + mean = mean.to(device=device, dtype=dtype) + std = std.to(device=device, dtype=dtype) + return [mean, 1.0 / std] + + def _encode_with_vae(self, videos: torch.Tensor) -> torch.Tensor: + device = self._get_runtime_execution_device() + vae_dtype = self._vae_compute_dtype() + videos = videos.to(device=device, dtype=vae_dtype) + + if self._uses_manual_sequential_offload("vae") and hasattr(self.vae, "model"): + scale = self._get_vae_scale(device=device, dtype=vae_dtype) + self.vae.model.to(device=device, dtype=vae_dtype) + try: + return self.vae.model.encode(videos, scale=scale) + finally: + self.vae.model.to("cpu") + empty_device_cache(device.type) + + if hasattr(self.vae, "mean"): + self.vae.mean = self.vae.mean.to(device=device, dtype=vae_dtype) + if hasattr(self.vae, "std"): + self.vae.std = self.vae.std.to(device=device, dtype=vae_dtype) + if hasattr(self.vae, "scale"): + self.vae.scale = [self.vae.mean, 1.0 / self.vae.std] + if hasattr(self.vae, "config"): + if hasattr(self.vae.config, "latents_mean"): + self.vae.config.latents_mean = self.vae.mean + if hasattr(self.vae.config, "latents_std"): + self.vae.config.latents_std = self.vae.std + + self.vae.to(device=device, dtype=vae_dtype) + return self.vae.encode(videos) + + def _decode_with_vae(self, latents: torch.Tensor): + device = self._get_runtime_execution_device() + vae_dtype = self._vae_compute_dtype() + latents = latents.to(device=device, dtype=vae_dtype) + + if self._uses_manual_sequential_offload("vae") and hasattr(self.vae, "model"): + scale = self._get_vae_scale(device=device, dtype=vae_dtype) + self.vae.model.to(device=device, dtype=vae_dtype) + try: + videos = [self.vae.model.decode(u.unsqueeze(0), scale=scale).clamp_(-1, 1).squeeze(0) for u in latents] + return torch.stack(videos, dim=0) + finally: + self.vae.model.to("cpu") + empty_device_cache(device.type) + + if hasattr(self.vae, "mean"): + self.vae.mean = self.vae.mean.to(device=device, dtype=vae_dtype) + if hasattr(self.vae, "std"): + self.vae.std = self.vae.std.to(device=device, dtype=vae_dtype) + if hasattr(self.vae, "scale"): + self.vae.scale = [self.vae.mean, 1.0 / self.vae.std] + if hasattr(self.vae, "config"): + if hasattr(self.vae.config, "latents_mean"): + self.vae.config.latents_mean = self.vae.mean + if hasattr(self.vae.config, "latents_std"): + self.vae.config.latents_std = self.vae.std + + self.vae.to(device=device, dtype=vae_dtype) + return self.vae.decode(latents, return_dict=False)[0] + + def prepare_latents( + self, + batch_size, + num_items, + num_channels_latents, + height, + width, + video_length, + dtype, + device, + generator, + latents=None, + reference_images=None, + image=None, + last_image=None, + ): + shape = ( + batch_size, + num_items, + num_channels_latents, + (video_length - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}." + ) + + if latents is None: + if reference_images is not None and len(reference_images) > 0: + ref_img = [torch.from_numpy(np.array(x.convert("RGB"))) for x in reference_images] + ref_img = torch.stack(ref_img).to(device=device, dtype=dtype) + ref_img = ref_img / 127.5 - 1.0 + ref_img = ref_img.permute(0, 3, 1, 2).unsqueeze(2) + ref_vae = self._encode_with_vae(ref_img) + ref_vae = ref_vae.reshape(shape[0], num_items - 1, *ref_vae.shape[1:]) + noise = randn_tensor((shape[0], 1, *shape[2:]), generator=generator, device=device, dtype=dtype) + latents = torch.cat([ref_vae, noise], dim=1) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + if not self.enable_multi_task: + return latents, None + raise NotImplementedError("JoyAI multi-task conditioning is not implemented in the diffusers adaptation yet.") + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + self._maybe_raise_error_if_group_offload_active(raise_error=True) + self.remove_all_hooks() + + is_pipeline_device_mapped = self._is_pipeline_device_mapped() + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload()` isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." + ) + + if device is None: + device = get_device() + if device == "cpu": + raise RuntimeError("`enable_sequential_cpu_offload` requires accelerator, but not found") + + torch_device = torch.device(device) + device_index = torch_device.index + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0) + device_type = torch_device.type + device = torch.device(f"{device_type}:{self._offload_gpu_id}") + self._offload_device = device + + if self.device.type != "cpu": + orig_device_type = self.device.type + self.to("cpu", silence_dtype_warnings=True) + empty_device_cache(orig_device_type) + + self._joyai_manual_offload_components = {"text_encoder", "vae"} + + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if name in self._exclude_from_cpu_offload: + model.to(device) + continue + + if name in self._joyai_manual_offload_components: + model.to("cpu") + continue + + offload_buffers = len(model._parameters) > 0 + params = list(model.parameters()) + on_cpu = len(params) == 0 or all(param.device.type == "cpu" for param in params) + state_dict = model.state_dict() if on_cpu else None + cpu_offload(model, device, offload_buffers=offload_buffers, state_dict=state_dict) + + self._joyai_sequential_cpu_offload_enabled = True + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def pad_sequence(self, x: torch.Tensor, target_length: int): + current_length = x.shape[1] + if current_length >= target_length: + return x[:, -target_length:] + padding_length = target_length - current_length + if x.ndim >= 3: + padding = torch.zeros((x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device) + else: + padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) + return torch.cat([x, padding], dim=1) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + num_frames: int = 1, + images: Optional[List[Any]] = None, + image_condition: Optional[torch.Tensor] = None, + last_image_condition: Optional[torch.Tensor] = None, + data_type: str = "image", + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 4096, + drop_vit_feature: bool = False, + **kwargs, + ): + self.check_inputs( + prompt, + height, + width, + images=images, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._get_runtime_execution_device() + template_type = "image" if num_frames == 1 else "video" + num_items = 1 if images is None or len(images) == 0 else 1 + len(images) + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + images=images, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + template_type=template_type, + drop_vit_feature=drop_vit_feature, + ) + + if self.do_classifier_free_guidance: + if negative_prompt is None and negative_prompt_embeds is None: + default_negative_prompt = "" + if num_items <= 1: + negative_prompt = [f"<|im_start|>user\n{default_negative_prompt}<|im_end|>\n"] * batch_size + else: + image_tokens = "\n" * (num_items - 1) + negative_prompt = [ + f"<|im_start|>user\n{image_tokens}{default_negative_prompt}<|im_end|>\n" + ] * batch_size + + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + images=images, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + template_type=template_type, + ) + + max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) + prompt_embeds = torch.cat( + [ + self.pad_sequence(negative_prompt_embeds, max_seq_len), + self.pad_sequence(prompt_embeds, max_seq_len), + ] + ) + if prompt_embeds_mask is not None: + prompt_embeds_mask = torch.cat( + [ + self.pad_sequence(negative_prompt_embeds_mask, max_seq_len), + self.pad_sequence(prompt_embeds_mask, max_seq_len), + ] + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + + num_channels_latents = self.vae.config.latent_channels + latents, condition = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_items, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + reference_images=images, + image=image_condition, + last_image=last_image_condition, + ) + + target_dtype = PRECISION_TO_TYPE.get(getattr(self.args, "dit_precision", "bf16"), prompt_embeds.dtype) + autocast_enabled = target_dtype != torch.float32 and device.type == "cuda" + vae_dtype = PRECISION_TO_TYPE.get(getattr(self.args, "vae_precision", "bf16"), prompt_embeds.dtype) + vae_autocast_enabled = vae_dtype != torch.float32 and device.type == "cuda" + + self._num_timesteps = len(timesteps) + if num_items > 1: + ref_latents = latents[:, : (num_items - 1)].clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + if num_items > 1: + latents[:, : (num_items - 1)] = ref_latents.clone() + + latents_ = torch.cat([latents, condition], dim=2) if condition is not None else latents + latent_model_input = torch.cat([latents_] * 2) if self.do_classifier_free_guidance else latents_ + latent_model_input = latent_model_input.to(device=device, dtype=target_dtype) + prompt_embeds_input = prompt_embeds.to(device=device, dtype=target_dtype) + t_expand = t.repeat(latent_model_input.shape[0]) + + with torch.autocast(device_type=device.type, dtype=target_dtype, enabled=autocast_enabled): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=prompt_embeds_input, + encoder_hidden_states_mask=prompt_embeds_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=2, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + else: + latents = latents.reshape(-1, *latents.shape[2:]) + with torch.autocast(device_type=device.type, dtype=vae_dtype, enabled=vae_autocast_enabled): + decoded = self._decode_with_vae(latents) + decoded = decoded.reshape(batch_size, num_items, *decoded.shape[1:]) + image = decoded[:, -1, :, 0] + image = (image / 2 + 0.5).clamp(0, 1) + + self.maybe_free_model_hooks() + + if output_type == "pt": + output_image = image.cpu().float() + elif output_type == "pil": + output_image = self.image_processor.numpy_to_pil(image.cpu().permute(0, 2, 3, 1).float().numpy()) + else: + output_image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if not return_dict: + return (output_image,) + return JoyAIImagePipelineOutput(images=output_image) + + +__all__ = ["JoyAIImagePipeline"] diff --git a/src/diffusers/pipelines/joyai_image/pipeline_output.py b/src/diffusers/pipelines/joyai_image/pipeline_output.py new file mode 100644 index 000000000000..d085cafc4790 --- /dev/null +++ b/src/diffusers/pipelines/joyai_image/pipeline_output.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Union + +import numpy as np +from PIL import Image + +from diffusers.utils import BaseOutput + + +@dataclass +class JoyAIImagePipelineOutput(BaseOutput): + images: Union[Image.Image, np.ndarray] + + +__all__ = ["JoyAIImagePipelineOutput"] From d6365ec0969880e8b034a7e095a92ef52147001a Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 9 Apr 2026 19:53:41 +0800 Subject: [PATCH 2/6] upd Signed-off-by: Lancer --- src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py index dc77ffe11e2e..dd17648bf55a 100644 --- a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py +++ b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py @@ -191,6 +191,7 @@ def __init__( self.text_token_max_length = int(getattr(self.args, "text_token_max_length", 2048)) self.prompt_template_encode = PROMPT_TEMPLATE_ENCODE self.prompt_template_encode_start_idx = PROMPT_TEMPLATE_START_IDX + self._joyai_force_vae_fp32 = True @staticmethod def _dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]: @@ -449,6 +450,8 @@ def check_inputs( ) def _vae_compute_dtype(self) -> torch.dtype: + if getattr(self, "_joyai_force_vae_fp32", False): + return torch.float32 if hasattr(self.vae, "model"): return next(self.vae.model.parameters()).dtype return next(self.vae.parameters()).dtype From e774af28136b42b3bd759c8dcd50c82ced1a8ce6 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 10 Apr 2026 01:07:09 +0800 Subject: [PATCH 3/6] upd Signed-off-by: Lancer --- src/diffusers/models/__init__.py | 4 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoder_kl_joyai_image.py | 700 ++++++++++++++++++ .../transformers/transformer_joyai_image.py | 662 ++++++++++++++++- .../joyai_image/pipeline_joyai_image.py | 380 ++++++---- src/diffusers/schedulers/__init__.py | 8 + .../scheduling_joyai_flow_match_discrete.py | 260 +++++++ 7 files changed, 1859 insertions(+), 156 deletions(-) create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py create mode 100644 src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..3b139c2fcdd5 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -48,6 +48,7 @@ _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] + _import_structure["autoencoders.autoencoder_kl_joyai_image"] = ["JoyAIImageVAE"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -110,6 +111,7 @@ _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] + _import_structure["transformers.transformer_joyai_image"] = ["JoyAIImageTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] @@ -171,6 +173,7 @@ AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, + JoyAIImageVAE, AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, @@ -225,6 +228,7 @@ HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, + JoyAIImageTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 609146ec340d..317055ee6d26 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -16,6 +16,7 @@ from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi +from .autoencoder_kl_joyai_image import JoyAIImageVAE from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py new file mode 100644 index 000000000000..759d53f7aef9 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py @@ -0,0 +1,700 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F + + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + resampled_height, resampled_width = x.shape[-2:] + x = x.reshape(b, t, x.shape[1], resampled_height, resampled_width).permute(0, 2, 1, 3, 4) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = x.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale=None, return_posterior=False): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if scale is None or return_posterior: + return mu, log_var + + mu = self.reparameterize(mu, log_var) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale=None): + self.clear_cache() + # z: [b,c,t,h,w] + if scale is not None: + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False, scale=None): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + mu = mu + std * torch.randn_like(std) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = { + "dim": 96, + "z_dim": z_dim, + "dim_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_scales": [], + "temperal_downsample": [False, True, True], + "dropout": 0.0, + } + cfg.update(**kwargs) + + # init model + with torch.device('meta'): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f'loading {pretrained_path}') + + if pretrained_path.endswith('.safetensors'): + from safetensors.torch import load_file + pretrained_state_dict = load_file(pretrained_path, device='cpu') + else: + pretrained_state_dict = torch.load(pretrained_path, map_location='cpu') + + model.load_state_dict(pretrained_state_dict, assign=True) + + return model + + + +class WanxVAE(nn.Module): + # @register_to_config + def __init__(self, + pretrained='', + torch_dtype=torch.float32, + device='cuda' + ): + super().__init__() + + self.dtype = torch_dtype + self.device = device + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=self.dtype, device=device) + self.std = torch.tensor(std, dtype=self.dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + self.config = lambda: None + self.config.latents_mean = self.mean + self.config.latents_std = self.std + self.ffactor_spatial = 8 + self.ffactor_temporal = 4 + self.config.latent_channels = 16 + + # init model + self.model = _video_vae( + pretrained_path=pretrained, + z_dim=16, + ).eval().requires_grad_(False) + self.model = self.model.to(device=device, dtype=torch_dtype) + + def encode(self, videos, return_posterior=False, **kwargs): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=torch.float): + if return_posterior: + mus, log_vars = self.model.encode( + videos, scale=self.scale, return_posterior=True) + return mus, log_vars + else: + latents = self.model.encode(videos, scale=self.scale) + return latents + + def decode(self, zs, **kwargs): + with amp.autocast(dtype=torch.float): + videos = [ + self.model.decode(u.unsqueeze(0), scale=self.scale).clamp_(-1, 1).squeeze(0) + for u in zs + ] + videos = torch.stack(videos, dim=0) + return (videos, ) + + + +JoyAIImageVAE = WanxVAE + +__all__ = ["WanxVAE", "JoyAIImageVAE"] diff --git a/src/diffusers/models/transformers/transformer_joyai_image.py b/src/diffusers/models/transformers/transformer_joyai_image.py index 86493eae48c1..6003bc7f038a 100644 --- a/src/diffusers/models/transformers/transformer_joyai_image.py +++ b/src/diffusers/models/transformers/transformer_joyai_image.py @@ -1,5 +1,663 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from diffusers.models.attention import AttentionModuleMixin, FeedForward +from diffusers.models.attention_dispatch import AttentionBackendName, dispatch_attention_fn +from diffusers.models.embeddings import ( + PixArtAlphaTextProjection, + TimestepEmbedding, + Timesteps, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) +from diffusers.models.normalization import RMSNorm + + +def _create_modulation( + modulate_type: str, + hidden_size: int, + factor: int, + dtype=None, + device=None): + factory_kwargs = {"dtype": dtype, "device": device} + if modulate_type == 'wanx': + return _WanModulation(hidden_size, factor, **factory_kwargs) + raise ValueError( + f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.") + + +class _WanModulation(nn.Module): + """Modulation layer for WanX.""" + + def __init__( + self, + hidden_size: int, + factor: int, + dtype=None, + device=None, + ): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, + dtype=dtype, device=device) / hidden_size**0.5, + requires_grad=True + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(x.shape) != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] + + +class JoyAIJointAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "JoyAIJointAttention", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attention_kwargs: Optional[dict[str, Any]] = None, + ) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + backend = AttentionBackendName.NATIVE if attn.backend == "torch_spda" else AttentionBackendName.FLASH_VARLEN + + try: + return dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + attention_kwargs=attention_kwargs, + backend=backend, + parallel_config=self._parallel_config, + ) + except (RuntimeError, ValueError, TypeError): + return dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + attention_kwargs=attention_kwargs, + backend=AttentionBackendName.NATIVE, + parallel_config=self._parallel_config, + ) + + +class JoyAIJointAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = JoyAIJointAttnProcessor + _available_processors = [JoyAIJointAttnProcessor] + + def __init__(self, backend: str = "flash_attn", processor=None) -> None: + super().__init__() + self.backend = backend + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attention_kwargs: Optional[dict[str, Any]] = None, + ) -> torch.Tensor: + return self.processor(self, query, key, value, attention_mask, attention_kwargs) + + +class JoyAIImageTransformerBlock(nn.Module): + """Joint text-image transformer block.""" + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + dit_modulation_type: Optional[str] = "wanx", + attn_backend: str = 'flash_attn', + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.dit_modulation_type = dit_modulation_type + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.img_mod = _create_modulation( + modulate_type=self.dit_modulation_type, + hidden_size=hidden_size, + factor=6, + **factory_kwargs, + ) + self.img_norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.img_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=True, **factory_kwargs + ) + self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.img_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=True, **factory_kwargs + ) + + self.img_norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, + activation_fn="gelu-approximate") + + self.txt_mod = _create_modulation( + modulate_type=self.dit_modulation_type, + hidden_size=hidden_size, + factor=6, + **factory_kwargs, + ) + self.txt_norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.txt_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=True, **factory_kwargs + ) + self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) + self.txt_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=True, **factory_kwargs + ) + self.attn = JoyAIJointAttention(attn_backend) + + self.txt_norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, + activation_fn="gelu-approximate") + + @staticmethod + def _modulate( + hidden_states: torch.Tensor, shift: torch.Tensor | None = None, scale: torch.Tensor | None = None + ) -> torch.Tensor: + if scale is None and shift is None: + return hidden_states + if shift is None: + return hidden_states * (1 + scale.unsqueeze(1)) + if scale is None: + return hidden_states + shift.unsqueeze(1) + return hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + @staticmethod + def _apply_gate( + hidden_states: torch.Tensor, gate: torch.Tensor | None = None, tanh: bool = False + ) -> torch.Tensor: + if gate is None: + return hidden_states + if tanh: + return hidden_states * gate.unsqueeze(1).tanh() + return hidden_states * gate.unsqueeze(1) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + vis_freqs_cis: tuple = None, + txt_freqs_cis: tuple = None, + attn_kwargs: Optional[dict] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(vec) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(vec) + + img_modulated = self.img_norm1(img) + img_modulated = self._modulate( + img_modulated, shift=img_mod1_shift, scale=img_mod1_scale + ) + img_qkv = self.img_attn_qkv(img_modulated) + batch_size, image_sequence_length, _ = img_qkv.shape + img_qkv = img_qkv.view(batch_size, image_sequence_length, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4) + img_q, img_k, img_v = img_qkv.unbind(0) + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + + if vis_freqs_cis is not None: + img_q = apply_rotary_emb(img_q, vis_freqs_cis, sequence_dim=1) + img_k = apply_rotary_emb(img_k, vis_freqs_cis, sequence_dim=1) + + txt_modulated = self.txt_norm1(txt) + txt_modulated = self._modulate( + txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale + ) + txt_qkv = self.txt_attn_qkv(txt_modulated) + _, text_sequence_length, _ = txt_qkv.shape + txt_qkv = txt_qkv.view(batch_size, text_sequence_length, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4) + txt_q, txt_k, txt_v = txt_qkv.unbind(0) + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + + if txt_freqs_cis is not None: + raise NotImplementedError("RoPE text is not supported for inference") + + + attention_output = self.attn( + torch.cat((img_q, txt_q), dim=1), + torch.cat((img_k, txt_k), dim=1), + torch.cat((img_v, txt_v), dim=1), + attention_mask=attn_kwargs.get("attention_mask") if attn_kwargs is not None else None, + attention_kwargs=attn_kwargs, + ) + attention_output = attention_output.flatten(2, 3) + image_attention_output = attention_output[:, : img.shape[1]] + text_attention_output = attention_output[:, img.shape[1]:] + + img = img + self._apply_gate(self.img_attn_proj(image_attention_output), + gate=img_mod1_gate) + img = img + self._apply_gate( + self.img_mlp( + self._modulate( + self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale + ) + ), + gate=img_mod2_gate, + ) + + txt = txt + self._apply_gate(self.txt_attn_proj(text_attention_output), + gate=txt_mod1_gate) + txt = txt + self._apply_gate( + self.txt_mlp( + self._modulate( + self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale + ) + ), + gate=txt_mod2_gate, + ) + + return img, txt + + +class JoyAITimeTextEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps( + num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding( + in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection( + text_embed_dim, dim, act_fn="gelu_tanh") + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + timestep_embedding = self.time_embedder(timestep).type_as(encoder_hidden_states) + modulation_states = self.time_proj(self.act_fn(timestep_embedding)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + return modulation_states, encoder_hidden_states + + +class JoyAIImageTransformer3DModel(ModelMixin, ConfigMixin): + _fsdp_shard_conditions: list = [ + lambda name, module: isinstance(module, JoyAIImageTransformerBlock)] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: tuple[int, int, int] = (1, 2, 2), + in_channels: int = 4, + out_channels: int = None, + hidden_size: int = 3072, + heads_num: int = 24, + text_states_dim: int = 4096, + mlp_width_ratio: float = 4.0, + mm_double_blocks_depth: int = 20, + rope_dim_list: tuple[int, int, int] = (16, 56, 56), + rope_type: str = 'rope', + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + dit_modulation_type: str = "wanx", + attn_backend: str = 'flash_attn', + theta: int = 256, + ): + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.heads_num = heads_num + self.rope_dim_list = rope_dim_list + self.dit_modulation_type = dit_modulation_type + self.rope_type = rope_type + self.theta = theta + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if hidden_size % heads_num != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}" + ) + + self.img_in = nn.Conv3d( + in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + self.condition_embedder = JoyAITimeTextEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_states_dim, + ) + + self.double_blocks = nn.ModuleList( + [ + JoyAIImageTransformerBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + dit_modulation_type=self.dit_modulation_type, + attn_backend=attn_backend, + **factory_kwargs, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + self.norm_out = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear( + hidden_size, out_channels * math.prod(patch_size), + **factory_kwargs) + + + @staticmethod + def _get_meshgrid_nd(start, *args, dim=2): + """Build an N-D meshgrid from integer sizes or ranges.""" + + def as_tuple(value): + if isinstance(value, int): + return (value,) * dim + if len(value) == dim: + return value + raise ValueError(f"Expected length {dim} or int, but got {value}") + if len(args) == 0: + num = as_tuple(start) + start = (0,) * dim + stop = num + elif len(args) == 1: + start = as_tuple(start) + stop = as_tuple(args[0]) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + start = as_tuple(start) + stop = as_tuple(args[0]) + num = as_tuple(args[1]) + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + grid = torch.stack(grid, dim=0) + + return grid + + + @staticmethod + def _get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + text_sequence_length=None, + ): + """Build visual and optional text rotary embeddings.""" + + grid = JoyAIImageTransformer3DModel._get_meshgrid_nd( + start, *args, dim=len(rope_dim_list) + ) + + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta=theta, + use_real=use_real, + ) + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) + sin = torch.cat([emb[1] for emb in embs], dim=1) + vis_emb = (cos, sin) + else: + vis_emb = torch.cat(embs, dim=1) + if text_sequence_length is not None: + embs_txt = [] + vis_max_ids = grid.view(-1).max().item() + text_positions = torch.arange(text_sequence_length) + vis_max_ids + 1 + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + text_positions, + theta=theta, + use_real=use_real, + ) + embs_txt.append(emb) + if use_real: + cos = torch.cat([emb[0] for emb in embs_txt], dim=1) + sin = torch.cat([emb[1] for emb in embs_txt], dim=1) + txt_emb = (cos, sin) + else: + txt_emb = torch.cat(embs_txt, dim=1) + else: + txt_emb = None + return vis_emb, txt_emb + + + + + def get_rotary_pos_embed(self, image_grid_size, text_sequence_length=None): + target_ndim = 3 + + if len(image_grid_size) != target_ndim: + image_grid_size = [1] * (target_ndim - len(image_grid_size)) + image_grid_size + head_dim = self.hidden_size // self.heads_num + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // + target_ndim for _ in range(target_ndim)] + assert ( + sum(rope_dim_list) == head_dim + ), "sum(rope_dim_list) should equal to head_dim of attention layer" + image_rotary_emb, text_rotary_emb = self._get_nd_rotary_pos_embed( + rope_dim_list, + image_grid_size, + text_sequence_length=text_sequence_length, + theta=self.theta, + use_real=True, + ) + return image_rotary_emb, text_rotary_emb + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + is_multi_item = (len(hidden_states.shape) == 6) + num_items = 0 + if is_multi_item: + num_items = hidden_states.shape[1] + if num_items > 1: + assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1" + hidden_states = torch.cat( + [ + hidden_states[:, -1:], + hidden_states[:, :-1] + ], + dim=1 + ) + batch_size, num_items, channels, frames_per_item, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).reshape( + batch_size, channels, num_items * frames_per_item, height, width + ) + + _, _, output_frames, output_height, output_width = hidden_states.shape + latent_frames, latent_height, latent_width = ( + output_frames // self.patch_size[0], + output_height // self.patch_size[1], + output_width // self.patch_size[2], + ) + image_hidden_states = self.img_in(hidden_states).flatten(2).transpose(1, 2) + + if encoder_hidden_states_mask is None: + encoder_hidden_states_mask = torch.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), + dtype=torch.bool, + device=image_hidden_states.device, + ) + else: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(device=image_hidden_states.device, dtype=torch.bool) + modulation_states, text_hidden_states = self.condition_embedder(timestep, encoder_hidden_states) + if modulation_states.shape[-1] > self.hidden_size: + modulation_states = modulation_states.unflatten(1, (6, -1)) + + text_seq_len = text_hidden_states.shape[1] + image_seq_len = image_hidden_states.shape[1] + image_rotary_emb, text_rotary_emb = self.get_rotary_pos_embed( + image_grid_size=(latent_frames, latent_height, latent_width), + text_sequence_length=text_seq_len if self.rope_type == 'mrope' else None, + ) + + attention_mask = torch.cat( + [ + torch.ones( + (encoder_hidden_states_mask.shape[0], image_seq_len), + dtype=torch.bool, + device=encoder_hidden_states_mask.device, + ), + encoder_hidden_states_mask.bool(), + ], + dim=1, + ) + attention_kwargs = { + 'thw': [latent_frames, latent_height, latent_width], + 'txt_len': text_seq_len, + 'attention_mask': attention_mask, + } + + for block in self.double_blocks: + image_hidden_states, text_hidden_states = block( + image_hidden_states, + text_hidden_states, + modulation_states, + image_rotary_emb, + text_rotary_emb, + attention_kwargs, + ) + + image_seq_len = image_hidden_states.shape[1] + hidden_states = torch.cat((image_hidden_states, text_hidden_states), dim=1) + image_hidden_states = hidden_states[:, :image_seq_len, ...] + image_hidden_states = self.proj_out(self.norm_out(image_hidden_states)) + image_hidden_states = self.unpatchify(image_hidden_states, latent_frames, latent_height, latent_width) + + if is_multi_item: + batch_size, channels, total_frames, height, width = image_hidden_states.shape + image_hidden_states = image_hidden_states.reshape( + batch_size, channels, num_items, total_frames // num_items, height, width + ).permute(0, 2, 1, 3, 4, 5) + if num_items > 1: + image_hidden_states = torch.cat( + [ + image_hidden_states[:, 1:], + image_hidden_states[:, :1], + ], + dim=1, + ) + + if return_dict: + return {"sample": image_hidden_states, "encoder_hidden_states": text_hidden_states} + return image_hidden_states, text_hidden_states + + def unpatchify(self, hidden_states, latent_frames, latent_height, latent_width): + channels = self.out_channels + patch_frames, patch_height, patch_width = self.patch_size + assert latent_frames * latent_height * latent_width == hidden_states.shape[1] + + hidden_states = hidden_states.reshape( + shape=( + hidden_states.shape[0], + latent_frames, + latent_height, + latent_width, + patch_frames, + patch_height, + patch_width, + channels, + ) + ) + hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) + + return hidden_states.reshape( + shape=( + hidden_states.shape[0], + channels, + latent_frames * patch_frames, + latent_height * patch_height, + latent_width * patch_width, + ) + ) -from ..joyai_image.transformer import Transformer3DModel as JoyAIImageTransformer3DModel __all__ = ["JoyAIImageTransformer3DModel"] diff --git a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py index dd17648bf55a..1a0f7aabf5a5 100644 --- a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py +++ b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py @@ -14,19 +14,20 @@ import inspect import json -import os from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, TypedDict, Union import numpy as np import torch -from transformers import AutoProcessor +from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, Qwen3VLForConditionalGeneration from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import VaeImageProcessor -from diffusers.models.joyai_image import JoyAIFlowMatchDiscreteScheduler, Transformer3DModel, WanxVAE, load_joyai_text_encoder +from diffusers.models.autoencoders.autoencoder_kl_joyai_image import JoyAIImageVAE +from diffusers.models.transformers.transformer_joyai_image import JoyAIImageTransformer3DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, empty_device_cache, get_device +from diffusers.schedulers.scheduling_joyai_flow_match_discrete import JoyAIFlowMatchDiscreteScheduler from diffusers.utils import is_accelerate_available, is_accelerate_version, logging from diffusers.utils.torch_utils import randn_tensor @@ -35,6 +36,7 @@ logger = logging.get_logger(__name__) + PRECISION_TO_TYPE = { "fp32": torch.float32, "float32": torch.float32, @@ -44,17 +46,34 @@ "bfloat16": torch.bfloat16, } -PROMPT_TEMPLATE_ENCODE = { - "image": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - "multiple_images": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n", - "video": "<|im_start|>system\n \nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", -} -PROMPT_TEMPLATE_START_IDX = { - "image": 34, - "multiple_images": 34, - "video": 91, -} +class JoyAIDitArchConfig(TypedDict): + hidden_size: int + in_channels: int + heads_num: int + mm_double_blocks_depth: int + out_channels: int + patch_size: list[int] + rope_dim_list: list[int] + text_states_dim: int + rope_type: str + dit_modulation_type: str + theta: int + attn_backend: str + + +class JoyAISchedulerArchConfig(TypedDict): + num_train_timesteps: int + shift: float + + +class JoyAIImageComponents(TypedDict): + args: "JoyAIImageSourceConfig" + tokenizer: PreTrainedTokenizerBase + text_encoder: Qwen3VLForConditionalGeneration + transformer: JoyAIImageTransformer3DModel + scheduler: JoyAIFlowMatchDiscreteScheduler + vae: JoyAIImageVAE @dataclass @@ -65,46 +84,146 @@ class JoyAIImageSourceConfig: text_encoder_precision: str = "bf16" text_token_max_length: int = 2048 enable_multi_task_training: bool = False - hsdp_shard_dim: int = 1 - reshard_after_forward: bool = False - use_fsdp_inference: bool = False - cpu_offload: bool = False - pin_cpu_memory: bool = False - dit_arch_config: dict = field(default_factory=lambda: { - "hidden_size": 4096, - "in_channels": 16, - "heads_num": 32, - "mm_double_blocks_depth": 40, - "out_channels": 16, - "patch_size": [1, 2, 2], - "rope_dim_list": [16, 56, 56], - "text_states_dim": 4096, - "rope_type": "rope", - "dit_modulation_type": "wanx", - "theta": 10000, - "attn_backend": "flash_attn", - }) - scheduler_arch_config: dict = field(default_factory=lambda: { - "num_train_timesteps": 1000, - "shift": 4.0, - }) + dit_arch_config: JoyAIDitArchConfig = field( + default_factory=lambda: { + "hidden_size": 4096, + "in_channels": 16, + "heads_num": 32, + "mm_double_blocks_depth": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_states_dim": 4096, + "rope_type": "rope", + "dit_modulation_type": "wanx", + "theta": 10000, + "attn_backend": "flash_attn", + } + ) + scheduler_arch_config: JoyAISchedulerArchConfig = field( + default_factory=lambda: { + "num_train_timesteps": 1000, + "shift": 4.0, + } + ) @property def text_encoder_arch_config(self) -> dict: return {"params": {"text_encoder_ckpt": str(self.source_root / "JoyAI-Image-Und")}} -def _load_transformer_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]: +def dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]: + if torch_dtype is None: + return None + for name, value in PRECISION_TO_TYPE.items(): + if value == torch_dtype and name in {"fp32", "fp16", "bf16"}: + return name + raise ValueError(f"Unsupported torch dtype for JoyAIImagePipeline: {torch_dtype}") + + +def resolve_manifest_path(source_root: Path, manifest_value: Optional[str]) -> Optional[Path]: + if manifest_value is None: + return None + path = Path(manifest_value) + if path.parts and path.parts[0] == source_root.name: + path = Path(*path.parts[1:]) + return source_root / path + + +def is_joyai_source_dir(path: Path) -> bool: + return ( + path.is_dir() + and (path / "infer_config.py").is_file() + and (path / "manifest.json").is_file() + and (path / "transformer").is_dir() + and (path / "vae").is_dir() + ) + + +def load_transformer_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]: state = torch.load(checkpoint_path, map_location="cpu", weights_only=True) if "model" in state: state = state["model"] return state -def _build_joyai_source_config(source_root: Path) -> JoyAIImageSourceConfig: - return JoyAIImageSourceConfig(source_root=source_root) +def load_joyai_components( + source_root: Union[str, Path], + torch_dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, +) -> JoyAIImageComponents: + source_root = Path(source_root) + if not is_joyai_source_dir(source_root): + raise ValueError(f"Not a valid JoyAI source checkpoint directory: {source_root}") + + precision = dtype_to_precision(torch_dtype) + cfg = JoyAIImageSourceConfig(source_root=source_root) + + manifest = json.loads((source_root / "manifest.json").read_text()) + transformer_ckpt = resolve_manifest_path(source_root, manifest.get("transformer_ckpt")) + vae_ckpt = source_root / "vae" / "Wan2.1_VAE.pth" + text_encoder_ckpt = source_root / "JoyAI-Image-Und" + + if precision is not None: + cfg.dit_precision = precision + cfg.vae_precision = precision + cfg.text_encoder_precision = precision + + load_device = torch.device(device) if device is not None else torch.device("cpu") + transformer = JoyAIImageTransformer3DModel( + dtype=PRECISION_TO_TYPE[cfg.dit_precision], + device=load_device, + **cfg.dit_arch_config, + ) + state_dict = load_transformer_state_dict(transformer_ckpt) + if "img_in.weight" in state_dict and transformer.img_in.weight.shape != state_dict["img_in.weight"].shape: + value = state_dict["img_in.weight"] + padded = value.new_zeros(transformer.img_in.weight.shape) + padded[:, : value.shape[1], :, :, :] = value + state_dict["img_in.weight"] = padded + transformer.load_state_dict(state_dict, strict=True) + transformer = transformer.to(dtype=PRECISION_TO_TYPE[cfg.dit_precision]).eval() + + vae = JoyAIImageVAE( + pretrained=str(vae_ckpt), + torch_dtype=PRECISION_TO_TYPE[cfg.vae_precision], + device=load_device, + ) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( + str(text_encoder_ckpt), + torch_dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision], + local_files_only=True, + trust_remote_code=True, + ).to(load_device).eval() + tokenizer = AutoTokenizer.from_pretrained( + str(text_encoder_ckpt), + local_files_only=True, + trust_remote_code=True, + ) + scheduler = JoyAIFlowMatchDiscreteScheduler(**cfg.scheduler_arch_config) + + return { + "args": cfg, + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "scheduler": scheduler, + "vae": vae, + } +PROMPT_TEMPLATE_ENCODE = { + "image": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "multiple_images": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n", + "video": "<|im_start|>system\n \nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", +} + +PROMPT_TEMPLATE_START_IDX = { + "image": 34, + "multiple_images": 34, + "video": 91, +} + def retrieve_timesteps( scheduler, @@ -148,12 +267,12 @@ class JoyAIImagePipeline(DiffusionPipeline): def __init__( self, - vae: Any, - text_encoder: Any, - tokenizer: Any, - transformer: Any, - scheduler: Any, - args: Any = None, + vae: JoyAIImageVAE, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: PreTrainedTokenizerBase, + transformer: JoyAIImageTransformer3DModel, + scheduler: JoyAIFlowMatchDiscreteScheduler, + args: JoyAIImageSourceConfig | None = None, ): super().__init__() self.args = args @@ -193,38 +312,10 @@ def __init__( self.prompt_template_encode_start_idx = PROMPT_TEMPLATE_START_IDX self._joyai_force_vae_fp32 = True - @staticmethod - def _dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]: - if torch_dtype is None: - return None - for name, value in PRECISION_TO_TYPE.items(): - if value == torch_dtype and name in {"fp32", "fp16", "bf16"}: - return name - raise ValueError(f"Unsupported torch dtype for JoyAIImagePipeline: {torch_dtype}") - - @staticmethod - def _resolve_manifest_path(source_root: Path, manifest_value: Optional[str]) -> Optional[Path]: - if manifest_value is None: - return None - path = Path(manifest_value) - if path.parts and path.parts[0] == source_root.name: - path = Path(*path.parts[1:]) - return source_root / path - - @classmethod - def _is_joyai_source_dir(cls, path: Path) -> bool: - return ( - path.is_dir() - and (path / "infer_config.py").is_file() - and (path / "manifest.json").is_file() - and (path / "transformer").is_dir() - and (path / "vae").is_dir() - ) - @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): source_path = Path(pretrained_model_name_or_path) if pretrained_model_name_or_path is not None else None - if source_path is not None and cls._is_joyai_source_dir(source_path): + if source_path is not None and is_joyai_source_dir(source_path): return cls.from_joyai_sources(pretrained_model_name_or_path, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -233,67 +324,22 @@ def from_joyai_sources( cls, pretrained_model_name_or_path: Union[str, Path], torch_dtype: Optional[torch.dtype] = None, - official_repo_path: Optional[Union[str, Path]] = None, device: Optional[Union[str, torch.device]] = None, - hsdp_shard_dim_override: Optional[int] = None, **kwargs, ): - source_root = Path(pretrained_model_name_or_path) - if not cls._is_joyai_source_dir(source_root): - raise ValueError(f"Not a valid JoyAI source checkpoint directory: {source_root}") - - precision = cls._dtype_to_precision(torch_dtype) - cfg = _build_joyai_source_config(source_root) - - manifest = json.loads((source_root / "manifest.json").read_text()) - transformer_ckpt = cls._resolve_manifest_path(source_root, manifest.get("transformer_ckpt")) - vae_ckpt = source_root / "vae" / "Wan2.1_VAE.pth" - text_encoder_ckpt = source_root / "JoyAI-Image-Und" - - if precision is not None: - cfg.dit_precision = precision - cfg.vae_precision = precision - cfg.text_encoder_precision = precision - - if hsdp_shard_dim_override is not None: - cfg.hsdp_shard_dim = hsdp_shard_dim_override - - load_device = torch.device(device) if device is not None else torch.device("cpu") - dit = Transformer3DModel( - args=cfg, - dtype=PRECISION_TO_TYPE[cfg.dit_precision], - device=load_device, - **cfg.dit_arch_config, - ) - state_dict = _load_transformer_state_dict(transformer_ckpt) - if "img_in.weight" in state_dict and dit.img_in.weight.shape != state_dict["img_in.weight"].shape: - v = state_dict["img_in.weight"] - v_new = v.new_zeros(dit.img_in.weight.shape) - v_new[:, : v.shape[1], :, :, :] = v - state_dict["img_in.weight"] = v_new - dit.load_state_dict(state_dict, strict=True) - dit = dit.to(dtype=PRECISION_TO_TYPE[cfg.dit_precision]) - dit = dit.eval() - - vae = WanxVAE( - pretrained=str(vae_ckpt), - torch_dtype=PRECISION_TO_TYPE[cfg.vae_precision], - device=load_device, - ) - tokenizer, text_encoder = load_joyai_text_encoder( - text_encoder_ckpt=str(text_encoder_ckpt), - torch_dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision], - device=load_device, + components = load_joyai_components( + source_root=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, ) - scheduler = JoyAIFlowMatchDiscreteScheduler(**cfg.scheduler_arch_config) pipe = cls( - vae=vae, - tokenizer=tokenizer, - text_encoder=text_encoder, - transformer=dit, - scheduler=scheduler, - args=cfg, + vae=components["vae"], + tokenizer=components["tokenizer"], + text_encoder=components["text_encoder"], + transformer=components["transformer"], + scheduler=components["scheduler"], + args=components["args"], ) if device is not None: pipe._joyai_execution_device_override = torch.device(device) @@ -318,9 +364,9 @@ def _get_qwen_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt template = self.prompt_template_encode[template_type] drop_idx = self.prompt_template_encode_start_idx[template_type] - txt = [template.format(e) for e in prompt] + formatted_prompts = [template.format(prompt_text) for prompt_text in prompt] txt_tokens = self.tokenizer( - txt, + formatted_prompts, max_length=self.text_token_max_length + drop_idx, padding=True, truncation=True, @@ -340,10 +386,21 @@ def _get_qwen_prompt_embeds( max(u.size(0) for u in attn_mask_list), ) prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + [ + torch.cat( + [ + hidden_state, + hidden_state.new_zeros(max_seq_len - hidden_state.size(0), hidden_state.size(1)), + ] + ) + for hidden_state in split_hidden_states + ] ) encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + [ + torch.cat([attention_mask_row, attention_mask_row.new_zeros(max_seq_len - attention_mask_row.size(0))]) + for attention_mask_row in attn_mask_list + ] ) return prompt_embeds.to(dtype=dtype, device=device), encoder_attention_mask @@ -417,17 +474,17 @@ def encode_prompt( def check_inputs( self, - prompt, - height, - width, - images=None, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - prompt_embeds_mask=None, - negative_prompt_embeds_mask=None, - callback_on_step_end_tensor_inputs=None, - ): + prompt: Optional[Union[str, List[str]]], + height: int, + width: int, + images: Optional[List[Any]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> None: if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -465,6 +522,28 @@ def _get_runtime_execution_device(self) -> torch.device: def _is_sequential_cpu_offload_enabled(self) -> bool: return bool(getattr(self, "_joyai_sequential_cpu_offload_enabled", False)) + def enable_manual_cpu_offload( + self, + device: torch.device | str, + components: Optional[List[str]] = None, + ) -> None: + """Enable manual CPU offload for selected components.""" + runtime_device = torch.device(device) + component_names = set(components or ["text_encoder", "vae"]) + + invalid_components = [name for name in component_names if name not in self.components] + if invalid_components: + raise ValueError(f"Unknown components for manual cpu offload: {invalid_components}") + + self._joyai_execution_device_override = runtime_device + self._joyai_sequential_cpu_offload_enabled = True + self._joyai_manual_offload_components = component_names + + for name in component_names: + component = getattr(self, name, None) + if isinstance(component, torch.nn.Module): + component.to("cpu") + def _uses_manual_sequential_offload(self, component_name: str) -> bool: manual_components = getattr(self, "_joyai_manual_offload_components", set()) return self._is_sequential_cpu_offload_enabled() and component_name in manual_components @@ -496,7 +575,7 @@ def _get_vae_scale(self, device: torch.device, dtype: torch.dtype): def _encode_with_vae(self, videos: torch.Tensor) -> torch.Tensor: device = self._get_runtime_execution_device() - vae_dtype = self._vae_compute_dtype() + vae_dtype = PRECISION_TO_TYPE.get(getattr(self.args, "vae_precision", "bf16"), videos.dtype) videos = videos.to(device=device, dtype=vae_dtype) if self._uses_manual_sequential_offload("vae") and hasattr(self.vae, "model"): @@ -566,8 +645,6 @@ def prepare_latents( generator, latents=None, reference_images=None, - image=None, - last_image=None, ): shape = ( batch_size, @@ -696,9 +773,6 @@ def __call__( width: int, num_frames: int = 1, images: Optional[List[Any]] = None, - image_condition: Optional[torch.Tensor] = None, - last_image_condition: Optional[torch.Tensor] = None, - data_type: str = "image", num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, @@ -818,13 +892,11 @@ def __call__( generator, latents, reference_images=images, - image=image_condition, - last_image=last_image_condition, ) target_dtype = PRECISION_TO_TYPE.get(getattr(self.args, "dit_precision", "bf16"), prompt_embeds.dtype) autocast_enabled = target_dtype != torch.float32 and device.type == "cuda" - vae_dtype = PRECISION_TO_TYPE.get(getattr(self.args, "vae_precision", "bf16"), prompt_embeds.dtype) + vae_dtype = self._vae_compute_dtype() vae_autocast_enabled = vae_dtype != torch.float32 and device.type == "cuda" self._num_timesteps = len(timesteps) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b1f75bed7dc5..b0fd70c2cd8e 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -61,6 +61,10 @@ _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"] _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] + _import_structure["scheduling_joyai_flow_match_discrete"] = [ + "JoyAIFlowMatchDiscreteScheduler", + "JoyAIFlowMatchDiscreteSchedulerOutput", + ] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] _import_structure["scheduling_helios"] = ["HeliosScheduler"] _import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"] @@ -167,6 +171,10 @@ from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler + from .scheduling_joyai_flow_match_discrete import ( + JoyAIFlowMatchDiscreteScheduler, + JoyAIFlowMatchDiscreteSchedulerOutput, + ) from .scheduling_flow_match_lcm import FlowMatchLCMScheduler from .scheduling_helios import HeliosScheduler from .scheduling_helios_dmd import HeliosDMDScheduler diff --git a/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py new file mode 100644 index 000000000000..62b98399bfe0 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py @@ -0,0 +1,260 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class JoyAIFlowMatchDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class JoyAIFlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + reverse (`bool`, defaults to `True`): + Whether to reverse the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + reverse: bool = True, + solver: str = "euler", + n_tokens: Optional[int] = None, + ): + sigmas = torch.linspace(1, 0, num_train_timesteps + 1) + + if not reverse: + sigmas = sigmas.flip(0) + + self.sigmas = sigmas + # the value fed to model + self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) + + + self._step_index = None + self._begin_index = None + + self.supported_solver = ["euler"] + if solver not in self.supported_solver: + raise ValueError( + f"Solver {solver} not supported. Supported solvers: {self.supported_solver}" + ) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + n_tokens: int = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + """ + self.num_inference_steps = num_inference_steps + + sigmas = torch.linspace(1, 0, num_inference_steps + 1) + + sigmas = self.sd3_time_shift(sigmas) + + if not self.config.reverse: + sigmas = 1 - sigmas + + self.sigmas = sigmas + self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to( + dtype=torch.float32, device=device + ) + + # Reset step index + self._step_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def scale_model_input( + self, sample: torch.Tensor, timestep: Optional[int] = None + ) -> torch.Tensor: + return sample + + def sd3_time_shift(self, t: torch.Tensor): + # print("sd3:self.config.shift",self.config.shift) + return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[JoyAIFlowMatchDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] + + if self.config.solver == "euler": + prev_sample = sample + model_output.to(torch.float32) * dt + else: + raise ValueError( + f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}" + ) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return JoyAIFlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps From 8190e3f1791973587cc5f56ad7b6d08d7958eddd Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 10 Apr 2026 02:14:00 +0800 Subject: [PATCH 4/6] upd Signed-off-by: Lancer --- src/diffusers/__init__.py | 4 + .../autoencoder_kl_joyai_image.py | 627 +++++++----------- .../transformers/transformer_joyai_image.py | 5 +- .../joyai_image/pipeline_joyai_image.py | 8 +- .../test_models_autoencoder_joyai_image.py | 80 +++ .../test_models_transformer_joyai_image.py | 73 ++ 6 files changed, 413 insertions(+), 384 deletions(-) create mode 100644 tests/models/autoencoders/test_models_autoencoder_joyai_image.py create mode 100644 tests/models/transformers/test_models_transformer_joyai_image.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a23b89a77b0a..244340f87252 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -209,6 +209,7 @@ "AutoencoderKLMagvit", "AutoencoderKLMochi", "AutoencoderKLQwenImage", + "JoyAIImageVAE", "AutoencoderKLTemporalDecoder", "AutoencoderKLWan", "AutoencoderOobleck", @@ -245,6 +246,7 @@ "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "HunyuanImageTransformer2DModel", + "JoyAIImageTransformer3DModel", "HunyuanVideo15Transformer3DModel", "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", @@ -1005,6 +1007,7 @@ AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, + JoyAIImageVAE, AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, @@ -1041,6 +1044,7 @@ HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, HunyuanImageTransformer2DModel, + JoyAIImageTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py index 759d53f7aef9..36b642e43e66 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py @@ -1,264 +1,38 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import logging + +from contextlib import nullcontext import torch -import torch.cuda.amp as amp import torch.nn as nn -import torch.nn.functional as F - +from ...configuration_utils import ConfigMixin +from ...loaders import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .autoencoder_kl_wan import ( + WanAttentionBlock as AttentionBlock, +) +from .autoencoder_kl_wan import ( + WanCausalConv3d as CausalConv3d, +) +from .autoencoder_kl_wan import ( + WanResample as Resample, +) +from .autoencoder_kl_wan import ( + WanResidualBlock as ResidualBlock, +) +from .autoencoder_kl_wan import ( + WanRMS_norm as RMS_norm, +) +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) CACHE_T = 2 -class CausalConv3d(nn.Conv3d): - """ - Causal 3d convolusion. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) - - def forward(self, x, cache_x=None): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) - - return super().forward(x) - - -class RMS_norm(nn.Module): - - def __init__(self, dim, channel_first=True, images=True, bias=False): - super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - - self.channel_first = channel_first - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. - - def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias - - -class Upsample(nn.Upsample): - - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) - - -class Resample(nn.Module): - - def __init__(self, dim, mode): - assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', - 'downsample3d') - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == 'upsample2d': - self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - elif mode == 'upsample3d': - self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - self.time_conv = CausalConv3d( - dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - - elif mode == 'downsample2d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == 'downsample3d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d( - dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - - else: - self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() - if self.mode == 'upsample3d': - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = 'Rep' - feat_idx[0] += 1 - else: - - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) - if feat_cache[idx] == 'Rep': - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), - 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.resample(x) - resampled_height, resampled_width = x.shape[-2:] - x = x.reshape(b, t, x.shape[1], resampled_height, resampled_width).permute(0, 2, 1, 3, 4) - - if self.mode == 'downsample3d': - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - - x = self.time_conv( - torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - return x - - def init_weight(self, conv): - conv_weight = conv.weight - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix - nn.init.zeros_(conv_weight) - #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 - conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - def init_weight2(self, conv): - conv_weight = conv.weight.data - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - init_matrix = torch.eye(c1 // 2, c2) - #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) - conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - -class ResidualBlock(nn.Module): - - def __init__(self, in_dim, out_dim, dropout=0.0): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - - # layers - self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), nn.SiLU(), - CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1)) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ - if in_dim != out_dim else nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) - for layer in self.residual: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + h - - -class AttentionBlock(nn.Module): - """ - Causal self-attention with a single head. - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim - - # layers - self.norm = RMS_norm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) - - # zero out the last layer params - nn.init.zeros_(self.proj.weight) - - def forward(self, x): - identity = x - b, c, t, h, w = x.size() - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.norm(x) - # compute query, key, value - q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, - -1).permute(0, 1, 3, - 2).contiguous().chunk( - 3, dim=-1) - - # apply attention - x = F.scaled_dot_product_attention( - q, - k, - v, - ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) - - # output - x = self.proj(x) - x = x.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) - return x + identity - - class Encoder3d(nn.Module): def __init__(self, @@ -300,17 +74,25 @@ def __init__(self, i] else 'downsample2d' downsamples.append(Resample(out_dim, mode=mode)) scale /= 2.0 - self.downsamples = nn.Sequential(*downsamples) + self.downsamples = nn.ModuleList(downsamples) # middle blocks - self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), - ResidualBlock(out_dim, out_dim, dropout)) + self.middle = nn.ModuleList( + [ + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ] + ) # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, z_dim, 3, padding=1)) + self.head = nn.ModuleList( + [ + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ] + ) def forward(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: @@ -329,17 +111,15 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): else: x = self.conv1(x) - ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) - ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) @@ -389,9 +169,13 @@ def __init__(self, self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) # middle blocks - self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout)) + self.middle = nn.ModuleList( + [ + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ] + ) # upsample blocks upsamples = [] @@ -410,12 +194,16 @@ def __init__(self, mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' upsamples.append(Resample(out_dim, mode=mode)) scale *= 2.0 - self.upsamples = nn.Sequential(*upsamples) + self.upsamples = nn.ModuleList(upsamples) # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + self.head = nn.ModuleList( + [ + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1), + ] + ) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 @@ -435,21 +223,18 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): else: x = self.conv1(x) - ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) - ## upsamples for layer in self.upsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) - ## head for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] @@ -504,6 +289,46 @@ def __init__(self, self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) + @property + def quant_conv(self): + return self.conv1 + + @property + def post_quant_conv(self): + return self.conv2 + + def _encode_frames(self, x): + num_frames = x.shape[2] + num_chunks = 1 + (num_frames - 1) // 4 + + for chunk_idx in range(num_chunks): + self._enc_conv_idx = [0] + if chunk_idx == 0: + encoded = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + encoded_chunk = self.encoder( + x[:, :, 1 + 4 * (chunk_idx - 1): 1 + 4 * chunk_idx, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + encoded = torch.cat([encoded, encoded_chunk], dim=2) + return encoded + + def _decode_frames(self, x): + num_frames = x.shape[2] + for frame_idx in range(num_frames): + self._conv_idx = [0] + decoded_chunk = self.decoder( + x[:, :, frame_idx : frame_idx + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + if frame_idx == 0: + decoded = decoded_chunk + else: + decoded = torch.cat([decoded, decoded_chunk], dim=2) + return decoded + def forward(self, x): mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) @@ -512,24 +337,8 @@ def forward(self, x): def encode(self, x, scale=None, return_posterior=False): self.clear_cache() - ## cache - t = x.shape[2] - iter_ = 1 + (t - 1) // 4 - ## 对encode输入的x,按时间拆分为1、4、4、4.... - for i in range(iter_): - self._enc_conv_idx = [0] - if i == 0: - out = self.encoder( - x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) - else: - out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) - out = torch.cat([out, out_], 2) - mu, log_var = self.conv1(out).chunk(2, dim=1) + encoded = self._encode_frames(x) + mu, log_var = self.quant_conv(encoded).chunk(2, dim=1) if scale is None or return_posterior: return mu, log_var @@ -544,30 +353,15 @@ def encode(self, x, scale=None, return_posterior=False): def decode(self, z, scale=None): self.clear_cache() - # z: [b,c,t,h,w] if scale is not None: if isinstance(scale[0], torch.Tensor): z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( 1, self.z_dim, 1, 1, 1) else: z = z / scale[1] + scale[0] - iter_ = z.shape[2] - x = self.conv2(z) - for i in range(iter_): - self._conv_idx = [0] - if i == 0: - out = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) - else: - out_ = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) - out = torch.cat([out, out_], 2) + decoded = self._decode_frames(self.post_quant_conv(z)) self.clear_cache() - return out + return decoded def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) @@ -598,11 +392,8 @@ def clear_cache(self): self._enc_feat_map = [None] * self._enc_conv_num -def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): - """ - Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. - """ - # params +def _build_video_vae(z_dim=None, use_meta=False, **kwargs): + """Build the JoyAI/Wan-derived VAE backbone without loading external weights.""" cfg = { "dim": 96, "z_dim": z_dim, @@ -614,87 +405,163 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): } cfg.update(**kwargs) - # init model - with torch.device('meta'): - model = WanVAE_(**cfg) + if use_meta: + with torch.device("meta"): + return WanVAE_(**cfg) + return WanVAE_(**cfg) - # load checkpoint - logging.info(f'loading {pretrained_path}') - if pretrained_path.endswith('.safetensors'): +def _remap_joyai_vae_state_dict_keys(pretrained_state_dict): + remapped_state_dict = {} + for key, value in pretrained_state_dict.items(): + key = key.replace(".residual.0.gamma", ".norm1.gamma") + key = key.replace(".residual.2.weight", ".conv1.weight") + key = key.replace(".residual.2.bias", ".conv1.bias") + key = key.replace(".residual.3.gamma", ".norm2.gamma") + key = key.replace(".residual.6.weight", ".conv2.weight") + key = key.replace(".residual.6.bias", ".conv2.bias") + key = key.replace(".shortcut.weight", ".conv_shortcut.weight") + key = key.replace(".shortcut.bias", ".conv_shortcut.bias") + remapped_state_dict[key] = value + return remapped_state_dict + + +def _load_pretrained_weights(model, pretrained_path): + if not pretrained_path: + return model + + logger.info(f"loading {pretrained_path}") + + if pretrained_path.endswith(".safetensors"): from safetensors.torch import load_file - pretrained_state_dict = load_file(pretrained_path, device='cpu') + + pretrained_state_dict = load_file(pretrained_path, device="cpu") else: - pretrained_state_dict = torch.load(pretrained_path, map_location='cpu') + pretrained_state_dict = torch.load(pretrained_path, map_location="cpu") + pretrained_state_dict = _remap_joyai_vae_state_dict_keys(pretrained_state_dict) model.load_state_dict(pretrained_state_dict, assign=True) - return model - -class WanxVAE(nn.Module): - # @register_to_config - def __init__(self, - pretrained='', - torch_dtype=torch.float32, - device='cuda' - ): +def _video_vae(pretrained_path=None, z_dim=None, use_meta=False, **kwargs): + model = _build_video_vae(z_dim=z_dim, use_meta=use_meta, **kwargs) + return _load_pretrained_weights(model, pretrained_path) + + +class JoyAIImageVAE(ModelMixin, ConfigMixin, AutoencoderMixin, FromOriginalModelMixin): + def __init__( + self, + pretrained: str = "", + torch_dtype: torch.dtype = torch.float32, + device: str | torch.device = "cpu", + z_dim: int = 16, + latent_channels: int | None = None, + dim: int = 96, + dim_mult: list[int] | tuple[int, ...] = (1, 2, 4, 4), + num_res_blocks: int = 2, + attn_scales: list[float] | tuple[float, ...] = (), + temperal_downsample: list[bool] | tuple[bool, ...] = (False, True, True), + dropout: float = 0.0, + latents_mean: list[float] | tuple[float, ...] = ( + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921, + ), + latents_std: list[float] | tuple[float, ...] = ( + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160, + ), + spatial_compression_ratio: int = 8, + temporal_compression_ratio: int = 4, + ): super().__init__() - self.dtype = torch_dtype - self.device = device + if latent_channels is not None: + z_dim = latent_channels + + self.register_to_config( + pretrained=pretrained, + z_dim=z_dim, + dim=dim, + dim_mult=list(dim_mult), + num_res_blocks=num_res_blocks, + attn_scales=list(attn_scales), + temperal_downsample=list(temperal_downsample), + dropout=dropout, + latent_channels=z_dim, + latents_mean=list(latents_mean), + latents_std=list(latents_std), + spatial_compression_ratio=spatial_compression_ratio, + temporal_compression_ratio=temporal_compression_ratio, + ) - mean = [ - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 - ] - std = [ - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 - ] - self.mean = torch.tensor(mean, dtype=self.dtype, device=device) - self.std = torch.tensor(std, dtype=self.dtype, device=device) - self.scale = [self.mean, 1.0 / self.std] - - self.config = lambda: None - self.config.latents_mean = self.mean - self.config.latents_std = self.std - self.ffactor_spatial = 8 - self.ffactor_temporal = 4 - self.config.latent_channels = 16 - - # init model + self.register_buffer("mean", torch.tensor(latents_mean, dtype=torch.float32), persistent=True) + self.register_buffer("std", torch.tensor(latents_std, dtype=torch.float32), persistent=True) + + self.ffactor_spatial = spatial_compression_ratio + self.ffactor_temporal = temporal_compression_ratio + + use_meta = bool(pretrained) self.model = _video_vae( pretrained_path=pretrained, - z_dim=16, - ).eval().requires_grad_(False) - self.model = self.model.to(device=device, dtype=torch_dtype) - - def encode(self, videos, return_posterior=False, **kwargs): - """ - videos: A list of videos each with shape [C, T, H, W]. - """ - with amp.autocast(dtype=torch.float): + z_dim=z_dim, + dim=dim, + dim_mult=list(dim_mult), + num_res_blocks=num_res_blocks, + attn_scales=list(attn_scales), + temperal_downsample=list(temperal_downsample), + dropout=dropout, + use_meta=use_meta, + ) + self.model.eval() + + def _latent_scale_tensors(self, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + mean = self.mean.to(device=device, dtype=dtype).view(1, -1, 1, 1, 1) + inv_std = self.std.to(device=device, dtype=dtype).reciprocal().view(1, -1, 1, 1, 1) + return mean, inv_std + + @apply_forward_hook + def encode(self, videos: torch.Tensor, return_dict: bool = True, return_posterior: bool = False, **kwargs): + autocast_context = torch.amp.autocast(device_type="cuda", dtype=torch.float32) if videos.device.type == "cuda" else nullcontext() + with autocast_context: + mean, logvar = self.model.encode(videos, scale=None, return_posterior=True) if return_posterior: - mus, log_vars = self.model.encode( - videos, scale=self.scale, return_posterior=True) - return mus, log_vars - else: - latents = self.model.encode(videos, scale=self.scale) - return latents - - def decode(self, zs, **kwargs): - with amp.autocast(dtype=torch.float): - videos = [ - self.model.decode(u.unsqueeze(0), scale=self.scale).clamp_(-1, 1).squeeze(0) - for u in zs - ] + return mean, logvar + + latent_mean, latent_inv_std = self._latent_scale_tensors(mean.device, mean.dtype) + scaled_mean = (mean - latent_mean) * latent_inv_std + scaled_logvar = logvar + 2 * torch.log(latent_inv_std) + posterior = DiagonalGaussianDistribution(torch.cat([scaled_mean, scaled_logvar], dim=1)) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + @apply_forward_hook + def decode(self, zs: torch.Tensor, return_dict: bool = True, **kwargs): + autocast_context = torch.amp.autocast(device_type="cuda", dtype=torch.float32) if zs.device.type == "cuda" else nullcontext() + with autocast_context: + mean, inv_std = self._latent_scale_tensors(zs.device, zs.dtype) + scale = [mean.view(-1), inv_std.view(-1)] + videos = [self.model.decode(z.unsqueeze(0), scale=scale).clamp_(-1, 1).squeeze(0) for z in zs] videos = torch.stack(videos, dim=0) - return (videos, ) + if not return_dict: + return (videos,) + return DecoderOutput(sample=videos) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ): + posterior = self.encode(sample).latent_dist + latents = posterior.sample(generator=generator) if sample_posterior else posterior.mode() + return self.decode(latents, return_dict=return_dict) -JoyAIImageVAE = WanxVAE +WanxVAE = JoyAIImageVAE -__all__ = ["WanxVAE", "JoyAIImageVAE"] +__all__ = ["JoyAIImageVAE", "WanxVAE"] diff --git a/src/diffusers/models/transformers/transformer_joyai_image.py b/src/diffusers/models/transformers/transformer_joyai_image.py index 6003bc7f038a..ec785e5525aa 100644 --- a/src/diffusers/models/transformers/transformer_joyai_image.py +++ b/src/diffusers/models/transformers/transformer_joyai_image.py @@ -6,7 +6,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models import ModelMixin -from diffusers.models.attention import AttentionModuleMixin, FeedForward +from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward from diffusers.models.attention_dispatch import AttentionBackendName, dispatch_attention_fn from diffusers.models.embeddings import ( PixArtAlphaTextProjection, @@ -100,6 +100,7 @@ def __call__( class JoyAIJointAttention(nn.Module, AttentionModuleMixin): _default_processor_cls = JoyAIJointAttnProcessor _available_processors = [JoyAIJointAttnProcessor] + _supports_qkv_fusion = False def __init__(self, backend: str = "flash_attn", processor=None) -> None: super().__init__() @@ -340,7 +341,7 @@ def forward( return modulation_states, encoder_hidden_states -class JoyAIImageTransformer3DModel(ModelMixin, ConfigMixin): +class JoyAIImageTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): _fsdp_shard_conditions: list = [ lambda name, module: isinstance(module, JoyAIImageTransformerBlock)] _supports_gradient_checkpointing = True diff --git a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py index 1a0f7aabf5a5..f47d66994f8d 100644 --- a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py +++ b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py @@ -189,9 +189,10 @@ def load_joyai_components( torch_dtype=PRECISION_TO_TYPE[cfg.vae_precision], device=load_device, ) + vae = vae.to(device=load_device, dtype=PRECISION_TO_TYPE[cfg.vae_precision]).eval() text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( str(text_encoder_ckpt), - torch_dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision], + dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision], local_files_only=True, trust_remote_code=True, ).to(load_device).eval() @@ -600,7 +601,10 @@ def _encode_with_vae(self, videos: torch.Tensor) -> torch.Tensor: self.vae.config.latents_std = self.vae.std self.vae.to(device=device, dtype=vae_dtype) - return self.vae.encode(videos) + encoded = self.vae.encode(videos) + if hasattr(encoded, "latent_dist"): + return encoded.latent_dist.sample() + return encoded def _decode_with_vae(self, latents: torch.Tensor): device = self._get_runtime_execution_device() diff --git a/tests/models/autoencoders/test_models_autoencoder_joyai_image.py b/tests/models/autoencoders/test_models_autoencoder_joyai_image.py new file mode 100644 index 000000000000..a6180c462fb0 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_joyai_image.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. + +import unittest + +from diffusers import JoyAIImageVAE + +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class JoyAIImageVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = JoyAIImageVAE + main_input_name = "sample" + base_precision = 1e-2 + + def get_joyai_image_vae_config(self): + return { + "dim": 3, + "z_dim": 16, + "dim_mult": [1, 1, 1, 1], + "num_res_blocks": 1, + "temperal_downsample": [False, True, True], + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + sample = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + return {"sample": sample} + + @property + def dummy_input_tiling(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (128, 128) + sample = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + return {"sample": sample} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_joyai_image_vae_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def prepare_init_args_and_inputs_for_tiling(self): + init_dict = self.get_joyai_image_vae_config() + inputs_dict = self.dummy_input_tiling + return init_dict, inputs_dict + + @unittest.skip("Gradient checkpointing has not been implemented yet") + def test_gradient_checkpointing_is_applied(self): + pass + + @unittest.skip("Test not supported") + def test_forward_with_norm_groups(self): + pass + + @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'") + def test_layerwise_casting_inference(self): + pass + + @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'") + def test_layerwise_casting_training(self): + pass diff --git a/tests/models/transformers/test_models_transformer_joyai_image.py b/tests/models/transformers/test_models_transformer_joyai_image.py new file mode 100644 index 000000000000..aab8530b3b16 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_joyai_image.py @@ -0,0 +1,73 @@ +# coding=utf-8 + +import torch + +from diffusers import JoyAIImageTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import AttentionTesterMixin, BaseModelTesterConfig, ModelTesterMixin + + +enable_full_determinism() + + +class JoyAIImageTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return JoyAIImageTransformer3DModel + + @property + def output_shape(self) -> tuple[int, int, int, int]: + return (4, 2, 4, 4) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | float | tuple[int, int, int] | str]: + return { + "patch_size": (1, 2, 2), + "in_channels": 4, + "out_channels": 4, + "hidden_size": 32, + "heads_num": 4, + "text_states_dim": 24, + "mlp_width_ratio": 2.0, + "mm_double_blocks_depth": 2, + "rope_dim_list": (2, 2, 4), + "rope_type": "rope", + "attn_backend": "torch_spda", + "theta": 1000, + } + + def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]: + hidden_states = randn_tensor((batch_size, 4, 2, 4, 4), generator=self.generator, device=torch_device) + timestep = torch.tensor([1.0] * batch_size, device=torch_device) + encoder_hidden_states = randn_tensor((batch_size, 6, 24), generator=self.generator, device=torch_device) + encoder_hidden_states_mask = torch.tensor( + [[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]], device=torch_device, dtype=torch.long + ) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + } + + +class TestJoyAIImageTransformer(JoyAIImageTransformerTesterConfig, ModelTesterMixin): + pass + + +class TestJoyAIImageTransformerAttention(JoyAIImageTransformerTesterConfig, AttentionTesterMixin): + def test_exposes_attention_processors(self): + model = self.model_class(**self.get_init_dict()).to(torch_device) + + assert hasattr(model, "attn_processors") + assert len(model.attn_processors) == len(model.double_blocks) From 5ebbbd79770e2212cb1243e8eba3be18bd52eaa8 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 10 Apr 2026 08:31:47 +0800 Subject: [PATCH 5/6] upd Signed-off-by: Lancer --- src/diffusers/__init__.py | 4 +- .../autoencoder_kl_joyai_image.py | 241 +++++++++++++++-- .../scheduling_joyai_flow_match_discrete.py | 243 ++---------------- 3 files changed, 255 insertions(+), 233 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 244340f87252..60cf203700d5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1007,7 +1007,6 @@ AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, - JoyAIImageVAE, AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, @@ -1044,11 +1043,12 @@ HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, HunyuanImageTransformer2DModel, - JoyAIImageTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, I2VGenXLUNet, + JoyAIImageTransformer3DModel, + JoyAIImageVAE, Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py index 36b642e43e66..337f0d88c3a7 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py @@ -4,28 +4,15 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin from ...loaders import FromOriginalModelMixin from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .autoencoder_kl_wan import ( - WanAttentionBlock as AttentionBlock, -) -from .autoencoder_kl_wan import ( - WanCausalConv3d as CausalConv3d, -) -from .autoencoder_kl_wan import ( - WanResample as Resample, -) -from .autoencoder_kl_wan import ( - WanResidualBlock as ResidualBlock, -) -from .autoencoder_kl_wan import ( - WanRMS_norm as RMS_norm, -) from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution @@ -33,6 +20,230 @@ CACHE_T = 2 +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanCausalConv3d +class CausalConv3d(nn.Conv3d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanRMS_norm +class RMS_norm(nn.Module): + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any( + t in str(x.dtype) for t in ("float4_", "float8_") + ) + normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to( + x.dtype + ) + + return normalized * self.scale * self.gamma + self.bias + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanUpsample +class Upsample(nn.Upsample): + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResample +class Resample(nn.Module): + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanResidualBlock +class ResidualBlock(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = RMS_norm(in_dim, images=False) + self.conv1 = CausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = RMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.conv_shortcut(x) + + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + return x + h + + +# Copied from diffusers.models.autoencoders.autoencoder_kl_wan.WanAttentionBlock +class AttentionBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + x = self.proj(x) + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + class Encoder3d(nn.Module): def __init__(self, diff --git a/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py index 62b98399bfe0..b68848a17897 100644 --- a/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py +++ b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py @@ -11,59 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================== -# -# Modified from diffusers==0.29.2 -# -# ============================================================================== -from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Union -import numpy as np import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput, logging -from diffusers.schedulers.scheduling_utils import SchedulerMixin - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class JoyAIFlowMatchDiscreteSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's `step` function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: torch.FloatTensor - - -class JoyAIFlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Euler scheduler. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. +from diffusers.configuration_utils import register_to_config +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - shift (`float`, defaults to 1.0): - The shift value for the timestep schedule. - reverse (`bool`, defaults to `True`): - Whether to reverse the timestep schedule. - """ +class JoyAIFlowMatchDiscreteScheduler(FlowMatchEulerDiscreteScheduler): _compatibles = [] order = 1 @@ -73,188 +30,42 @@ def __init__( num_train_timesteps: int = 1000, shift: float = 1.0, reverse: bool = True, - solver: str = "euler", - n_tokens: Optional[int] = None, ): - sigmas = torch.linspace(1, 0, num_train_timesteps + 1) - - if not reverse: - sigmas = sigmas.flip(0) - - self.sigmas = sigmas - # the value fed to model - self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) - - - self._step_index = None - self._begin_index = None - - self.supported_solver = ["euler"] - if solver not in self.supported_solver: - raise ValueError( - f"Solver {solver} not supported. Supported solvers: {self.supported_solver}" - ) - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index + super().__init__( + num_train_timesteps=num_train_timesteps, + shift=shift, + use_dynamic_shifting=False, + base_shift=0.5, + max_shift=1.15, + base_image_seq_len=256, + max_image_seq_len=4096, + invert_sigmas=False, + shift_terminal=None, + use_karras_sigmas=False, + use_exponential_sigmas=False, + use_beta_sigmas=False, + time_shift_type="exponential", + stochastic_sampling=False, + ) + self.register_to_config(reverse=reverse) - def _sigma_to_t(self, sigma): - return sigma * self.config.num_train_timesteps + def sd3_time_shift(self, timesteps: torch.Tensor) -> torch.Tensor: + return (self.config.shift * timesteps) / (1 + (self.config.shift - 1) * timesteps) def set_timesteps( self, num_inference_steps: int, - device: Union[str, torch.device] = None, - n_tokens: int = None, + device: Union[str, torch.device, None] = None, + **kwargs, ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - n_tokens (`int`, *optional*): - Number of tokens in the input sequence. - """ self.num_inference_steps = num_inference_steps sigmas = torch.linspace(1, 0, num_inference_steps + 1) - sigmas = self.sd3_time_shift(sigmas) if not self.config.reverse: sigmas = 1 - sigmas - self.sigmas = sigmas - self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to( - dtype=torch.float32, device=device - ) - - # Reset step index + self.sigmas = sigmas.to(device=device) + self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) self._step_index = None - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - def _init_step_index(self, timestep): - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - - def scale_model_input( - self, sample: torch.Tensor, timestep: Optional[int] = None - ) -> torch.Tensor: - return sample - - def sd3_time_shift(self, t: torch.Tensor): - # print("sd3:self.config.shift",self.config.shift) - return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[JoyAIFlowMatchDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - n_tokens (`int`, *optional*): - Number of tokens in the input sequence. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or - tuple. - - Returns: - [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. - """ - - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Upcast to avoid precision issues when computing prev_sample - sample = sample.to(torch.float32) - - dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] - - if self.config.solver == "euler": - prev_sample = sample + model_output.to(torch.float32) * dt - else: - raise ValueError( - f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}" - ) - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return JoyAIFlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) - - def __len__(self): - return self.config.num_train_timesteps From 2ccc45fb474a7526d723e6bb48199bc353d0fcf0 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 10 Apr 2026 12:53:50 +0800 Subject: [PATCH 6/6] upd Signed-off-by: Lancer --- docs/source/en/_toctree.yml | 6 + .../api/models/autoencoder_kl_joyai_image.md | 35 +++ .../api/models/joyai_image_transformer3d.md | 26 ++ docs/source/en/api/pipelines/joyai_image.md | 128 ++++++++ docs/source/en/api/pipelines/overview.md | 1 + scripts/convert_joyai_image_to_diffusers.py | 286 ++++++++++++++++++ src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 4 +- src/diffusers/models/autoencoders/__init__.py | 2 +- .../autoencoder_kl_joyai_image.py | 174 ++++++----- .../transformers/transformer_joyai_image.py | 174 ++++------- .../joyai_image/pipeline_joyai_image.py | 219 +------------- .../pipelines/joyai_image/pipeline_output.py | 14 + src/diffusers/schedulers/__init__.py | 8 +- .../scheduling_joyai_flow_match_discrete.py | 36 ++- .../joyai_image/test_pipeline_joyai_image.py | 56 ++++ ...est_scheduler_joyai_flow_match_discrete.py | 37 +++ 17 files changed, 797 insertions(+), 411 deletions(-) create mode 100644 docs/source/en/api/models/autoencoder_kl_joyai_image.md create mode 100644 docs/source/en/api/models/joyai_image_transformer3d.md create mode 100644 docs/source/en/api/pipelines/joyai_image.md create mode 100644 scripts/convert_joyai_image_to_diffusers.py create mode 100644 tests/pipelines/joyai_image/test_pipeline_joyai_image.py create mode 100644 tests/schedulers/test_scheduler_joyai_flow_match_discrete.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7582a56505f7..848e26630825 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -370,6 +370,8 @@ title: LatteTransformer3DModel - local: api/models/longcat_image_transformer2d title: LongCatImageTransformer2DModel + - local: api/models/joyai_image_transformer3d + title: JoyAIImageTransformer3DModel - local: api/models/ltx2_video_transformer3d title: LTX2VideoTransformer3DModel - local: api/models/ltx_video_transformer3d @@ -466,6 +468,8 @@ title: AutoencoderKLQwenImage - local: api/models/autoencoder_kl_wan title: AutoencoderKLWan + - local: api/models/autoencoder_kl_joyai_image + title: JoyAIImageVAE - local: api/models/autoencoder_rae title: AutoencoderRAE - local: api/models/consistency_decoder_vae @@ -558,6 +562,8 @@ title: Kandinsky 5.0 Image - local: api/pipelines/kolors title: Kolors + - local: api/pipelines/joyai_image + title: JoyAI-Image - local: api/pipelines/latent_consistency_models title: Latent Consistency Models - local: api/pipelines/latent_diffusion diff --git a/docs/source/en/api/models/autoencoder_kl_joyai_image.md b/docs/source/en/api/models/autoencoder_kl_joyai_image.md new file mode 100644 index 000000000000..e65909092da6 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_joyai_image.md @@ -0,0 +1,35 @@ + + +# JoyAIImageVAE + +The 3D variational autoencoder (VAE) model with KL loss used in JoyAI-Image by JDopensource. + +The model can be loaded with the following code snippet. + +```python +from diffusers import JoyAIImageVAE + +vae = JoyAIImageVAE.from_pretrained("path/to/checkpoint", subfolder="vae", torch_dtype=torch.bfloat16) +``` + + +## JoyAIImageVAE + +[[autodoc]] JoyAIImageVAE + - decode + - all + + +## DecoderOutput + +[[autodoc]] diffusers.models.autoencoders.autoencoder_kl.AutoencoderKLOutput \ No newline at end of file diff --git a/docs/source/en/api/models/joyai_image_transformer3d.md b/docs/source/en/api/models/joyai_image_transformer3d.md new file mode 100644 index 000000000000..c90ac5e50ea4 --- /dev/null +++ b/docs/source/en/api/models/joyai_image_transformer3d.md @@ -0,0 +1,26 @@ + + +# JoyAIImageTransformer3DModel + +The model can be loaded with the following code snippet. + +```python +from diffusers import JoyAIImageTransformer3DModel + +transformer = JoyAIImageTransformer3DModel.from_pretrained("path/to/checkpoint", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + + +## JoyAIImageTransformer3DModel + +[[autodoc]] JoyAIImageTransformer3DModel \ No newline at end of file diff --git a/docs/source/en/api/pipelines/joyai_image.md b/docs/source/en/api/pipelines/joyai_image.md new file mode 100644 index 000000000000..7ed482a66978 --- /dev/null +++ b/docs/source/en/api/pipelines/joyai_image.md @@ -0,0 +1,128 @@ + + +# JoyAI-Image + +
+ LoRA +
+ +JoyAI-Image is a multimodal foundation model specialized in instruction-guided image editing. It enables precise and controllable edits by leveraging strong spatial understanding, including scene parsing, relational grounding, and instruction decomposition, allowing complex modifications to be applied accurately to specified regions. + + +### Key Features +- 🌟 **Unified Multimodal Understanding and Generation**: Combines powerful image understanding with generation capabilities in a single model. +- 🌟 **Spatial Editing**: Supports precise spatial editing including object movement, rotation, and camera control. +- 🌟 **Instruction Following**: Accurately interprets user instructions for image modifications while preserving image quality. +- 🌟 **Qwen2.5-VL Integration**: Leverages Qwen2.5-VL for enhanced multimodal understanding. + +For more details, please refer to the [JoyAI-Image GitHub](https://github.com/jd-opensource/JoyAI-Image). + + +## Usage Example + +```py +import torch +from diffusers import JoyAIImagePipeline + +pipe = JoyAIImagePipeline.from_pretrained("path/to/converted/checkpoint", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "Move the apple into the red box and finally remove the red box." +image = pipe( + prompt, + image=input_image, + num_inference_steps=30, + guidance_scale=5.0, +).images[0] +image.save("./output.png") +``` + + +### Supported Prompt Patterns + +#### 1. Object Move +```text +Move the into the red box and finally remove the red box. +``` + +#### 2. Object Rotation +```text +Rotate the to show the side view. +``` +Supported views: front, right, left, rear, front right, front left, rear right, rear left + +#### 3. Camera Control +```text +Move the camera. +- Camera rotation: Yaw {y_rotation}°, Pitch {p_rotation}°. +- Camera zoom: in/out/unchanged. +- Keep the 3D scene static; only change the viewpoint. +``` + +This pipeline was contributed by JDopensource Team. The original codebase can be found [here](https://github.com/jd-opensource/JoyAI-Image). + + +## Available Models +
+ + + + + + + + + + + + + + + + + +
ModelsTypeDescriptionDownload Link
JoyAI‑Image‑EditImage EditingFinal Release. Specialized model for instruction-guided image editing. + 🤗 Huggingface +
+
+ +## Converting Original Checkpoint to Diffusers Format + +If you have the original JoyAI checkpoint, you can convert it to diffusers format using the provided conversion script: + +```bash +python scripts/convert_joyai_image_to_diffusers.py \ + --source_path /path/to/original/JoyAI-Image-Edit \ + --output_path /path/to/converted/checkpoint \ + --dtype bf16 +``` + +After conversion, load the model with: + +```py +from diffusers import JoyAIImagePipeline +pipe = JoyAIImagePipeline.from_pretrained("/path/to/converted/checkpoint") +``` + + +## JoyAIImagePipeline + +[[autodoc]] JoyAIImagePipeline + - all + - __call__ + + +## JoyAIImagePipelineOutput + +[[autodoc]] pipelines.joyai_image.pipeline_output.JoyAIImagePipelineOutput + diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index c3e493c63d6a..bac1a810529e 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -49,6 +49,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting | | [Kandinsky 3](kandinsky3) | text2image, image2image | | [Kolors](kolors) | text2image | +| [JoyAI-Image](joyai_image) | image editing | | [Latent Consistency Models](latent_consistency_models) | text2image | | [Latent Diffusion](latent_diffusion) | text2image, super-resolution | | [Latte](latte) | text2image | diff --git a/scripts/convert_joyai_image_to_diffusers.py b/scripts/convert_joyai_image_to_diffusers.py new file mode 100644 index 000000000000..343885a76f6f --- /dev/null +++ b/scripts/convert_joyai_image_to_diffusers.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 + +import argparse +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration + +from diffusers import JoyAIImagePipeline +from diffusers.configuration_utils import FrozenDict +from diffusers.models.autoencoders.autoencoder_kl_joyai_image import JoyAIImageVAE +from diffusers.models.transformers.transformer_joyai_image import JoyAIImageTransformer3DModel +from diffusers.schedulers.scheduling_joyai_flow_match_discrete import JoyAIFlowMatchDiscreteScheduler + + +DTYPE_MAP = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, +} + + +@dataclass +class JoyAIImageSourceConfig: + source_root: Path + dit_precision: str = "bf16" + vae_precision: str = "bf16" + text_encoder_precision: str = "bf16" + text_token_max_length: int = 2048 + enable_multi_task_training: bool = False + dit_arch_config: dict[str, Any] = field( + default_factory=lambda: { + "hidden_size": 4096, + "in_channels": 16, + "heads_num": 32, + "mm_double_blocks_depth": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_states_dim": 4096, + "rope_type": "rope", + "dit_modulation_type": "wanx", + "theta": 10000, + "attn_backend": "flash_attn", + } + ) + scheduler_arch_config: dict[str, Any] = field( + default_factory=lambda: { + "num_train_timesteps": 1000, + "shift": 4.0, + } + ) + + @property + def text_encoder_arch_config(self) -> dict[str, Any]: + return {"params": {"text_encoder_ckpt": str(self.source_root / "JoyAI-Image-Und")}} + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert a raw JoyAI checkpoint directory to standard diffusers format." + ) + parser.add_argument( + "--source_path", type=str, required=True, help="Path to the original JoyAI checkpoint directory" + ) + parser.add_argument( + "--output_path", type=str, required=True, help="Output path for the converted diffusers checkpoint" + ) + parser.add_argument( + "--dtype", type=str, default="bf16", choices=sorted(DTYPE_MAP), help="Component dtype to load and save" + ) + parser.add_argument("--device", type=str, default="cpu", help="Device used while loading the raw JoyAI checkpoint") + parser.add_argument( + "--safe_serialization", + action="store_true", + default=True, + help="Save diffusers weights with safetensors when supported (default: True)", + ) + return parser.parse_args() + + +def dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]: + if torch_dtype is None: + return None + for name, value in PRECISION_TO_TYPE.items(): + if value == torch_dtype and name in {"fp32", "fp16", "bf16"}: + return name + raise ValueError(f"Unsupported torch dtype for JoyAI conversion: {torch_dtype}") + + +def resolve_manifest_path(source_root: Path, manifest_value: Optional[str]) -> Optional[Path]: + if manifest_value is None: + return None + path = Path(manifest_value) + if path.parts and path.parts[0] == source_root.name: + path = Path(*path.parts[1:]) + return source_root / path + + +def is_joyai_source_dir(path: Path) -> bool: + return ( + path.is_dir() + and (path / "infer_config.py").is_file() + and (path / "manifest.json").is_file() + and (path / "transformer").is_dir() + and (path / "vae").is_dir() + ) + + +def load_transformer_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]: + state = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + if "model" in state: + state = state["model"] + return state + + +def load_joyai_components( + source_root: Union[str, Path], + torch_dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, +) -> dict[str, Any]: + source_root = Path(source_root) + if not is_joyai_source_dir(source_root): + raise ValueError(f"Not a valid JoyAI source checkpoint directory: {source_root}") + + precision = dtype_to_precision(torch_dtype) + cfg = JoyAIImageSourceConfig(source_root=source_root) + + manifest = json.loads((source_root / "manifest.json").read_text()) + transformer_ckpt = resolve_manifest_path(source_root, manifest.get("transformer_ckpt")) + vae_ckpt = source_root / "vae" / "Wan2.1_VAE.pth" + text_encoder_ckpt = source_root / "JoyAI-Image-Und" + + if precision is not None: + cfg.dit_precision = precision + cfg.vae_precision = precision + cfg.text_encoder_precision = precision + + load_device = torch.device(device) if device is not None else torch.device("cpu") + transformer = JoyAIImageTransformer3DModel( + dtype=PRECISION_TO_TYPE[cfg.dit_precision], + device=load_device, + **cfg.dit_arch_config, + ) + state_dict = load_transformer_state_dict(transformer_ckpt) + if "img_in.weight" in state_dict and transformer.img_in.weight.shape != state_dict["img_in.weight"].shape: + value = state_dict["img_in.weight"] + padded = value.new_zeros(transformer.img_in.weight.shape) + padded[:, : value.shape[1], :, :, :] = value + state_dict["img_in.weight"] = padded + transformer.load_state_dict(state_dict, strict=True) + transformer = transformer.to(dtype=PRECISION_TO_TYPE[cfg.dit_precision]).eval() + + vae = JoyAIImageVAE( + pretrained=str(vae_ckpt), + torch_dtype=PRECISION_TO_TYPE[cfg.vae_precision], + device=load_device, + ) + vae = vae.to(device=load_device, dtype=PRECISION_TO_TYPE[cfg.vae_precision]).eval() + text_encoder = ( + Qwen3VLForConditionalGeneration.from_pretrained( + str(text_encoder_ckpt), + dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision], + local_files_only=True, + trust_remote_code=True, + ) + .to(load_device) + .eval() + ) + tokenizer = AutoTokenizer.from_pretrained( + str(text_encoder_ckpt), + local_files_only=True, + trust_remote_code=True, + ) + processor = AutoProcessor.from_pretrained( + str(text_encoder_ckpt), + local_files_only=True, + trust_remote_code=True, + ) + scheduler = JoyAIFlowMatchDiscreteScheduler(**cfg.scheduler_arch_config) + + return { + "args": cfg, + "processor": processor, + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "scheduler": scheduler, + "vae": vae, + } + + +def _sanitize_config_value(value: Any) -> Any: + if isinstance(value, (torch.dtype, torch.device)): + raise TypeError("Drop non-JSON torch config values") + if isinstance(value, Path): + return str(value) + if isinstance(value, dict): + sanitized = {} + for key, item in value.items(): + try: + sanitized[key] = _sanitize_config_value(item) + json.dumps(sanitized[key]) + except TypeError: + continue + return sanitized + if isinstance(value, (list, tuple)): + sanitized = [] + for item in value: + try: + converted = _sanitize_config_value(item) + json.dumps(converted) + sanitized.append(converted) + except TypeError: + continue + return sanitized + return value + + +def _sanitize_component_config(component: Any) -> None: + config = getattr(component, "config", None) + if config is None: + return + + sanitized_config = {} + for key, value in dict(config).items(): + try: + sanitized_value = _sanitize_config_value(value) + json.dumps(sanitized_value) + sanitized_config[key] = sanitized_value + except TypeError: + continue + + component._internal_dict = FrozenDict(sanitized_config) + + +def _sanitize_pipeline_for_export(pipeline: JoyAIImagePipeline) -> None: + for component_name in ["vae", "transformer", "scheduler"]: + _sanitize_component_config(getattr(pipeline, component_name, None)) + + +def main(): + args = parse_args() + source_path = Path(args.source_path) + output_path = Path(args.output_path) + + if not source_path.exists(): + raise ValueError(f"Source path does not exist: {source_path}") + + output_path.mkdir(parents=True, exist_ok=True) + + components = load_joyai_components( + source_root=source_path, + torch_dtype=DTYPE_MAP[args.dtype], + device=args.device, + ) + pipeline = JoyAIImagePipeline( + vae=components["vae"], + text_encoder=components["text_encoder"], + tokenizer=components["tokenizer"], + transformer=components["transformer"], + scheduler=components["scheduler"], + processor=components["processor"], + args=components["args"], + ) + _sanitize_pipeline_for_export(pipeline) + pipeline.save_pretrained(output_path, safe_serialization=args.safe_serialization) + + print(f"Converted JoyAI checkpoint saved to: {output_path}") + print(f"Load with: JoyAIImagePipeline.from_pretrained({str(output_path)!r})") + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 60cf203700d5..baeda1625a09 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -378,6 +378,7 @@ "FlowMatchLCMScheduler", "HeliosDMDScheduler", "HeliosScheduler", + "JoyAIFlowMatchDiscreteScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", "KarrasVeScheduler", @@ -1174,6 +1175,7 @@ HeliosScheduler, HeunDiscreteScheduler, IPNDMScheduler, + JoyAIFlowMatchDiscreteScheduler, KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3b139c2fcdd5..7dfa6c34d58c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -173,7 +173,6 @@ AutoencoderKLMagvit, AutoencoderKLMochi, AutoencoderKLQwenImage, - JoyAIImageVAE, AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, @@ -181,6 +180,7 @@ AutoencoderTiny, AutoencoderVidTok, ConsistencyDecoderVAE, + JoyAIImageVAE, VQModel, ) from .cache_utils import CacheMixin @@ -228,10 +228,10 @@ HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, - JoyAIImageTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, + JoyAIImageTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, LongCatImageTransformer2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 317055ee6d26..a54a005f6812 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -9,6 +9,7 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 +from .autoencoder_kl_joyai_image import JoyAIImageVAE from .autoencoder_kl_kvae import AutoencoderKLKVAE from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo @@ -16,7 +17,6 @@ from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi -from .autoencoder_kl_joyai_image import JoyAIImageVAE from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py index 337f0d88c3a7..c2d08d20c166 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_joyai_image.py @@ -1,4 +1,17 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This model is adapted from https://github.com/jd-opensource/JoyAI-Image from contextlib import nullcontext @@ -245,15 +258,16 @@ def forward(self, x): class Encoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -281,8 +295,7 @@ def __init__(self, # downsample block if i != len(dim_mult) - 1: - mode = 'downsample3d' if temperal_downsample[ - i] else 'downsample2d' + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" downsamples.append(Resample(out_dim, mode=mode)) scale /= 2.0 self.downsamples = nn.ModuleList(downsamples) @@ -311,11 +324,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -341,11 +350,9 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -355,15 +362,16 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): class Decoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -374,7 +382,7 @@ def __init__(self, # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2**(len(dim_mult) - 2) + scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) @@ -402,7 +410,7 @@ def __init__(self, # upsample block if i != len(dim_mult) - 1: - mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + mode = "upsample3d" if temperal_upsample[i] else "upsample2d" upsamples.append(Resample(out_dim, mode=mode)) scale *= 2.0 self.upsamples = nn.ModuleList(upsamples) @@ -423,11 +431,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -452,11 +456,9 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -474,15 +476,16 @@ def count_conv3d(model): class WanVAE_(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -493,12 +496,12 @@ def __init__(self, self.temperal_upsample = temperal_downsample[::-1] # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, - attn_scales, self.temperal_downsample, dropout) + self.encoder = Encoder3d( + dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, - attn_scales, self.temperal_upsample, dropout) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) @property def quant_conv(self): @@ -518,7 +521,7 @@ def _encode_frames(self, x): encoded = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: encoded_chunk = self.encoder( - x[:, :, 1 + 4 * (chunk_idx - 1): 1 + 4 * chunk_idx, :, :], + x[:, :, 1 + 4 * (chunk_idx - 1) : 1 + 4 * chunk_idx, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx, ) @@ -555,8 +558,7 @@ def encode(self, x, scale=None, return_posterior=False): mu = self.reparameterize(mu, log_var) if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) else: mu = (mu - scale[0]) * scale[1] self.clear_cache() @@ -566,8 +568,7 @@ def decode(self, z, scale=None): self.clear_cache() if scale is not None: if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) else: z = z / scale[1] + scale[0] decoded = self._decode_frames(self.post_quant_conv(z)) @@ -586,8 +587,7 @@ def sample(self, imgs, deterministic=False, scale=None): std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) mu = mu + std * torch.randn_like(std) if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) else: mu = (mu - scale[0]) * scale[1] self.clear_cache() @@ -597,7 +597,7 @@ def clear_cache(self): self._conv_num = count_conv3d(self.decoder) self._conv_idx = [0] self._feat_map = [None] * self._conv_num - #cache encode + # cache encode self._enc_conv_num = count_conv3d(self.encoder) self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num @@ -675,12 +675,40 @@ def __init__( temperal_downsample: list[bool] | tuple[bool, ...] = (False, True, True), dropout: float = 0.0, latents_mean: list[float] | tuple[float, ...] = ( - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921, + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, ), latents_std: list[float] | tuple[float, ...] = ( - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160, + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, ), spatial_compression_ratio: int = 8, temporal_compression_ratio: int = 4, @@ -733,7 +761,11 @@ def _latent_scale_tensors(self, device: torch.device, dtype: torch.dtype) -> tup @apply_forward_hook def encode(self, videos: torch.Tensor, return_dict: bool = True, return_posterior: bool = False, **kwargs): - autocast_context = torch.amp.autocast(device_type="cuda", dtype=torch.float32) if videos.device.type == "cuda" else nullcontext() + autocast_context = ( + torch.amp.autocast(device_type="cuda", dtype=torch.float32) + if videos.device.type == "cuda" + else nullcontext() + ) with autocast_context: mean, logvar = self.model.encode(videos, scale=None, return_posterior=True) if return_posterior: @@ -750,7 +782,9 @@ def encode(self, videos: torch.Tensor, return_dict: bool = True, return_posterio @apply_forward_hook def decode(self, zs: torch.Tensor, return_dict: bool = True, **kwargs): - autocast_context = torch.amp.autocast(device_type="cuda", dtype=torch.float32) if zs.device.type == "cuda" else nullcontext() + autocast_context = ( + torch.amp.autocast(device_type="cuda", dtype=torch.float32) if zs.device.type == "cuda" else nullcontext() + ) with autocast_context: mean, inv_std = self._latent_scale_tensors(zs.device, zs.dtype) scale = [mean.view(-1), inv_std.view(-1)] diff --git a/src/diffusers/models/transformers/transformer_joyai_image.py b/src/diffusers/models/transformers/transformer_joyai_image.py index ec785e5525aa..33e3cfff866b 100644 --- a/src/diffusers/models/transformers/transformer_joyai_image.py +++ b/src/diffusers/models/transformers/transformer_joyai_image.py @@ -1,3 +1,18 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This model is adapted from https://github.com/jd-opensource/JoyAI-Image + import math from typing import Any, Dict, Optional, Tuple, Union @@ -18,17 +33,11 @@ from diffusers.models.normalization import RMSNorm -def _create_modulation( - modulate_type: str, - hidden_size: int, - factor: int, - dtype=None, - device=None): +def _create_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None): factory_kwargs = {"dtype": dtype, "device": device} - if modulate_type == 'wanx': + if modulate_type == "wanx": return _WanModulation(hidden_size, factor, **factory_kwargs) - raise ValueError( - f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.") + raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.") class _WanModulation(nn.Module): @@ -44,9 +53,7 @@ def __init__( super().__init__() self.factor = factor self.modulate_table = nn.Parameter( - torch.zeros(1, factor, hidden_size, - dtype=dtype, device=device) / hidden_size**0.5, - requires_grad=True + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, requires_grad=True ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -131,7 +138,7 @@ def __init__( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, dit_modulation_type: Optional[str] = "wanx", - attn_backend: str = 'flash_attn', + attn_backend: str = "flash_attn", ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -146,24 +153,15 @@ def __init__( factor=6, **factory_kwargs, ) - self.img_norm1 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.img_attn_qkv = nn.Linear( - hidden_size, hidden_size * 3, bias=True, **factory_kwargs - ) + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.img_attn_proj = nn.Linear( - hidden_size, hidden_size, bias=True, **factory_kwargs - ) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) - self.img_norm2 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) - self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, - activation_fn="gelu-approximate") + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") self.txt_mod = _create_modulation( modulate_type=self.dit_modulation_type, @@ -171,25 +169,16 @@ def __init__( factor=6, **factory_kwargs, ) - self.txt_norm1 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.txt_attn_qkv = nn.Linear( - hidden_size, hidden_size * 3, bias=True, **factory_kwargs - ) + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6) - self.txt_attn_proj = nn.Linear( - hidden_size, hidden_size, bias=True, **factory_kwargs - ) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) self.attn = JoyAIJointAttention(attn_backend) - self.txt_norm2 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) - self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, - activation_fn="gelu-approximate") + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") @staticmethod def _modulate( @@ -204,9 +193,7 @@ def _modulate( return hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @staticmethod - def _apply_gate( - hidden_states: torch.Tensor, gate: torch.Tensor | None = None, tanh: bool = False - ) -> torch.Tensor: + def _apply_gate(hidden_states: torch.Tensor, gate: torch.Tensor | None = None, tanh: bool = False) -> torch.Tensor: if gate is None: return hidden_states if tanh: @@ -240,9 +227,7 @@ def forward( ) = self.txt_mod(vec) img_modulated = self.img_norm1(img) - img_modulated = self._modulate( - img_modulated, shift=img_mod1_shift, scale=img_mod1_scale - ) + img_modulated = self._modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) img_qkv = self.img_attn_qkv(img_modulated) batch_size, image_sequence_length, _ = img_qkv.shape img_qkv = img_qkv.view(batch_size, image_sequence_length, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4) @@ -255,9 +240,7 @@ def forward( img_k = apply_rotary_emb(img_k, vis_freqs_cis, sequence_dim=1) txt_modulated = self.txt_norm1(txt) - txt_modulated = self._modulate( - txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale - ) + txt_modulated = self._modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) txt_qkv = self.txt_attn_qkv(txt_modulated) _, text_sequence_length, _ = txt_qkv.shape txt_qkv = txt_qkv.view(batch_size, text_sequence_length, 3, self.heads_num, -1).permute(2, 0, 1, 3, 4) @@ -268,7 +251,6 @@ def forward( if txt_freqs_cis is not None: raise NotImplementedError("RoPE text is not supported for inference") - attention_output = self.attn( torch.cat((img_q, txt_q), dim=1), torch.cat((img_k, txt_k), dim=1), @@ -278,27 +260,17 @@ def forward( ) attention_output = attention_output.flatten(2, 3) image_attention_output = attention_output[:, : img.shape[1]] - text_attention_output = attention_output[:, img.shape[1]:] + text_attention_output = attention_output[:, img.shape[1] :] - img = img + self._apply_gate(self.img_attn_proj(image_attention_output), - gate=img_mod1_gate) + img = img + self._apply_gate(self.img_attn_proj(image_attention_output), gate=img_mod1_gate) img = img + self._apply_gate( - self.img_mlp( - self._modulate( - self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale - ) - ), + self.img_mlp(self._modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), gate=img_mod2_gate, ) - txt = txt + self._apply_gate(self.txt_attn_proj(text_attention_output), - gate=txt_mod1_gate) + txt = txt + self._apply_gate(self.txt_attn_proj(text_attention_output), gate=txt_mod1_gate) txt = txt + self._apply_gate( - self.txt_mlp( - self._modulate( - self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale - ) - ), + self.txt_mlp(self._modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), gate=txt_mod2_gate, ) @@ -315,14 +287,11 @@ def __init__( ): super().__init__() - self.timesteps_proj = Timesteps( - num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.time_embedder = TimestepEmbedding( - in_channels=time_freq_dim, time_embed_dim=dim) + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) - self.text_embedder = PixArtAlphaTextProjection( - text_embed_dim, dim, act_fn="gelu_tanh") + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") def forward( self, @@ -342,8 +311,7 @@ def forward( class JoyAIImageTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): - _fsdp_shard_conditions: list = [ - lambda name, module: isinstance(module, JoyAIImageTransformerBlock)] + _fsdp_shard_conditions: list = [lambda name, module: isinstance(module, JoyAIImageTransformerBlock)] _supports_gradient_checkpointing = True @register_to_config @@ -358,11 +326,11 @@ def __init__( mlp_width_ratio: float = 4.0, mm_double_blocks_depth: int = 20, rope_dim_list: tuple[int, int, int] = (16, 56, 56), - rope_type: str = 'rope', + rope_type: str = "rope", dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, dit_modulation_type: str = "wanx", - attn_backend: str = 'flash_attn', + attn_backend: str = "flash_attn", theta: int = 256, ): self.out_channels = out_channels or in_channels @@ -377,12 +345,9 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if hidden_size % heads_num != 0: - raise ValueError( - f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}" - ) + raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") - self.img_in = nn.Conv3d( - in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) self.condition_embedder = JoyAITimeTextEmbedding( dim=hidden_size, @@ -405,13 +370,8 @@ def __init__( ] ) - self.norm_out = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6 - ) - self.proj_out = nn.Linear( - hidden_size, out_channels * math.prod(patch_size), - **factory_kwargs) - + self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, out_channels * math.prod(patch_size), **factory_kwargs) @staticmethod def _get_meshgrid_nd(start, *args, dim=2): @@ -423,6 +383,7 @@ def as_tuple(value): if len(value) == dim: return value raise ValueError(f"Expected length {dim} or int, but got {value}") + if len(args) == 0: num = as_tuple(start) start = (0,) * dim @@ -448,7 +409,6 @@ def as_tuple(value): return grid - @staticmethod def _get_nd_rotary_pos_embed( rope_dim_list, @@ -460,9 +420,7 @@ def _get_nd_rotary_pos_embed( ): """Build visual and optional text rotary embeddings.""" - grid = JoyAIImageTransformer3DModel._get_meshgrid_nd( - start, *args, dim=len(rope_dim_list) - ) + grid = JoyAIImageTransformer3DModel._get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) embs = [] for i in range(len(rope_dim_list)): @@ -502,9 +460,6 @@ def _get_nd_rotary_pos_embed( txt_emb = None return vis_emb, txt_emb - - - def get_rotary_pos_embed(self, image_grid_size, text_sequence_length=None): target_ndim = 3 @@ -513,11 +468,8 @@ def get_rotary_pos_embed(self, image_grid_size, text_sequence_length=None): head_dim = self.hidden_size // self.heads_num rope_dim_list = self.rope_dim_list if rope_dim_list is None: - rope_dim_list = [head_dim // - target_ndim for _ in range(target_ndim)] - assert ( - sum(rope_dim_list) == head_dim - ), "sum(rope_dim_list) should equal to head_dim of attention layer" + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" image_rotary_emb, text_rotary_emb = self._get_nd_rotary_pos_embed( rope_dim_list, image_grid_size, @@ -535,19 +487,13 @@ def forward( encoder_hidden_states_mask: torch.Tensor = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - is_multi_item = (len(hidden_states.shape) == 6) + is_multi_item = len(hidden_states.shape) == 6 num_items = 0 if is_multi_item: num_items = hidden_states.shape[1] if num_items > 1: assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1" - hidden_states = torch.cat( - [ - hidden_states[:, -1:], - hidden_states[:, :-1] - ], - dim=1 - ) + hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1) batch_size, num_items, channels, frames_per_item, height, width = hidden_states.shape hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).reshape( batch_size, channels, num_items * frames_per_item, height, width @@ -568,7 +514,9 @@ def forward( device=image_hidden_states.device, ) else: - encoder_hidden_states_mask = encoder_hidden_states_mask.to(device=image_hidden_states.device, dtype=torch.bool) + encoder_hidden_states_mask = encoder_hidden_states_mask.to( + device=image_hidden_states.device, dtype=torch.bool + ) modulation_states, text_hidden_states = self.condition_embedder(timestep, encoder_hidden_states) if modulation_states.shape[-1] > self.hidden_size: modulation_states = modulation_states.unflatten(1, (6, -1)) @@ -577,7 +525,7 @@ def forward( image_seq_len = image_hidden_states.shape[1] image_rotary_emb, text_rotary_emb = self.get_rotary_pos_embed( image_grid_size=(latent_frames, latent_height, latent_width), - text_sequence_length=text_seq_len if self.rope_type == 'mrope' else None, + text_sequence_length=text_seq_len if self.rope_type == "mrope" else None, ) attention_mask = torch.cat( @@ -592,9 +540,9 @@ def forward( dim=1, ) attention_kwargs = { - 'thw': [latent_frames, latent_height, latent_width], - 'txt_len': text_seq_len, - 'attention_mask': attention_mask, + "thw": [latent_frames, latent_height, latent_width], + "txt_len": text_seq_len, + "attention_mask": attention_mask, } for block in self.double_blocks: diff --git a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py index f47d66994f8d..d17ec3ce568f 100644 --- a/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py +++ b/src/diffusers/pipelines/joyai_image/pipeline_joyai_image.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,16 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# This pipeline is adapted from https://github.com/jd-opensource/JoyAI-Image import inspect -import json -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, TypedDict, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch -from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, Qwen3VLForConditionalGeneration +from transformers import AutoProcessor, PreTrainedTokenizerBase, Qwen3VLForConditionalGeneration from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import VaeImageProcessor @@ -36,7 +34,6 @@ logger = logging.get_logger(__name__) - PRECISION_TO_TYPE = { "fp32": torch.float32, "float32": torch.float32, @@ -47,172 +44,6 @@ } -class JoyAIDitArchConfig(TypedDict): - hidden_size: int - in_channels: int - heads_num: int - mm_double_blocks_depth: int - out_channels: int - patch_size: list[int] - rope_dim_list: list[int] - text_states_dim: int - rope_type: str - dit_modulation_type: str - theta: int - attn_backend: str - - -class JoyAISchedulerArchConfig(TypedDict): - num_train_timesteps: int - shift: float - - -class JoyAIImageComponents(TypedDict): - args: "JoyAIImageSourceConfig" - tokenizer: PreTrainedTokenizerBase - text_encoder: Qwen3VLForConditionalGeneration - transformer: JoyAIImageTransformer3DModel - scheduler: JoyAIFlowMatchDiscreteScheduler - vae: JoyAIImageVAE - - -@dataclass -class JoyAIImageSourceConfig: - source_root: Path - dit_precision: str = "bf16" - vae_precision: str = "bf16" - text_encoder_precision: str = "bf16" - text_token_max_length: int = 2048 - enable_multi_task_training: bool = False - dit_arch_config: JoyAIDitArchConfig = field( - default_factory=lambda: { - "hidden_size": 4096, - "in_channels": 16, - "heads_num": 32, - "mm_double_blocks_depth": 40, - "out_channels": 16, - "patch_size": [1, 2, 2], - "rope_dim_list": [16, 56, 56], - "text_states_dim": 4096, - "rope_type": "rope", - "dit_modulation_type": "wanx", - "theta": 10000, - "attn_backend": "flash_attn", - } - ) - scheduler_arch_config: JoyAISchedulerArchConfig = field( - default_factory=lambda: { - "num_train_timesteps": 1000, - "shift": 4.0, - } - ) - - @property - def text_encoder_arch_config(self) -> dict: - return {"params": {"text_encoder_ckpt": str(self.source_root / "JoyAI-Image-Und")}} - - -def dtype_to_precision(torch_dtype: Optional[torch.dtype]) -> Optional[str]: - if torch_dtype is None: - return None - for name, value in PRECISION_TO_TYPE.items(): - if value == torch_dtype and name in {"fp32", "fp16", "bf16"}: - return name - raise ValueError(f"Unsupported torch dtype for JoyAIImagePipeline: {torch_dtype}") - - -def resolve_manifest_path(source_root: Path, manifest_value: Optional[str]) -> Optional[Path]: - if manifest_value is None: - return None - path = Path(manifest_value) - if path.parts and path.parts[0] == source_root.name: - path = Path(*path.parts[1:]) - return source_root / path - - -def is_joyai_source_dir(path: Path) -> bool: - return ( - path.is_dir() - and (path / "infer_config.py").is_file() - and (path / "manifest.json").is_file() - and (path / "transformer").is_dir() - and (path / "vae").is_dir() - ) - - -def load_transformer_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]: - state = torch.load(checkpoint_path, map_location="cpu", weights_only=True) - if "model" in state: - state = state["model"] - return state - - -def load_joyai_components( - source_root: Union[str, Path], - torch_dtype: Optional[torch.dtype] = None, - device: Optional[Union[str, torch.device]] = None, -) -> JoyAIImageComponents: - source_root = Path(source_root) - if not is_joyai_source_dir(source_root): - raise ValueError(f"Not a valid JoyAI source checkpoint directory: {source_root}") - - precision = dtype_to_precision(torch_dtype) - cfg = JoyAIImageSourceConfig(source_root=source_root) - - manifest = json.loads((source_root / "manifest.json").read_text()) - transformer_ckpt = resolve_manifest_path(source_root, manifest.get("transformer_ckpt")) - vae_ckpt = source_root / "vae" / "Wan2.1_VAE.pth" - text_encoder_ckpt = source_root / "JoyAI-Image-Und" - - if precision is not None: - cfg.dit_precision = precision - cfg.vae_precision = precision - cfg.text_encoder_precision = precision - - load_device = torch.device(device) if device is not None else torch.device("cpu") - transformer = JoyAIImageTransformer3DModel( - dtype=PRECISION_TO_TYPE[cfg.dit_precision], - device=load_device, - **cfg.dit_arch_config, - ) - state_dict = load_transformer_state_dict(transformer_ckpt) - if "img_in.weight" in state_dict and transformer.img_in.weight.shape != state_dict["img_in.weight"].shape: - value = state_dict["img_in.weight"] - padded = value.new_zeros(transformer.img_in.weight.shape) - padded[:, : value.shape[1], :, :, :] = value - state_dict["img_in.weight"] = padded - transformer.load_state_dict(state_dict, strict=True) - transformer = transformer.to(dtype=PRECISION_TO_TYPE[cfg.dit_precision]).eval() - - vae = JoyAIImageVAE( - pretrained=str(vae_ckpt), - torch_dtype=PRECISION_TO_TYPE[cfg.vae_precision], - device=load_device, - ) - vae = vae.to(device=load_device, dtype=PRECISION_TO_TYPE[cfg.vae_precision]).eval() - text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( - str(text_encoder_ckpt), - dtype=PRECISION_TO_TYPE[cfg.text_encoder_precision], - local_files_only=True, - trust_remote_code=True, - ).to(load_device).eval() - tokenizer = AutoTokenizer.from_pretrained( - str(text_encoder_ckpt), - local_files_only=True, - trust_remote_code=True, - ) - scheduler = JoyAIFlowMatchDiscreteScheduler(**cfg.scheduler_arch_config) - - return { - "args": cfg, - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "transformer": transformer, - "scheduler": scheduler, - "vae": vae, - } - - PROMPT_TEMPLATE_ENCODE = { "image": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", "multiple_images": "<|im_start|>system\n \nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n", @@ -264,6 +95,7 @@ def retrieve_timesteps( class JoyAIImagePipeline(DiffusionPipeline): model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = ["processor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -273,7 +105,8 @@ def __init__( tokenizer: PreTrainedTokenizerBase, transformer: JoyAIImageTransformer3DModel, scheduler: JoyAIFlowMatchDiscreteScheduler, - args: JoyAIImageSourceConfig | None = None, + processor: Any | None = None, + args: Any | None = None, ): super().__init__() self.args = args @@ -282,6 +115,7 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + processor=processor, transformer=transformer, scheduler=scheduler, ) @@ -295,13 +129,13 @@ def __init__( self.vae_scale_factor_temporal = 4 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.qwen_processor = None + self.qwen_processor = processor text_encoder_ckpt = None text_encoder_cfg = getattr(self.args, "text_encoder_arch_config", None) if isinstance(text_encoder_cfg, dict): text_encoder_params = text_encoder_cfg.get("params", {}) text_encoder_ckpt = text_encoder_params.get("text_encoder_ckpt") - if text_encoder_ckpt is not None: + if self.qwen_processor is None and text_encoder_ckpt is not None: self.qwen_processor = AutoProcessor.from_pretrained( text_encoder_ckpt, local_files_only=True, @@ -313,39 +147,6 @@ def __init__( self.prompt_template_encode_start_idx = PROMPT_TEMPLATE_START_IDX self._joyai_force_vae_fp32 = True - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - source_path = Path(pretrained_model_name_or_path) if pretrained_model_name_or_path is not None else None - if source_path is not None and is_joyai_source_dir(source_path): - return cls.from_joyai_sources(pretrained_model_name_or_path, **kwargs) - return super().from_pretrained(pretrained_model_name_or_path, **kwargs) - - @classmethod - def from_joyai_sources( - cls, - pretrained_model_name_or_path: Union[str, Path], - torch_dtype: Optional[torch.dtype] = None, - device: Optional[Union[str, torch.device]] = None, - **kwargs, - ): - components = load_joyai_components( - source_root=pretrained_model_name_or_path, - torch_dtype=torch_dtype, - device=device, - ) - - pipe = cls( - vae=components["vae"], - tokenizer=components["tokenizer"], - text_encoder=components["text_encoder"], - transformer=components["transformer"], - scheduler=components["scheduler"], - args=components["args"], - ) - if device is not None: - pipe._joyai_execution_device_override = torch.device(device) - return pipe - def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) diff --git a/src/diffusers/pipelines/joyai_image/pipeline_output.py b/src/diffusers/pipelines/joyai_image/pipeline_output.py index d085cafc4790..131da308bed5 100644 --- a/src/diffusers/pipelines/joyai_image/pipeline_output.py +++ b/src/diffusers/pipelines/joyai_image/pipeline_output.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Union diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b0fd70c2cd8e..10f23a0d770b 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -171,15 +171,15 @@ from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler - from .scheduling_joyai_flow_match_discrete import ( - JoyAIFlowMatchDiscreteScheduler, - JoyAIFlowMatchDiscreteSchedulerOutput, - ) from .scheduling_flow_match_lcm import FlowMatchLCMScheduler from .scheduling_helios import HeliosScheduler from .scheduling_helios_dmd import HeliosDMDScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler + from .scheduling_joyai_flow_match_discrete import ( + JoyAIFlowMatchDiscreteScheduler, + JoyAIFlowMatchDiscreteSchedulerOutput, + ) from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_lcm import LCMScheduler diff --git a/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py index b68848a17897..b3acaaba10e6 100644 --- a/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py +++ b/src/diffusers/schedulers/scheduling_joyai_flow_match_discrete.py @@ -30,22 +30,34 @@ def __init__( num_train_timesteps: int = 1000, shift: float = 1.0, reverse: bool = True, + use_dynamic_shifting: bool = False, + base_shift: float = 0.5, + max_shift: float = 1.15, + base_image_seq_len: int = 256, + max_image_seq_len: int = 4096, + invert_sigmas: bool = False, + shift_terminal: float | None = None, + use_karras_sigmas: bool = False, + use_exponential_sigmas: bool = False, + use_beta_sigmas: bool = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, ): super().__init__( num_train_timesteps=num_train_timesteps, shift=shift, - use_dynamic_shifting=False, - base_shift=0.5, - max_shift=1.15, - base_image_seq_len=256, - max_image_seq_len=4096, - invert_sigmas=False, - shift_terminal=None, - use_karras_sigmas=False, - use_exponential_sigmas=False, - use_beta_sigmas=False, - time_shift_type="exponential", - stochastic_sampling=False, + use_dynamic_shifting=use_dynamic_shifting, + base_shift=base_shift, + max_shift=max_shift, + base_image_seq_len=base_image_seq_len, + max_image_seq_len=max_image_seq_len, + invert_sigmas=invert_sigmas, + shift_terminal=shift_terminal, + use_karras_sigmas=use_karras_sigmas, + use_exponential_sigmas=use_exponential_sigmas, + use_beta_sigmas=use_beta_sigmas, + time_shift_type=time_shift_type, + stochastic_sampling=stochastic_sampling, ) self.register_to_config(reverse=reverse) diff --git a/tests/pipelines/joyai_image/test_pipeline_joyai_image.py b/tests/pipelines/joyai_image/test_pipeline_joyai_image.py new file mode 100644 index 000000000000..5adae87b6fbb --- /dev/null +++ b/tests/pipelines/joyai_image/test_pipeline_joyai_image.py @@ -0,0 +1,56 @@ +from unittest.mock import patch + +from diffusers import DiffusionPipeline, JoyAIImagePipeline +from diffusers.configuration_utils import FrozenDict +from diffusers.pipelines.joyai_image import pipeline_joyai_image + + +class _DummyModule: + pass + + +def test_joyai_pipeline_uses_base_from_pretrained(): + assert JoyAIImagePipeline.from_pretrained.__func__ is DiffusionPipeline.from_pretrained.__func__ + + +def test_joyai_pipeline_does_not_expose_source_loader_api(): + assert not hasattr(JoyAIImagePipeline, "from_joyai_sources") + + +def test_joyai_pipeline_module_does_not_expose_raw_source_helpers(): + assert not hasattr(pipeline_joyai_image, "load_joyai_components") + + +def test_joyai_pipeline_keeps_passed_processor_without_reloading(): + pipe = object.__new__(JoyAIImagePipeline) + pipe._internal_dict = FrozenDict({}) + pipe.args = type("Args", (), {"text_encoder_arch_config": {"params": {"text_encoder_ckpt": "/tmp/raw"}}})() + pipe.vae = type("VAE", (), {"ffactor_spatial": 8, "ffactor_temporal": 4})() + + registered = {} + + def fake_register_modules(**kwargs): + registered.update(kwargs) + for key, value in kwargs.items(): + setattr(pipe, key, value) + + pipe.register_modules = fake_register_modules + + processor = _DummyModule() + with patch( + "diffusers.pipelines.joyai_image.pipeline_joyai_image.AutoProcessor.from_pretrained" + ) as mock_from_pretrained: + JoyAIImagePipeline.__init__( + pipe, + vae=pipe.vae, + text_encoder=_DummyModule(), + tokenizer=_DummyModule(), + transformer=_DummyModule(), + scheduler=_DummyModule(), + processor=processor, + args=pipe.args, + ) + + assert pipe.qwen_processor is processor + mock_from_pretrained.assert_not_called() + assert registered["processor"] is processor diff --git a/tests/schedulers/test_scheduler_joyai_flow_match_discrete.py b/tests/schedulers/test_scheduler_joyai_flow_match_discrete.py new file mode 100644 index 000000000000..e24466ea136d --- /dev/null +++ b/tests/schedulers/test_scheduler_joyai_flow_match_discrete.py @@ -0,0 +1,37 @@ +import tempfile + +import torch + +from diffusers import JoyAIFlowMatchDiscreteScheduler +from diffusers.utils import logging + +from .test_schedulers import CaptureLogger + + +def test_joyai_scheduler_roundtrip_config_has_no_unexpected_warning(): + scheduler = JoyAIFlowMatchDiscreteScheduler(num_train_timesteps=1000, shift=4.0, reverse=True) + logger = logging.get_logger("diffusers.configuration_utils") + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + with CaptureLogger(logger) as cap_logger: + config = JoyAIFlowMatchDiscreteScheduler.load_config(tmpdirname) + reloaded = JoyAIFlowMatchDiscreteScheduler.from_config(config) + + assert isinstance(reloaded, JoyAIFlowMatchDiscreteScheduler) + assert cap_logger.out == "" + + +def test_joyai_scheduler_reloaded_instance_supports_step(): + scheduler = JoyAIFlowMatchDiscreteScheduler(num_train_timesteps=1000, shift=4.0, reverse=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + reloaded = JoyAIFlowMatchDiscreteScheduler.from_pretrained(tmpdirname) + + reloaded.set_timesteps(2) + sample = torch.zeros(1, 2, 2) + model_output = torch.zeros_like(sample) + prev_sample = reloaded.step(model_output, reloaded.timesteps[0], sample, return_dict=False)[0] + + assert prev_sample.shape == sample.shape