diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index ca59d2a0..5b4c0b41 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -122,11 +122,15 @@ def from_pretrained( model_config.model_id = redirect_dict[model_config.origin_file_pattern][0] model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1] - # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) if use_usp: from ..utils.xfuser import initialize_usp initialize_usp(device) + import torch.distributed as dist + from ..core.device.npu_compatible_device import get_device_name + if dist.is_available() and dist.is_initialized(): + device = get_device_name() + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) model_pool = pipe.download_and_load_models(model_configs, vram_limit) # Fetch models