diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 52f607e10..43d8cd1de 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -527,8 +527,10 @@ def forward(self, clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) - + x = self.patchify(x) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index abf0f3fef..40b18821a 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -81,8 +81,10 @@ def usp_dit_forward(self, clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) - + x = self.patchify(x) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),