diff --git a/README.md b/README.md index 1b0dc69c..e6584d23 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2026/1/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported - **`2025/11/11`**: Wan2.2 txt2vid generation is now supported - **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported. - **`2025/10/14`**: NVIDIA DGX Spark Flux support. @@ -482,19 +483,38 @@ To generate images, run the following command: Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + ### Text2Vid + ```bash HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 ``` + + ### Img2Vid + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_14b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + ``` + ## Wan2.2 Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + ### Text2Vid + ```bash HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 ``` + ### Img2Vid + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_27b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + ``` + ## Flux First make sure you have permissions to access the Flux repos in Huggingface. diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index b601cb34..006b3ec8 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -19,6 +19,8 @@ from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 +from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 +from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 from .. import max_logging, max_utils import orbax.checkpoint as ocp @@ -59,7 +61,7 @@ def load_diffusers_checkpoint(self): raise NotImplementedError @abstractmethod - def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]: + def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int]]: raise NotImplementedError @abstractmethod diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py new file mode 100644 index 00000000..6f4bbc90 --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py @@ -0,0 +1,95 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import json +import jax +import numpy as np +from typing import Optional, Tuple +from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 +from .. import max_logging +import orbax.checkpoint as ocp +from etils import epath +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + +class WanCheckpointerI2V_2_1(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + transformer_metadata = metadatas.wan_state + abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) + params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_params, + ) + ) + + max_logging.log("Restoring WAN checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + wan_state=params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipelineI2V_2_1.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_1, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipelineI2V_2_1.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint.wan_state.keys(): + opt_state = restored_checkpoint.wan_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_1, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items["wan_state"] = ocp.args.PyTreeSave(train_states) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py new file mode 100644 index 00000000..a55048cf --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py @@ -0,0 +1,114 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import json +import jax +import numpy as np +from typing import Optional, Tuple +from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 +from .. import max_logging +import orbax.checkpoint as ocp +from etils import epath +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + +class WanCheckpointerI2V_2_2(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + # Handle low_noise_transformer + low_noise_transformer_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + low_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_low_params, + ) + ) + + # Handle high_noise_transformer + high_noise_transformer_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + high_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_params, + ) + ) + + max_logging.log("Restoring WAN 2.2 checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + low_noise_transformer_state=low_params_restore, + high_noise_transformer_state=high_params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipelineI2V_2_2.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint) + # Check for optimizer state in either transformer + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_2, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), + } + + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f152ac73..b2a11dba 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -29,6 +29,7 @@ log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' model_name: wan2.1 +model_type: 'T2V' # Overrides the transformer from pretrained_model_name_or_path wan_transformer_pretrained_model_name_or_path: '' diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 314d1141..cff70a94 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -29,6 +29,7 @@ log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' model_name: wan2.2 +model_type: 'T2V' # Overrides the transformer from pretrained_model_name_or_path wan_transformer_pretrained_model_name_or_path: '' diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml new file mode 100644 index 00000000..92a371e3 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -0,0 +1,345 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-480P-Diffusers' +model_name: wan2.1 +model_type: 'I2V' + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: True + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 +dropout: 0.1 + +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +attention_sharding_uniform: True + +flash_block_sizes: { + "block_q" : 2048, + "block_kv_compute" : 512, + "block_kv" : 2048, + "block_q_dkv" : 2048, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 512, + "use_fused_bwd_kernel" : True +} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048, +# "use_fused_bwd_kernel": False, +# } +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +# Don't override _class_name - use the pretrained UniPCMultistepScheduler +diffusion_scheduler_config: { + prediction_type: 'flow_prediction', + rescale_zero_terminal_snr: False, + timestep_spacing: 'linspace' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', 'data'], + ['activation_self_attn_heads', ['fsdp', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp', 'tensor']], + ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + +# Generation parameters +prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 480 +width: 832 +num_frames: 81 +guidance_scale: 5.0 +flow_shift: 3.0 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 24 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False + +# i2v specific parameters +# I2V Input Image +# URL or local path to the conditioning image +image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml new file mode 100644 index 00000000..f8982b44 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -0,0 +1,357 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.2-I2V-A14B-Diffusers' +model_name: wan2.2 +model_type: 'I2V' + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: True + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 +dropout: 0.1 + +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +attention_sharding_uniform: True + +flash_block_sizes: { + "block_q" : 1024, + "block_kv_compute" : 256, + "block_kv" : 1024, + "block_q_dkv" : 1024, + "block_kv_dkv" : 1024, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 1024, + "block_kv_dq" : 1024 +} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048 +# "use_fused_bwd_kernel": False, +# } +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', 'data'], + ['activation_self_attn_heads', ['fsdp', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp', 'tensor']], + ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + +# Generation parameters +prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 480 +width: 832 +num_frames: 81 +flow_shift: 3.0 + +# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py +# guidance scale factor for low noise transformer +guidance_scale_low: 3.0 + +# guidance scale factor for high noise transformer +guidance_scale_high: 4.0 + +# The timestep threshold. If `t` is at or above this value, +# the `high_noise_model` is considered as the required model. +# timestep to switch between low noise and high noise transformer +boundary_ratio: 0.875 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 24 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False + +# i2v specific parameters +# I2V Input Image +# URL or local path to the conditioning image +image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" \ No newline at end of file diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e3365e96..d3aad31d 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -19,9 +19,12 @@ import subprocess from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 +from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1 +from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 from maxdiffusion import pyconfig, max_logging, max_utils from absl import app from maxdiffusion.utils import export_to_video +from maxdiffusion.utils.loading_utils import load_image from google.cloud import storage import flax from maxdiffusion.common_types import WAN2_1, WAN2_2 @@ -79,30 +82,59 @@ def get_git_commit_hash(): def call_pipeline(config, pipeline, prompt, negative_prompt): model_key = config.model_name - if model_key == WAN2_1: - return pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, - ) - elif model_key == WAN2_2: - return pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale_low=config.guidance_scale_low, - guidance_scale_high=config.guidance_scale_high, - boundary=config.boundary_timestep, - ) - else: - raise ValueError(f"Unsupported model_name in config: {model_key}") + model_type = config.model_type + if model_type == "I2V": + image = load_image(config.image_url) + if model_key == WAN2_1: + return pipeline( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + elif model_key == WAN2_2: + return pipeline( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale_low=config.guidance_scale_low, + guidance_scale_high=config.guidance_scale_high, + ) + else: + raise ValueError(f"Unsupported model_name for I2V in config: {model_key}") + elif model_type == "T2V": + if model_key == WAN2_1: + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + elif model_key == WAN2_2: + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale_low=config.guidance_scale_low, + guidance_scale_high=config.guidance_scale_high, + boundary=config.boundary_timestep, + ) + else: + raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}") def inference_generate_video(config, pipeline, filename_prefix=""): @@ -141,10 +173,17 @@ def run(config, pipeline=None, filename_prefix=""): max_logging.log("Could not retrieve Git commit hash.") if pipeline is None: + model_type = config.model_type if model_key == WAN2_1: - checkpoint_loader = WanCheckpointer2_1(config=config) + if model_type == "I2V": + checkpoint_loader = WanCheckpointerI2V_2_1(config=config) + else: + checkpoint_loader = WanCheckpointer2_1(config=config) elif model_key == WAN2_2: - checkpoint_loader = WanCheckpointer2_2(config=config) + if model_type == "I2V": + checkpoint_loader = WanCheckpointerI2V_2_2(config=config) + else: + checkpoint_loader = WanCheckpointer2_2(config=config) else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") pipeline, _, _ = checkpoint_loader.load_checkpoint() @@ -162,7 +201,7 @@ def run(config, pipeline=None, filename_prefix=""): max_logging.log("===================== Model details =======================") max_logging.log(f"model name: {config.model_name}") max_logging.log(f"model path: {config.pretrained_model_name_or_path}") - max_logging.log("model type: t2v") + max_logging.log(f"model type: {config.model_type}") max_logging.log(f"hardware: {jax.devices()[0].platform}") max_logging.log(f"number of devices: {jax.device_count()}") max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f8ba9310..2982e19e 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -31,6 +31,7 @@ from .. import common_types, max_logging from . import quantizations +from .modeling_flax_utils import get_activation Array = common_types.Array @@ -220,6 +221,7 @@ def _tpu_flash_attention( attention_kernel: str = "flash", mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, + attention_mask: jax.Array = None, ) -> jax.Array: """TPU Flash Attention""" @@ -294,6 +296,24 @@ def wrap_flash_attention(query, key, value): kv_padded_len = key.shape[2] kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + + # If attention_mask is provided, apply it to kv_segment_ids + if attention_mask is not None: + mask_len = min(key_seq_len, attention_mask.shape[1]) + kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,) + # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid) + if key_seq_len > mask_len: + extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) + kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,) + # Pad to kv_padded_len + if kv_padded_len > key_seq_len: + padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) + kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,) + else: + kv_mask_padded = kv_mask_for_batch + # Both are (kv_padded_len,) - element-wise multiplication + kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) # make_splash_mha is wrapped around shardmap and seq and head is already @@ -502,6 +522,7 @@ def _apply_attention( dpa_layer: Callable, mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, + attention_mask: Array = None, ): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) @@ -534,6 +555,7 @@ def _apply_attention( attention_kernel, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, + attention_mask=attention_mask, ) elif attention_kernel == "ring": return _tpu_flash_attention( @@ -649,6 +671,40 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) +# New Class for Wan I2V +class NNXSimpleFeedForward(nnx.Module): + def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult: int = 4, activation_fn: str = "gelu", dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: Optional[jax.lax.Precision] = None): + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + self.net_0 = nnx.Linear( + dim, + inner_dim, + rngs=rngs, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", None)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)), + ) + self.act = get_activation(activation_fn) + self.net_2 = nnx.Linear( + inner_dim, + dim_out, + rngs=rngs, + use_bias=True, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), + ) + + def __call__(self, hidden_states: Array) -> Array: + hidden_states = self.net_0(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.net_2(hidden_states) + return hidden_states class NNXAttentionOp(nnx.Module): @@ -693,7 +749,7 @@ def __init__( self.mask_padding_tokens = mask_padding_tokens self.residual_checkpoint_name = residual_checkpoint_name - def apply_attention(self, query: Array, key: Array, value: Array): + def apply_attention(self, query: Array, key: Array, value: Array, attention_mask: Array = None): return _apply_attention( query=query, key=key, @@ -714,6 +770,7 @@ def apply_attention(self, query: Array, key: Array, value: Array): dpa_layer=self.dpa_layer, mask_padding_tokens=self.mask_padding_tokens, residual_checkpoint_name=self.residual_checkpoint_name, + attention_mask=attention_mask, ) @@ -753,7 +810,7 @@ def setup(self): transpose_batch_sequence=False, ) - def apply_attention(self, query: Array, key: Array, value: Array): + def apply_attention(self, query: Array, key: Array, value: Array, attention_mask: Array = None): return _apply_attention( query=query, key=key, @@ -772,6 +829,7 @@ def apply_attention(self, query: Array, key: Array, value: Array): axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, + attention_mask=attention_mask, ) @@ -806,6 +864,8 @@ def __init__( mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, + added_kv_proj_dim: Optional[int] = None, # New for I2V + image_seq_len: Optional[int] = None, # New for I2V ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -829,6 +889,8 @@ def __init__( else: axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) + self.added_kv_proj_dim = added_kv_proj_dim # New for I2V + self.image_seq_len = image_seq_len # New for I2V self.attention_op = NNXAttentionOp( mesh=mesh, @@ -938,6 +1000,35 @@ def __init__( param_dtype=weights_dtype, ) + # New layers for I2V image conditioning + self.add_k_proj = nnx.data(None) + self.add_v_proj = nnx.data(None) + self.norm_added_k = nnx.data(None) + if self.added_kv_proj_dim is not None: + self.add_k_proj = nnx.Linear( + self.added_kv_proj_dim, self.inner_dim, rngs=rngs, + dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ("embed",), + ), + ) + self.add_v_proj = nnx.Linear( + self.added_kv_proj_dim, self.inner_dim, rngs=rngs, + dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ("embed",), + ), + ) + self.norm_added_k = nnx.RMSNorm( + num_features=self.inner_dim, rngs=rngs, epsilon=eps, dtype=dtype, param_dtype=weights_dtype, + scale_init=nnx.with_partitioning( + nnx.initializers.ones, + ("norm",), + ), + ) + def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]: dtype = xq.dtype reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) @@ -963,42 +1054,132 @@ def __call__( hidden_states: jax.Array, encoder_hidden_states: jax.Array = None, rotary_emb: Optional[jax.Array] = None, + encoder_attention_mask: Optional[jax.Array] = None, deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) dtype = hidden_states.dtype + is_self_attention = encoder_hidden_states is None if encoder_hidden_states is None: encoder_hidden_states = hidden_states - with jax.named_scope("query_proj"): - query_proj = self.query(hidden_states) - with jax.named_scope("key_proj"): - key_proj = self.key(encoder_hidden_states) - with jax.named_scope("value_proj"): - value_proj = self.value(encoder_hidden_states) - - if self.qk_norm: - with self.conditional_named_scope("attn_q_norm"): - query_proj = self.norm_q(query_proj) - with self.conditional_named_scope("attn_k_norm"): - key_proj = self.norm_k(key_proj) - - if rotary_emb is not None: - with self.conditional_named_scope("attn_rope"): - query_proj = _unflatten_heads(query_proj, self.heads) - key_proj = _unflatten_heads(key_proj, self.heads) - value_proj = _unflatten_heads(value_proj, self.heads) - # output of _unflatten_heads Batch, heads, seq_len, head_dim - query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - - query_proj = checkpoint_name(query_proj, "query_proj") - key_proj = checkpoint_name(key_proj, "key_proj") - value_proj = checkpoint_name(value_proj, "value_proj") - - with jax.named_scope("apply_attention"): - attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + is_i2v_cross_attention = self.added_kv_proj_dim is not None and not is_self_attention + + if not is_i2v_cross_attention: + with jax.named_scope("query_proj"): + query_proj = self.query(hidden_states) + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_q_norm"): + query_proj = self.norm_q(query_proj) + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) + + if rotary_emb is not None: + with self.conditional_named_scope("attn_rope"): + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) + value_proj = _unflatten_heads(value_proj, self.heads) + # output of _unflatten_heads Batch, heads, seq_len, head_dim + query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + + query_proj = checkpoint_name(query_proj, "query_proj") + key_proj = checkpoint_name(key_proj, "key_proj") + value_proj = checkpoint_name(value_proj, "value_proj") + + with jax.named_scope("apply_attention"): + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + + else: + # NEW PATH for I2V CROSS-ATTENTION + with self.conditional_named_scope("proj_query"): + query_proj_raw = self.query(hidden_states) + + # Image embeddings are padded to multiples of 128 for TPU flash attention + # Calculate the padded length to correctly split image and text embeddings + if self.added_kv_proj_dim is not None: + alignment = 128 + if self.image_seq_len is not None: + image_seq_len_actual = self.image_seq_len + else: + image_seq_len_actual = 257 + padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment # 257 -> 384 + + if encoder_attention_mask is None: + padded_img_len = image_seq_len_actual + + encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :] + encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :] + + # Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention + # It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384 + if encoder_attention_mask is not None: + encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len] + else: + # Fallback: no mask means treat all as valid (for dot product attention) + encoder_attention_mask_img = None + else: + # If no image_seq_len is specified, treat all as text + encoder_hidden_states_img = None + encoder_hidden_states_text = encoder_hidden_states + encoder_attention_mask_img = None + + if self.qk_norm: + with self.conditional_named_scope("attn_q_norm"): + query_proj_text = self.norm_q(query_proj_raw) + else: + query_proj_text = query_proj_raw + + # Text K/V + with self.conditional_named_scope("proj_key"): + key_proj_text = self.key(encoder_hidden_states_text) + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj_text = self.norm_k(key_proj_text) + with self.conditional_named_scope("proj_value"): + value_proj_text = self.value(encoder_hidden_states_text) + + # Image K/V (only if image embeddings are present) + if encoder_hidden_states_img is not None: + with self.conditional_named_scope("add_proj_k"): + key_proj_img = self.add_k_proj(encoder_hidden_states_img) + with self.conditional_named_scope("norm_add_k"): + key_proj_img = self.norm_added_k(key_proj_img) + with self.conditional_named_scope("add_proj_v"): + value_proj_img = self.add_v_proj(encoder_hidden_states_img) + query_proj_img = query_proj_raw + # Check norm_added_k too + # Checkpointing + query_proj_text = checkpoint_name(query_proj_text, "query_proj") + key_proj_text = checkpoint_name(key_proj_text, "key_proj_text") + value_proj_text = checkpoint_name(value_proj_text, "value_proj_text") + key_proj_img = checkpoint_name(key_proj_img, "key_proj_img") + value_proj_img = checkpoint_name(value_proj_img, "value_proj_img") + query_proj_img = checkpoint_name(query_proj_img, "query_proj_img") + + + # Attention - tensors are (B, S, D) + with self.conditional_named_scope("cross_attn_text_apply"): + attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) + with self.conditional_named_scope("cross_attn_img_apply"): + # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens + attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img) + + attn_output = attn_output_text + attn_output_img + else: + # No image embeddings, only text cross-attention + query_proj_text = checkpoint_name(query_proj_text, "query_proj") + key_proj_text = checkpoint_name(key_proj_text, "key_proj_text") + value_proj_text = checkpoint_name(value_proj_text, "value_proj_text") + + with self.conditional_named_scope("cross_attn_text_apply"): + attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index ad34dd55..21c67e10 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -19,6 +19,8 @@ from typing import List, Union import jax from .modeling_flax_utils import get_activation +from ..models.attention_flax import NNXSimpleFeedForward +from ..models.normalization_flax import FP32LayerNorm def get_sinusoidal_embeddings( @@ -247,6 +249,59 @@ def get_1d_rotary_pos_embed( out = jnp.exp(1j * freqs) return out +class NNXWanImageEmbedding(nnx.Module): + def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, dtype: jnp.dtype, weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None, alignment: int = 128, flash_min_seq_length: int = 4096): + self.norm1 = FP32LayerNorm(rngs=rngs, dim=in_features, elementwise_affine=True, eps=1e-6) + self.ff = NNXSimpleFeedForward(rngs=rngs, dim=in_features, dim_out=out_features, mult=1, activation_fn="gelu", dtype=dtype, weights_dtype=weights_dtype, precision=precision) + self.norm2 = FP32LayerNorm(rngs=rngs, dim=out_features, elementwise_affine=True, eps=1e-6) + self.alignment = alignment + self.flash_min_seq_length = flash_min_seq_length + if pos_embed_seq_len is not None: + self.pos_embed = nnx.Param(jnp.zeros((1, pos_embed_seq_len, in_features), dtype=dtype)) + else: + self.pos_embed = nnx.data(None) + + def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, jax.Array]: + hidden_states = encoder_hidden_states_image + B, current_seq_len, D_in = hidden_states.shape + + if self.pos_embed is not None: + pe_len = self.pos_embed.value.shape[1] + add_len = min(current_seq_len, pe_len) + # Apply pos_embed to the original sequence length + hidden_states = hidden_states.at[:, :add_len, :].add(self.pos_embed.value[:, :add_len, :]) + if current_seq_len > pe_len: + print(f"[WARN] Input seq_len {current_seq_len} > pos_embed len {pe_len}") + + hidden_states = self.norm1(hidden_states) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + # hidden_states shape: (B, current_seq_len, out_features) + B, current_seq_len, D_out = hidden_states.shape + use_flash_attn = current_seq_len>=self.flash_min_seq_length + + if use_flash_attn: + # --- Dynamic Padding to nearest multiple of self.alignment --- + num_blocks = (current_seq_len + self.alignment - 1) // self.alignment + target_seq_len = num_blocks * self.alignment + else: + target_seq_len = current_seq_len + + # Create attention mask: 1 for real tokens, 0 for padded tokens + attention_mask = jnp.ones((B, current_seq_len), dtype=jnp.int32) + + if current_seq_len < target_seq_len: + padding_size = target_seq_len - current_seq_len + padding = jnp.zeros((B, padding_size, D_out), dtype=hidden_states.dtype) + hidden_states = jnp.concatenate([hidden_states, padding], axis=1) + + # Extend mask with zeros for padded positions + padding_mask = jnp.zeros((B, padding_size), dtype=jnp.int32) + attention_mask = jnp.concatenate([attention_mask, padding_mask], axis=1) + if not use_flash_attn: + attention_mask = None + return hidden_states, attention_mask + class NNXPixArtAlphaTextProjection(nnx.Module): diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index cb952afa..a18b127c 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -28,6 +28,7 @@ from ...modeling_flax_utils import FlaxModelMixin, get_activation from ....configuration_utils import ConfigMixin, register_to_config from ...embeddings_flax import ( + NNXWanImageEmbedding, get_1d_rotary_pos_embed, NNXFlaxTimesteps, NNXTimestepEmbedding, @@ -103,6 +104,7 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, + flash_min_seq_length: int = 4096 ): self.timesteps_proj = NNXFlaxTimesteps(dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0) self.time_embedder = NNXTimestepEmbedding( @@ -137,6 +139,19 @@ def __init__( act_fn="gelu_tanh", ) + self.image_embedder = nnx.data(None) + if image_embed_dim is not None: + self.image_embedder = NNXWanImageEmbedding( + rngs=rngs, + in_features=image_embed_dim, + out_features=dim, + pos_embed_seq_len=pos_embed_seq_len, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + flash_min_seq_length=flash_min_seq_length + ) + def __call__( self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None ): @@ -146,9 +161,10 @@ def __call__( timestep_proj = self.time_proj(self.act_fn(temb)) encoder_hidden_states = self.text_embedder(encoder_hidden_states) + encoder_attention_mask = None if encoder_hidden_states_image is not None: - raise NotImplementedError("currently img2vid is not supported") - return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + encoder_hidden_states_image, encoder_attention_mask = self.image_embedder(encoder_hidden_states_image) + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask class ApproximateGELU(nnx.Module): @@ -263,8 +279,8 @@ def __init__( qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, - # In torch, this is none, so it can be ignored. - # added_kv_proj_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + image_seq_len: Optional[int] = None, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, @@ -310,6 +326,8 @@ def __init__( dim_head=dim // num_heads, qk_norm=qk_norm, eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + image_seq_len=image_seq_len, flash_min_seq_length=flash_min_seq_length, flash_block_sizes=flash_block_sizes, mesh=mesh, @@ -357,6 +375,7 @@ def __call__( rotary_emb: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None, + encoder_attention_mask: Optional[jax.Array] = None, ): with self.conditional_named_scope("transformer_block"): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( @@ -393,6 +412,7 @@ def __call__( encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs, + encoder_attention_mask = encoder_attention_mask ) with self.conditional_named_scope("cross_attn_residual"): hidden_states = hidden_states + attn_output @@ -436,6 +456,7 @@ def __init__( added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, pos_embed_seq_len: Optional[int] = None, + image_seq_len: Optional[int] = None, flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, mesh: jax.sharding.Mesh = None, @@ -483,6 +504,7 @@ def __init__( text_embed_dim=text_dim, image_embed_dim=image_dim, pos_embed_seq_len=pos_embed_seq_len, + flash_min_seq_length=flash_min_seq_length ) # 3. Transformer blocks @@ -507,6 +529,8 @@ def init_block(rngs): dropout=dropout, mask_padding_tokens=mask_padding_tokens, enable_jax_named_scopes=enable_jax_named_scopes, + added_kv_proj_dim=added_kv_proj_dim, + image_seq_len=image_seq_len, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -525,6 +549,8 @@ def init_block(rngs): qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + image_seq_len=image_seq_len, flash_min_seq_length=flash_min_seq_length, flash_block_sizes=flash_block_sizes, mesh=mesh, @@ -583,20 +609,24 @@ def __call__( hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) with self.conditional_named_scope("condition_embedder"): - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: - raise NotImplementedError("img2vid is not yet implemented.") + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + if encoder_attention_mask is not None: + text_mask = jnp.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), dtype=jnp.int32) + encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) + encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) if self.scan_layers: def scan_fn(carry, block): hidden_states_carry, rngs_carry = carry hidden_states = block( - hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry + hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry, encoder_attention_mask ) new_carry = (hidden_states, rngs_carry) return new_carry, None @@ -617,7 +647,7 @@ def scan_fn(carry, block): for block in self.blocks: def layer_forward(hidden_states): - return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs) + return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs, encoder_attention_mask=encoder_attention_mask) rematted_layer_forward = self.gradient_checkpoint.apply( layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 191d8b61..7a4b8841 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -254,7 +254,30 @@ def load_base_wan_transformer( random_flax_state_dict[string_tuple] = flattened_dict[key] del flattened_dict for pt_key, tensor in tensors.items(): + # The diffusers implementation explicitly describes this key in keys to be ignored. + if "norm_added_q" in pt_key: + continue renamed_pt_key = rename_key(pt_key) + + if "condition_embedder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") + + if "image_embedder" in renamed_pt_key: + if "net.0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") + elif "net_0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") + if "net.2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") + renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") + if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("weight", "scale") + renamed_pt_key = renamed_pt_key.replace("kernel", "scale") + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 68c2ddab..0bc93f0c 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -13,8 +13,9 @@ # limitations under the License. from abc import abstractmethod -from typing import List, Union, Optional +from typing import List, Union, Optional, Tuple from functools import partial +from maxdiffusion.image_processor import PipelineImageInput import numpy as np import jax import jax.numpy as jnp @@ -39,6 +40,9 @@ import re import torch import qwix +from transformers import CLIPImageProcessor +from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModel +import PIL def cast_with_exclusion(path, x, dtype_to_cast): @@ -102,6 +106,13 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config = restored_checkpoint["wan_config"] else: wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) + if config.model_type == "I2V": + # WAN 2.1 I2V uses image embeddings via CLIP encoder (image_dim and added_kv_proj_dim are set) + # WAN 2.2 I2V uses VAE-encoded latent conditioning (image_dim and added_kv_proj_dim are None in the transformer config) + if config.model_name == "wan2.1": + if wan_config.get("image_seq_len") is None: + wan_config["image_seq_len"] = 257 + wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype @@ -201,6 +212,8 @@ def __init__( devices_array: np.array, mesh: Mesh, config: HyperParameters, + image_processor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[FlaxCLIPVisionModel] = None, ): self.tokenizer = tokenizer self.text_encoder = text_encoder @@ -212,6 +225,8 @@ def __init__( self.mesh = mesh self.config = config self.model_name = config.model_name + self.image_processor = image_processor + self.image_encoder = image_encoder self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -235,6 +250,20 @@ def load_tokenizer(cls, config: HyperParameters): ) return tokenizer + @classmethod + def load_image_encoder(cls, config: HyperParameters): + image_processor = CLIPImageProcessor.from_pretrained( + config.pretrained_model_name_or_path, subfolder="image_processor" + ) + try: + image_encoder = FlaxCLIPVisionModel.from_pretrained( + config.pretrained_model_name_or_path, subfolder="image_encoder", dtype=jnp.float32 + ) + except Exception as e: + max_logging.error(f"Failed to load FlaxCLIPVisionModel: {e}") + raise + return image_processor, image_encoder + @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): @@ -371,6 +400,18 @@ def load_scheduler(cls, config): ) return scheduler, scheduler_state + def encode_image(self, image: PipelineImageInput, num_videos_per_prompt: int = 1): + if not isinstance(image, list): + image = [image] + image_inputs = self.image_processor(images=image, return_tensors="np") + pixel_values = jnp.array(image_inputs.pixel_values) + + image_encoder_output = self.image_encoder(pixel_values, output_hidden_states=True) + image_embeds = image_encoder_output.hidden_states[-2] + + image_embeds = jnp.repeat(image_embeds, num_videos_per_prompt, axis=0) + return image_embeds + def _get_t5_prompt_embeds( self, @@ -460,6 +501,47 @@ def prepare_latents( return latents + def prepare_latents_i2v_base( + self, + image: jax.Array, + num_frames: int, + dtype: jnp.dtype, + last_image: Optional[jax.Array] = None, + ) -> Tuple[jax.Array, jax.Array]: + """ + Encodes the initial image(s) into latents to be used as conditioning. + Returns: + latent_condition: The VAE encoded latents of the image(s). + video_condition: The input to the VAE. + """ + height, width = image.shape[-2:] + image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W] + + if last_image is None: + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 + ) + else: + last_image = last_image[:, :, jnp.newaxis, :, :] + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], axis=2 + ) + + vae_dtype = getattr(self.vae, "dtype", jnp.float32) + video_condition = video_condition.astype(vae_dtype) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() + + # Normalize latents + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) + latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latent_condition = encoded_output + latent_condition = latent_condition.astype(dtype) + latent_condition = (latent_condition - latents_mean) / latents_std + + return latent_condition, video_condition + def _denormalize_latents(self, latents: jax.Array) -> jax.Array: """Denormalizes latents using VAE statistics.""" latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) @@ -479,7 +561,7 @@ def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: return self.video_processor.postprocess_video(video, output_type="np") @classmethod - def _create_common_components(cls, config, vae_only=False): + def _create_common_components(cls, config, vae_only=False, i2v=False): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) @@ -491,13 +573,16 @@ def _create_common_components(cls, config, vae_only=False): components = { "vae": wan_vae, "vae_cache": vae_cache, "devices_array": devices_array, "rngs": rngs, "mesh": mesh, - "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None + "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None, + "image_processor": None, "image_encoder": None } if not vae_only: components["tokenizer"] = cls.load_tokenizer(config=config) components["text_encoder"] = cls.load_text_encoder(config=config) components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) + if i2v and config.model_name == 'wan2.1': + components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) return components @abstractmethod @@ -505,6 +590,73 @@ def _get_num_channel_latents(self) -> int: """Returns the number of input channels for the transformer.""" pass + def _prepare_model_inputs_i2v( + self, + prompt: Union[str, List[str]], + image: Union[PIL.Image.Image, List[PIL.Image.Image]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + image_embeds: Optional[jax.Array] = None, + last_image: Optional[PIL.Image.Image] = None, + ): + if prompt is not None and isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt + effective_batch_size = batch_size * num_videos_per_prompt + + # 1. Encode Prompts + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + + # 2. Encode Image (only for WAN 2.1 I2V which uses CLIP image embeddings) + # WAN 2.2 I2V does not use CLIP image embeddings, it uses VAE latent conditioning instead + transformer_dtype = self.config.activations_dtype + + if self.config.model_name == "wan2.1": + # WAN 2.1 I2V: Use CLIP image encoder + if image_embeds is None: + images_to_encode = [image] + if last_image is None: + images_to_encode = [image] + else: + images_to_encode = [image, last_image] + image_embeds = self.encode_image(images_to_encode, num_videos_per_prompt=num_videos_per_prompt) + self.image_seq_len = image_embeds.shape[1] + + if batch_size > 1: + image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1)) + + image_embeds = image_embeds.astype(transformer_dtype) + else: + # WAN 2.2 I2V: No CLIP image embeddings, set to None or empty tensor + # The actual image conditioning happens via VAE latents in prepare_latents + image_embeds = None + prompt_embeds = prompt_embeds.astype(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype) + + # Use same sharding logic as T2V pipeline for consistent behavior + data_sharding = NamedSharding(self.mesh, P()) + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + image_embeds = jax.device_put(image_embeds, data_sharding) + + return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size + + def _prepare_model_inputs( self, prompt: Union[str, List[str]] = None, @@ -585,14 +737,16 @@ def transformer_forward_pass( prompt_embeds, do_classifier_free_guidance, guidance_scale, + encoder_hidden_states_image=None, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image) if do_classifier_free_guidance: bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + noise_cond = noise_pred[:bsz] # First half = conditional + noise_uncond = noise_pred[bsz:] # Second half = unconditional + noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond) + latents = latents[:bsz] return noise_pred, latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 9efccf90..c0400f60 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -41,7 +41,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer" + subfolder="transformer_2" ) high_noise_transformer = super().load_transformer( devices_array=common_components["devices_array"], @@ -49,7 +49,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer_2" + subfolder="transformer" ) pipeline = cls( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py new file mode 100644 index 00000000..0380a07c --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -0,0 +1,290 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from maxdiffusion import max_logging +from maxdiffusion.image_processor import PipelineImageInput +from .wan_pipeline import WanPipeline, transformer_forward_pass +from ...models.wan.transformers.transformer_wan import WanModel +from typing import List, Union, Optional, Tuple +from ...pyconfig import HyperParameters +from functools import partial +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + +class WanPipelineI2V_2_1(WanPipeline): + """Pipeline for WAN 2.1 Image-to-Video.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.transformer = transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only, i2v=True) + transformer = None + if not vae_only: + if load_transformer: + transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + image_processor=common_components["image_processor"], + image_encoder=common_components["image_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + return pipeline, transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def prepare_latents( + self, + image: jax.Array, + batch_size: int, + height: int, + width: int, + num_frames: int, + dtype: jnp.dtype, + rng: jax.Array, + latents: Optional[jax.Array] = None, + last_image: Optional[jax.Array] = None, + num_videos_per_prompt: int = 1, + ) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: + + if hasattr(image, "detach"): + image = image.detach().cpu().numpy() + image = jnp.array(image) + + if last_image is not None: + if hasattr(last_image, "detach"): + last_image = last_image.detach().cpu().numpy() + last_image = jnp.array(last_image) + + if num_videos_per_prompt > 1: + image = jnp.repeat(image, num_videos_per_prompt, axis=0) + if last_image is not None: + last_image = jnp.repeat(last_image, num_videos_per_prompt, axis=0) + + num_channels_latents = self.vae.z_dim + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + + if latents is None: + latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) + else: + latents = latents.astype(dtype) + latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) + mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) + if last_image is None: + mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) + else: + mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) + mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) + mask_lat_size = mask_lat_size.reshape( + batch_size, + 1, + num_latent_frames, + self.vae_scale_factor_temporal, + latent_height, + latent_width + ) + mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) + condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) + return latents, condition, None + + + def __call__( + self, + prompt: Union[str, List[str]], + image: PipelineImageInput, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + image_embeds: Optional[jax.Array] = None, + last_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "np", + rng: Optional[jax.Array] = None, + ): + + height = height or self.config.height + width = width or self.config.width + num_frames = num_frames or self.config.num_frames + + # Validate and adjust num_frames to ensure proper reshaping in prepare_latents + if num_frames % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + f"Rounding {num_frames} to the nearest valid number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + max_logging.log(f"Adjusted num_frames to: {num_frames}") + num_frames = max(num_frames, 1) + + prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( + prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length, + prompt_embeds, negative_prompt_embeds, image_embeds, last_image + ) + + def _process_image_input(img_input, height, width, num_videos_per_prompt): + if img_input is None: + return None + tensor = self.video_processor.preprocess(img_input, height=height, width=width) + jax_array = jnp.array(tensor.cpu().numpy()) + if jax_array.ndim == 3: + jax_array = jax_array[None, ...] # Add batch dimension + if num_videos_per_prompt > 1: + jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) + return jax_array + + image_tensor = _process_image_input(image, height, width, effective_batch_size) + last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) + + if rng is None: + rng = jax.random.key(self.config.seed) + latents_rng, inference_rng = jax.random.split(rng) + + latents, condition, first_frame_mask = self.prepare_latents( + image=image_tensor, + batch_size=effective_batch_size, + height=height, + width=width, + num_frames=num_frames, + dtype=image_embeds.dtype, + rng=latents_rng, + latents=latents, + last_image=last_image_tensor, + num_videos_per_prompt=num_videos_per_prompt, + ) + + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + ) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + data_sharding = NamedSharding(self.mesh, P()) + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + + latents = jax.device_put(latents, data_sharding) + condition = jax.device_put(condition, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + image_embeds = jax.device_put(image_embeds, data_sharding) + if first_frame_mask is not None: + first_frame_mask = jax.device_put(first_frame_mask, data_sharding) + + p_run_inference = partial( + run_inference_2_1_i2v, + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + ) + + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + latents=latents, + condition=condition, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + scheduler_state=scheduler_state, + ) + latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) + latents = self._denormalize_latents(latents) + + if output_type == "latent": + return latents + return self._decode_latents_to_video(latents) + + +def run_inference_2_1_i2v( + graphdef, sharded_state, rest_of_state, + latents: jnp.array, + condition: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + image_embeds: jnp.array, + guidance_scale: float, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale > 1.0 + + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) + condition = jnp.concatenate([condition] * 2) + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + latents_input = latents + if do_classifier_free_guidance: + latents_input = jnp.concatenate([latents, latents], axis=0) + + latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + noise_pred, _ = transformer_forward_pass( + graphdef, sharded_state, rest_of_state, + latent_model_input, timestep, prompt_embeds, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale, + encoder_hidden_states_image=image_embeds, + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents, return_dict=False) + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py new file mode 100644 index 00000000..ab24a651 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -0,0 +1,310 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from maxdiffusion.image_processor import PipelineImageInput +from maxdiffusion import max_logging +from .wan_pipeline import WanPipeline, transformer_forward_pass +from ...models.wan.transformers.transformer_wan import WanModel +from typing import List, Union, Optional, Tuple +from ...pyconfig import HyperParameters +from functools import partial +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + +class WanPipelineI2V_2_2(WanPipeline): + """Pipeline for WAN 2.2 Image-to-Video.""" + def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer + self.boundary_ratio = config.boundary_ratio + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only, i2v=True) + low_noise_transformer, high_noise_transformer = None, None + if not vae_only: + if load_transformer: + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], mesh=common_components["mesh"], + rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], mesh=common_components["mesh"], + rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, + subfolder="transformer_2" + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], text_encoder=common_components["text_encoder"], + image_processor=common_components["image_processor"], image_encoder=common_components["image_encoder"], + low_noise_transformer=low_noise_transformer, high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], mesh=common_components["mesh"], + config=config, + ) + return pipeline, low_noise_transformer, high_noise_transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) + pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) + pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, _, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def prepare_latents( + self, + image: jax.Array, + batch_size: int, + height: int, + width: int, + num_frames: int, + dtype: jnp.dtype, + rng: jax.Array, + latents: Optional[jax.Array] = None, + last_image: Optional[jax.Array] = None, + num_videos_per_prompt: int = 1, +) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: + + if hasattr(image, "detach"): + image = image.detach().cpu().numpy() + image = jnp.array(image) + + if last_image is not None: + if hasattr(last_image, "detach"): + last_image = last_image.detach().cpu().numpy() + last_image = jnp.array(last_image) + + num_channels_latents = self.vae.z_dim + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + + if latents is None: + latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) + else: + latents = latents.astype(dtype) + + latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) + mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) + if last_image is None: + mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) + else: + mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) + mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) + mask_lat_size = mask_lat_size.reshape( + batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) + condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) + return latents, condition, None + + def __call__( + self, + prompt: Union[str, List[str]], + image: PipelineImageInput, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + image_embeds: Optional[jax.Array] = None, + last_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "np", + rng: Optional[jax.Array] = None, + ): + height = height or self.config.height + width = width or self.config.width + num_frames = num_frames or self.config.num_frames + + if num_frames % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + f"Rounding {num_frames} to the nearest valid number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + max_logging.log(f"Adjusted num_frames to: {num_frames}") + num_frames = max(num_frames, 1) + + prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( + prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length, + prompt_embeds, negative_prompt_embeds, image_embeds, last_image + ) + def _process_image_input(img_input, height, width, num_videos_per_prompt): + if img_input is None: + return None + tensor = self.video_processor.preprocess(img_input, height=height, width=width) + jax_array = jnp.array(tensor.cpu().numpy()) + if jax_array.ndim == 3: + jax_array = jax_array[None, ...] # Add batch dimension + if num_videos_per_prompt > 1: + jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) + return jax_array + + image_tensor = _process_image_input(image, height, width, effective_batch_size) + last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) + + if rng is None: + rng = jax.random.key(self.config.seed) + latents_rng, inference_rng = jax.random.split(rng) + + # For WAN 2.2, image_embeds may be None (no CLIP image encoder) + # Use prompt_embeds dtype as fallback + latents_dtype = image_embeds.dtype if image_embeds is not None else prompt_embeds.dtype + + latents, condition, first_frame_mask = self.prepare_latents( + image=image_tensor, + batch_size=effective_batch_size, + height=height, + width=width, + num_frames=num_frames, + dtype=latents_dtype, + rng=latents_rng, + latents=latents, + last_image=last_image_tensor, + ) + + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + ) + + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) + data_sharding = NamedSharding(self.mesh, P()) + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + latents = jax.device_put(latents, data_sharding) + condition = jax.device_put(condition, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + # WAN 2.2 I2V doesn't use image_embeds (it's None), but we still need to pass it to the function + if image_embeds is not None: + image_embeds = jax.device_put(image_embeds, data_sharding) + if first_frame_mask is not None: + first_frame_mask = jax.device_put(first_frame_mask, data_sharding) + + + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + + p_run_inference = partial( + run_inference_2_2_i2v, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary_timestep, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + image_embeds=image_embeds, + ) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + low_noise_graphdef=low_noise_graphdef, low_noise_state=low_noise_state, low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, high_noise_state=high_noise_state, high_noise_rest=high_noise_rest, + latents=latents, condition=condition, + prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + scheduler_state=scheduler_state, + ) + latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) + latents = self._denormalize_latents(latents) + + if output_type == "latent": + return latents + return self._decode_latents_to_video(latents) + +def run_inference_2_2_i2v( + low_noise_graphdef, low_noise_state, low_noise_rest, + high_noise_graphdef, high_noise_state, high_noise_rest, + latents: jnp.array, + condition: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + image_embeds: jnp.array, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + def high_noise_branch(operands): + latents_input, ts_input, pe_input, ie_input = operands + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + noise_pred, latents_out = transformer_forward_pass( + high_noise_graphdef, high_noise_state, high_noise_rest, + latents_input, ts_input, pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_high, + encoder_hidden_states_image=ie_input + ) + return noise_pred, latents_out + + def low_noise_branch(operands): + latents_input, ts_input, pe_input, ie_input = operands + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + noise_pred, latents_out = transformer_forward_pass( + low_noise_graphdef, low_noise_state, low_noise_rest, + latents_input, ts_input, pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_low, + encoder_hidden_states_image=ie_input + ) + return noise_pred, latents_out + + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + # WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder + if image_embeds is not None: + image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) + condition = jnp.concatenate([condition] * 2) + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + latents_input = latents + if do_classifier_free_guidance: + latents_input = jnp.concatenate([latents, latents], axis=0) + latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + use_high_noise = jnp.greater_equal(t, boundary) + noise_pred, _ = jax.lax.cond( + use_high_noise, + high_noise_branch, + low_noise_branch, + (latent_model_input, timestep, prompt_embeds, image_embeds) + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index 719716d7..81a38670 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -15,6 +15,9 @@ from unittest.mock import patch, MagicMock from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 +from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1 +from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 +from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 class WanCheckpointer2_1Test(unittest.TestCase): """Tests for WAN 2.1 checkpointer.""" @@ -234,6 +237,215 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m self.assertEqual(opt_state["learning_rate"], 0.002) self.assertEqual(step, 1) +class WanCheckpointerI2V_2_1Test(unittest.TestCase): + """Tests for WAN 2.1 I2V checkpointer.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_i2v_checkpoint_test" + self.config.dataset_type = "test_dataset" + self.config.model_type = "I2V" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch.object(WanPipelineI2V_2_1, "from_pretrained", autospec=True) + def test_load_from_diffusers(self, mock_from_pretrained, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointerI2V_2_1(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch.object(WanPipelineI2V_2_1, "from_checkpoint", autospec=True) + def test_load_checkpoint_no_optimizer(self, mock_from_checkpoint, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["wan_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointerI2V_2_1(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once() + mock_from_checkpoint.assert_called_once_with(self.config, restored_mock) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch.object(WanPipelineI2V_2_1, "from_checkpoint", autospec=True) + def test_load_checkpoint_with_optimizer(self, mock_from_checkpoint, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["wan_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointerI2V_2_1(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once() + mock_from_checkpoint.assert_called_once_with(self.config, restored_mock) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + self.assertEqual(step, 1) + +class WanCheckpointerI2V_2_2Test(unittest.TestCase): + """Tests for WAN 2.2 I2V checkpointer.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_i2v_2_2_checkpoint_test" + self.config.dataset_type = "test_dataset" + self.config.model_type = "I2V" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") + def test_load_from_diffusers(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline_i2v_2p2.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointerI2V_2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline_i2v_2p2.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline_i2v_2p2.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointerI2V_2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once() + mock_wan_pipeline_i2v_2p2.from_checkpoint.assert_called_once_with(self.config, restored_mock) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") + def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline_i2v_2p2.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointerI2V_2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once() + mock_wan_pipeline_i2v_2p2.from_checkpoint.assert_called_once_with(self.config, restored_mock) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2.WanPipelineI2V_2_2") + def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline_i2v_2p2, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.002}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline_i2v_2p2.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointerI2V_2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once() + mock_wan_pipeline_i2v_2p2.from_checkpoint.assert_called_once_with(self.config, restored_mock) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.002) + self.assertEqual(step, 1) class WanCheckpointerEdgeCasesTest(unittest.TestCase): """Tests for edge cases and error handling."""