From 46e96dad411d1862fb95ac14b7b91045fb79f907 Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Fri, 15 May 2026 15:12:43 +0000 Subject: [PATCH] Add `text_encoder_dtype` and `compile_text_encoder` config parameters for Wan text encoders. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mitigates PR#397 by enabling the text encoder dtype to be configured separately from `weights_dtype`. Separating these parameters makes it possible to compile/load the text encoder in float32 (as it was before #397 by default) while keeping the model weights in bfloat16, and provides a configuration parameter to make the text encoder compilation optional (as it wasn’t compiled before #397). This addresses problems since some environments experience issues when `torch.compile` is run on it in hermetic/packaged setups. --- src/maxdiffusion/configs/base_wan_14b.yml | 5 +++++ src/maxdiffusion/configs/base_wan_1_3b.yml | 5 +++++ src/maxdiffusion/configs/base_wan_27b.yml | 5 +++++ src/maxdiffusion/configs/base_wan_animate.yml | 5 +++++ src/maxdiffusion/configs/base_wan_i2v_14b.yml | 5 +++++ src/maxdiffusion/configs/base_wan_i2v_27b.yml | 5 +++++ src/maxdiffusion/pipelines/wan/wan_pipeline.py | 6 ++++-- 7 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f432928a..319bfbc7 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 0e055265..3134ed93 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index bf29fa86..dfe300dd 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_animate.yml b/src/maxdiffusion/configs/base_wan_animate.yml index 8f95c855..7b3334c7 100644 --- a/src/maxdiffusion/configs/base_wan_animate.yml +++ b/src/maxdiffusion/configs/base_wan_animate.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index ca2d239a..f722e04e 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 90799524..0aa533b4 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa29..c304ee42 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -270,13 +270,15 @@ def __init__( @classmethod def load_text_encoder(cls, config: HyperParameters): - torch_dtype = getattr(torch, str(config.weights_dtype), torch.float32) + text_encoder_dtype = getattr(config, "text_encoder_dtype", "float32") + torch_dtype = getattr(torch, str(text_encoder_dtype), torch.float32) text_encoder = UMT5EncoderModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, ) - text_encoder = torch.compile(text_encoder) + if getattr(config, "compile_text_encoder", True): + text_encoder = torch.compile(text_encoder) return text_encoder @classmethod