Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 90 additions & 4 deletions src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2
from .. import max_logging
import orbax.checkpoint as ocp
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
from flax import nnx
from maxdiffusion.checkpointing.checkpointing_utils import (
add_sharding_to_struct,
get_cpu_mesh_and_sharding,
create_orbax_checkpoint_manager,
WAN_CHECKPOINT,
)
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer


Expand Down Expand Up @@ -83,20 +89,100 @@ def load_diffusers_checkpoint(self):
pipeline = WanPipeline2_2.from_pretrained(self.config)
return pipeline

def _get_pretrained_orbax_dir(self) -> str:
return getattr(self.config, "pretrained_orbax_dir", "")
Comment on lines +92 to +93
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this function? We could just directly use getattr(self.config, "pretrained_orbax_dir", "") as it is only needed in one place. Also should the default value be None instead of ""


def save_pretrained_checkpoint(self, pretrained_dir: str, pipeline: WanPipeline2_2):
"""Save pretrained weights (no optimizer state) to orbax for fast subsequent loads."""
max_logging.log(f"Saving pretrained WAN 2.2 weights to orbax at {pretrained_dir}")
pretrained_mgr = create_orbax_checkpoint_manager(
pretrained_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=WAN_CHECKPOINT,
use_async=False,
)
_, low_state, _ = nnx.split(pipeline.low_noise_transformer, nnx.Param, ...)
_, high_state, _ = nnx.split(pipeline.high_noise_transformer, nnx.Param, ...)
low_params = low_state.to_pure_dict()
high_params = high_state.to_pure_dict()
wan_config = json.loads(pipeline.low_noise_transformer.to_json_string())
pretrained_mgr.save(
0,
args=ocp.args.Composite(
wan_config=ocp.args.JsonSave(wan_config),
low_noise_transformer_state=ocp.args.StandardSave(low_params),
high_noise_transformer_state=ocp.args.StandardSave(high_params),
),
)
pretrained_mgr.wait_until_finished()
max_logging.log(f"Pretrained weights saved to {pretrained_dir}")

def load_pretrained_from_orbax(self, pretrained_dir: str) -> Tuple[Optional[object], Optional[int]]:
"""Load pretrained weights from orbax cache if available."""
try:
pretrained_mgr = create_orbax_checkpoint_manager(
pretrained_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=WAN_CHECKPOINT,
use_async=False,
)
step = pretrained_mgr.latest_step()
if step is None:
max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}")
return None, None
max_logging.log(f"Found pretrained orbax checkpoint (step {step}) in {pretrained_dir}")
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
metadatas = pretrained_mgr.item_metadata(step)
low_meta = metadatas.low_noise_transformer_state
high_meta = metadatas.high_noise_transformer_state
target_shardings_low = jax.tree_util.tree_map(lambda x: replicated_sharding, low_meta)
target_shardings_high = jax.tree_util.tree_map(lambda x: replicated_sharding, high_meta)
with mesh:
abstract_low = jax.tree_util.tree_map(add_sharding_to_struct, low_meta, target_shardings_low)
abstract_high = jax.tree_util.tree_map(add_sharding_to_struct, high_meta, target_shardings_high)
max_logging.log("Restoring pretrained WAN 2.2 weights from orbax")
restored = pretrained_mgr.restore(
step,
args=ocp.args.Composite(
wan_config=ocp.args.JsonRestore(),
low_noise_transformer_state=ocp.args.StandardRestore(abstract_low),
high_noise_transformer_state=ocp.args.StandardRestore(abstract_high),
),
)
return restored, step
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}")
return None, None

def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]:
pretrained_dir = self._get_pretrained_orbax_dir()

# 1. Fast path: load from pretrained orbax cache (skips diffusers entirely).
if pretrained_dir:
restored, loaded_step = self.load_pretrained_from_orbax(pretrained_dir)
if restored is not None:
max_logging.log("Loading WAN 2.2 pipeline from pretrained orbax checkpoint")
pipeline = WanPipeline2_2.from_checkpoint(self.config, restored)
return pipeline, None, loaded_step

# 2. Try training checkpoint from checkpoint_dir.
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
max_logging.log("Loading WAN pipeline from training checkpoint")
pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint)
# Check for optimizer state in either transformer
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
# 3. Slow path: load from diffusers, then cache to orbax for next time.
max_logging.log("No checkpoint found, loading pipeline from diffusers.")
pipeline = self.load_diffusers_checkpoint()
if pretrained_dir:
self.save_pretrained_checkpoint(pretrained_dir, pipeline)

return pipeline, opt_state, step

Expand Down
94 changes: 90 additions & 4 deletions src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2
from .. import max_logging
import orbax.checkpoint as ocp
from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding
from flax import nnx
from maxdiffusion.checkpointing.checkpointing_utils import (
add_sharding_to_struct,
get_cpu_mesh_and_sharding,
create_orbax_checkpoint_manager,
WAN_CHECKPOINT,
)
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer


Expand Down Expand Up @@ -83,20 +89,100 @@ def load_diffusers_checkpoint(self):
pipeline = WanPipelineI2V_2_2.from_pretrained(self.config)
return pipeline

def _get_pretrained_orbax_dir(self) -> str:
return getattr(self.config, "pretrained_orbax_dir", "")
Comment on lines +92 to +93
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


def save_pretrained_checkpoint(self, pretrained_dir: str, pipeline: WanPipelineI2V_2_2):
"""Save pretrained weights (no optimizer state) to orbax for fast subsequent loads."""
max_logging.log(f"Saving pretrained WAN 2.2 I2V weights to orbax at {pretrained_dir}")
pretrained_mgr = create_orbax_checkpoint_manager(
pretrained_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=WAN_CHECKPOINT,
use_async=False,
)
_, low_state, _ = nnx.split(pipeline.low_noise_transformer, nnx.Param, ...)
_, high_state, _ = nnx.split(pipeline.high_noise_transformer, nnx.Param, ...)
low_params = low_state.to_pure_dict()
high_params = high_state.to_pure_dict()
wan_config = json.loads(pipeline.low_noise_transformer.to_json_string())
pretrained_mgr.save(
0,
args=ocp.args.Composite(
wan_config=ocp.args.JsonSave(wan_config),
low_noise_transformer_state=ocp.args.StandardSave(low_params),
high_noise_transformer_state=ocp.args.StandardSave(high_params),
),
)
pretrained_mgr.wait_until_finished()
max_logging.log(f"Pretrained weights saved to {pretrained_dir}")

def load_pretrained_from_orbax(self, pretrained_dir: str) -> Tuple[Optional[object], Optional[int]]:
"""Load pretrained weights from orbax cache if available."""
try:
pretrained_mgr = create_orbax_checkpoint_manager(
pretrained_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=WAN_CHECKPOINT,
use_async=False,
)
step = pretrained_mgr.latest_step()
if step is None:
max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}")
return None, None
max_logging.log(f"Found pretrained orbax checkpoint (step {step}) in {pretrained_dir}")
mesh, replicated_sharding = get_cpu_mesh_and_sharding()
metadatas = pretrained_mgr.item_metadata(step)
low_meta = metadatas.low_noise_transformer_state
high_meta = metadatas.high_noise_transformer_state
target_shardings_low = jax.tree_util.tree_map(lambda x: replicated_sharding, low_meta)
target_shardings_high = jax.tree_util.tree_map(lambda x: replicated_sharding, high_meta)
with mesh:
abstract_low = jax.tree_util.tree_map(add_sharding_to_struct, low_meta, target_shardings_low)
abstract_high = jax.tree_util.tree_map(add_sharding_to_struct, high_meta, target_shardings_high)
max_logging.log("Restoring pretrained WAN 2.2 I2V weights from orbax")
restored = pretrained_mgr.restore(
step,
args=ocp.args.Composite(
wan_config=ocp.args.JsonRestore(),
low_noise_transformer_state=ocp.args.StandardRestore(abstract_low),
high_noise_transformer_state=ocp.args.StandardRestore(abstract_high),
),
)
return restored, step
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}")
return None, None

def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]:
pretrained_dir = self._get_pretrained_orbax_dir()

# 1. Fast path: load from pretrained orbax cache (skips diffusers entirely).
if pretrained_dir:
restored, loaded_step = self.load_pretrained_from_orbax(pretrained_dir)
if restored is not None:
max_logging.log("Loading WAN 2.2 I2V pipeline from pretrained orbax checkpoint")
pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored)
return pipeline, None, loaded_step

# 2. Try training checkpoint from checkpoint_dir.
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
max_logging.log("Loading WAN pipeline from training checkpoint")
pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint)
# Check for optimizer state in either transformer
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
# 3. Slow path: load from diffusers, then cache to orbax for next time.
max_logging.log("No checkpoint found, loading pipeline from diffusers.")
pipeline = self.load_diffusers_checkpoint()
if pretrained_dir:
self.save_pretrained_checkpoint(pretrained_dir, pipeline)

return pipeline, opt_state, step

Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ names_which_can_be_offloaded: []
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
checkpoint_dir: ""
# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads.
# On first run (slow, diffusers load), weights are saved here automatically.
# On subsequent runs, weights are loaded from here instead (~10x faster).
pretrained_orbax_dir: ""
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
19 changes: 17 additions & 2 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,35 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
common_components = cls._create_common_components(config, vae_only)
low_noise_transformer, high_noise_transformer = None, None
if not vae_only and load_transformer:
# Restructure the combined checkpoint into per-transformer checkpoints.
# create_sharded_logical_transformer expects {"wan_config": ..., "wan_state": ...}.
if restored_checkpoint is not None:
low_noise_ckpt = {
"wan_config": restored_checkpoint["wan_config"],
"wan_state": restored_checkpoint["low_noise_transformer_state"],
}
high_noise_ckpt = {
"wan_config": restored_checkpoint["wan_config"],
"wan_state": restored_checkpoint["high_noise_transformer_state"],
}
else:
low_noise_ckpt = None
high_noise_ckpt = None

low_noise_transformer = super().load_transformer(
devices_array=common_components["devices_array"],
mesh=common_components["mesh"],
rngs=common_components["rngs"],
config=config,
restored_checkpoint=restored_checkpoint,
restored_checkpoint=low_noise_ckpt,
subfolder="transformer_2",
)
high_noise_transformer = super().load_transformer(
devices_array=common_components["devices_array"],
mesh=common_components["mesh"],
rngs=common_components["rngs"],
config=config,
restored_checkpoint=restored_checkpoint,
restored_checkpoint=high_noise_ckpt,
subfolder="transformer",
)

Expand Down
19 changes: 17 additions & 2 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,35 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
low_noise_transformer, high_noise_transformer = None, None
if not vae_only:
if load_transformer:
# Restructure the combined checkpoint into per-transformer checkpoints.
# create_sharded_logical_transformer expects {"wan_config": ..., "wan_state": ...}.
if restored_checkpoint is not None:
high_noise_ckpt = {
"wan_config": restored_checkpoint["wan_config"],
"wan_state": restored_checkpoint["high_noise_transformer_state"],
}
low_noise_ckpt = {
"wan_config": restored_checkpoint["wan_config"],
"wan_state": restored_checkpoint["low_noise_transformer_state"],
}
else:
high_noise_ckpt = None
low_noise_ckpt = None

high_noise_transformer = super().load_transformer(
devices_array=common_components["devices_array"],
mesh=common_components["mesh"],
rngs=common_components["rngs"],
config=config,
restored_checkpoint=restored_checkpoint,
restored_checkpoint=high_noise_ckpt,
subfolder="transformer",
)
low_noise_transformer = super().load_transformer(
devices_array=common_components["devices_array"],
mesh=common_components["mesh"],
rngs=common_components["rngs"],
config=config,
restored_checkpoint=restored_checkpoint,
restored_checkpoint=low_noise_ckpt,
subfolder="transformer_2",
)

Expand Down
Loading