Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a2456ae
WAN Img2Vid Implementation
prishajain1 Jan 10, 2026
0b6e69a
Removed randn_tensor function import
prishajain1 Jan 10, 2026
6e74e2b
logical_axis rules and attention_sharding_uniform added in config files
prishajain1 Jan 10, 2026
6a3262f
removed attn_mask from FlaxWanAttn call
prishajain1 Jan 10, 2026
9e5d828
fix to prevent load_image_encoder from running for wan 2.2 iv
prishajain1 Jan 10, 2026
520ed97
boundary_ratio removed from generate_wan.py
prishajain1 Jan 10, 2026
b40c1bb
testing with 720p
prishajain1 Jan 10, 2026
70d1780
model restored
prishajain1 Jan 10, 2026
bf6eccf
attn_mask correction
prishajain1 Jan 10, 2026
35be2e7
transformer corrected in wan 2.2 t2v and config files updated
prishajain1 Jan 10, 2026
2358605
revert
prishajain1 Jan 11, 2026
f9deaf3
corrected
prishajain1 Jan 11, 2026
1d48f8e
import added in wan_checkpointer_test.py
prishajain1 Jan 11, 2026
3522291
wan_checkpointer_test.py corrected
prishajain1 Jan 11, 2026
7df34b3
wan_checkpointer_test.py corrected
prishajain1 Jan 11, 2026
1e92718
wan_checkpointer_test.py corrected
prishajain1 Jan 11, 2026
231b379
removed redundance img attn mask
prishajain1 Jan 11, 2026
bafb313
Fix for multiple videos
prishajain1 Jan 12, 2026
04ded63
Fix for multiple videos
prishajain1 Jan 12, 2026
4c3f0b0
Fix for multiple videos
prishajain1 Jan 12, 2026
2d631b5
Fix for multiple videos
prishajain1 Jan 12, 2026
6f1ab12
removed redundant args
prishajain1 Jan 12, 2026
ca150aa
removed redundant args
prishajain1 Jan 12, 2026
264dcf1
trying dot attn fix
prishajain1 Jan 13, 2026
0c62ef1
reverting fix to see if that was the issue
prishajain1 Jan 13, 2026
ef6ead2
fix verified
prishajain1 Jan 13, 2026
e5c6324
updated comments
prishajain1 Jan 13, 2026
705434c
Added sharding
prishajain1 Jan 14, 2026
f9cc8f8
sharding added
prishajain1 Jan 14, 2026
52d415b
ruff checks
prishajain1 Jan 14, 2026
b0dab1a
README updated
prishajain1 Jan 14, 2026
3a4463b
sharding
prishajain1 Jan 14, 2026
7ec026b
ruff check
prishajain1 Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2026/1/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
- **`2025/10/14`**: NVIDIA DGX Spark Flux support.
Expand Down Expand Up @@ -482,19 +483,38 @@ To generate images, run the following command:

Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

### Text2Vid

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
```

### Img2Vid

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_14b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
```

## Wan2.2

Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

### Text2Vid

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
```

### Img2Vid

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_27b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
```

## Flux

First make sure you have permissions to access the Flux repos in Huggingface.
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
from .. import max_logging, max_utils
import orbax.checkpoint as ocp

Expand Down Expand Up @@ -59,7 +61,7 @@ def load_diffusers_checkpoint(self):
raise NotImplementedError

@abstractmethod
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]:
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int]]:
raise NotImplementedError

@abstractmethod
Expand Down
95 changes: 95 additions & 0 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import jax
import numpy as np
from typing import Optional, Tuple
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
from .. import max_logging
import orbax.checkpoint as ocp
from etils import epath
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer

class WanCheckpointerI2V_2_1(WanCheckpointer):

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
step = self.checkpoint_manager.latest_step()
max_logging.log(f"Latest WAN checkpoint step: {step}")
if step is None:
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")
metadatas = self.checkpoint_manager.item_metadata(step)
transformer_metadata = metadatas.wan_state
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_params,
)
)

max_logging.log("Restoring WAN checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
wan_state=params_restore,
wan_config=ocp.args.JsonRestore(),
),
)
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}")
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipelineI2V_2_1.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_1, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipelineI2V_2_1.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint.wan_state.keys():
opt_state = restored_checkpoint.wan_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step

def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_1, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

max_logging.log(f"Saving checkpoint for step {train_step}")
items = {
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
}

items["wan_state"] = ocp.args.PyTreeSave(train_states)

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
max_logging.log(f"Checkpoint for step {train_step} saved.")
114 changes: 114 additions & 0 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import jax
import numpy as np
from typing import Optional, Tuple
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
from .. import max_logging
import orbax.checkpoint as ocp
from etils import epath
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer

class WanCheckpointerI2V_2_2(WanCheckpointer):

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
step = self.checkpoint_manager.latest_step()
max_logging.log(f"Latest WAN checkpoint step: {step}")
if step is None:
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")
metadatas = self.checkpoint_manager.item_metadata(step)

# Handle low_noise_transformer
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
low_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_low_params,
)
)

# Handle high_noise_transformer
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
high_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_high_params,
)
)

max_logging.log("Restoring WAN 2.2 checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
low_noise_transformer_state=low_params_restore,
high_noise_transformer_state=high_params_restore,
wan_config=ocp.args.JsonRestore(),
),
)
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipelineI2V_2_2.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint)
# Check for optimizer state in either transformer
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step

def save_checkpoint(self, train_step, pipeline: WanPipelineI2V_2_2, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

max_logging.log(f"Saving checkpoint for step {train_step}")
items = {
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
}

items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
max_logging.log(f"Checkpoint for step {train_step} saved.")
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ log_period: 100

pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
model_name: wan2.1
model_type: 'T2V'

# Overrides the transformer from pretrained_model_name_or_path
wan_transformer_pretrained_model_name_or_path: ''
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ log_period: 100

pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers'
model_name: wan2.2
model_type: 'T2V'

# Overrides the transformer from pretrained_model_name_or_path
wan_transformer_pretrained_model_name_or_path: ''
Expand Down
Loading
Loading