From 63d874d7b7f0dcddbc669e886307a24c93bddf44 Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 2 Apr 2026 23:10:52 +0800 Subject: [PATCH 01/11] Add LongCat-AudioDiT pipeline Signed-off-by: Lancer --- docs/source/en/_toctree.yml | 2 + .../en/api/pipelines/longcat_audio_dit.md | 63 ++ docs/source/en/api/pipelines/overview.md | 1 + src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoder_longcat_audio_dit.py | 394 ++++++++++++ src/diffusers/models/transformers/__init__.py | 1 + .../transformer_longcat_audio_dit.py | 582 ++++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/longcat_audio_dit/__init__.py | 40 ++ .../pipeline_longcat_audio_dit.py | 432 +++++++++++++ ...st_models_transformer_longcat_audio_dit.py | 27 + .../test_longcat_audio_dit.py | 178 ++++++ 14 files changed, 1731 insertions(+) create mode 100644 docs/source/en/api/pipelines/longcat_audio_dit.md create mode 100644 src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py create mode 100644 src/diffusers/models/transformers/transformer_longcat_audio_dit.py create mode 100644 src/diffusers/pipelines/longcat_audio_dit/__init__.py create mode 100644 src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py create mode 100644 tests/models/transformers/test_models_transformer_longcat_audio_dit.py create mode 100644 tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7582a56505f7..885e2aa27181 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -488,6 +488,8 @@ title: AudioLDM 2 - local: api/pipelines/stable_audio title: Stable Audio + - local: api/pipelines/longcat_audio_dit + title: LongCat-AudioDiT title: Audio - sections: - local: api/pipelines/animatediff diff --git a/docs/source/en/api/pipelines/longcat_audio_dit.md b/docs/source/en/api/pipelines/longcat_audio_dit.md new file mode 100644 index 000000000000..b605bb4ae672 --- /dev/null +++ b/docs/source/en/api/pipelines/longcat_audio_dit.md @@ -0,0 +1,63 @@ + + +# LongCat-AudioDiT + +LongCat-AudioDiT is a text-to-audio diffusion model from Meituan LongCat. The diffusers integration exposes a standard [`DiffusionPipeline`] interface for text-conditioned audio generation. + +This pipeline supports loading the original flat LongCat checkpoint layout from either a local directory or a Hugging Face Hub repository containing: + +- `config.json` +- `model.safetensors` + +The loader builds the text encoder, transformer, and VAE from `config.json`, restores component weights from `model.safetensors`, and ties the shared UMT5 embedding when needed. + +This pipeline was adapted from the LongCat-AudioDiT reference implementation: https://github.com/meituan-longcat/LongCat-AudioDiT + +## Usage + +```py +import torch +from diffusers import LongCatAudioDiTPipeline + +repo_id = "" +tokenizer_path = os.environ["LONGCAT_AUDIO_DIT_TOKENIZER_PATH"] + +pipe = LongCatAudioDiTPipeline.from_pretrained( + repo_id, + tokenizer=tokenizer_path, + torch_dtype=torch.float16, + local_files_only=True, +) +pipe = pipe.to("cuda") + +audio = pipe( + prompt="A calm ocean wave ambience with soft wind in the background.", + audio_end_in_s=2.0, + num_inference_steps=16, + guidance_scale=4.0, + output_type="pt", +).audios +``` + +## Tips + +- `audio_end_in_s` is the most direct way to control output duration. +- `output_type="pt"` returns a PyTorch tensor shaped `(batch, channels, samples)`. +- If your tokenizer path is local-only, pass it explicitly to `from_pretrained(...)`. + +## LongCatAudioDiTPipeline + +[[autodoc]] LongCatAudioDiTPipeline + - all + - __call__ + - from_pretrained diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index c3e493c63d6a..2d5c4ff74039 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -29,6 +29,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an |---|---| | [AnimateDiff](animatediff) | text2video | | [AudioLDM2](audioldm2) | text2audio | +| [LongCat-AudioDiT](longcat_audio_dit) | text2audio | | [AuraFlow](aura_flow) | text2image | | [Bria 3.2](bria_3_2) | text2image | | [CogVideoX](cogvideox) | text2video | diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f74c0bbcb4a..b48a7f0a1c46 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -212,6 +212,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderKLWan", "AutoencoderOobleck", + "LongCatAudioDiTVae", "AutoencoderRAE", "AutoencoderTiny", "AutoencoderVidTok", @@ -253,6 +254,7 @@ "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LongCatImageTransformer2DModel", + "LongCatAudioDiTTransformer", "LTX2VideoTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", @@ -594,6 +596,7 @@ "LLaDA2PipelineOutput", "LongCatImageEditPipeline", "LongCatImagePipeline", + "LongCatAudioDiTPipeline", "LTX2ConditionPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", @@ -1007,6 +1010,7 @@ AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, + LongCatAudioDiTVae, AutoencoderRAE, AutoencoderTiny, AutoencoderVidTok, @@ -1048,6 +1052,7 @@ Kandinsky5Transformer3DModel, LatteTransformer3DModel, LongCatImageTransformer2DModel, + LongCatAudioDiTTransformer, LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -1365,6 +1370,7 @@ LLaDA2PipelineOutput, LongCatImageEditPipeline, LongCatImagePipeline, + LongCatAudioDiTPipeline, LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..2b24b53a7035 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -51,6 +51,7 @@ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] + _import_structure["autoencoders.autoencoder_longcat_audio_dit"] = ["LongCatAudioDiTVae"] _import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_vidtok"] = ["AutoencoderVidTok"] @@ -112,6 +113,7 @@ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] + _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 609146ec340d..803b27285a42 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -20,6 +20,7 @@ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan from .autoencoder_oobleck import AutoencoderOobleck +from .autoencoder_longcat_audio_dit import LongCatAudioDiTVae from .autoencoder_rae import AutoencoderRAE from .autoencoder_tiny import AutoencoderTiny from .autoencoder_vidtok import AutoencoderVidTok diff --git a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py new file mode 100644 index 000000000000..9ab0a0d27470 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py @@ -0,0 +1,394 @@ +# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from the LongCat-AudioDiT reference implementation: +# https://github.com/meituan-longcat/LongCat-AudioDiT + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.torch_utils import randn_tensor +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin + + +def _wn_conv1d(in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True): + return weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)) + + +def _wn_conv_transpose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class Snake1d(nn.Module): + def __init__(self, channels: int, alpha_logscale: bool = True): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter(torch.zeros(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + alpha = self.alpha[None, :, None] + beta = self.beta[None, :, None] + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + return hidden_states + (1.0 / (beta + 1e-9)) * torch.sin(hidden_states * alpha).pow(2) + + +def _get_vae_activation(name: str, channels: int = 0) -> nn.Module: + if name == "elu": + return nn.ELU() + if name == "snake": + return Snake1d(channels) + raise ValueError(f"Unknown activation: {name}") + + +def _pixel_unshuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: + batch, channels, width = hidden_states.size() + return ( + hidden_states.view(batch, channels, width // factor, factor) + .permute(0, 1, 3, 2) + .contiguous() + .view(batch, channels * factor, width // factor) + ) + + +def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: + batch, channels, width = hidden_states.size() + return ( + hidden_states.view(batch, channels // factor, factor, width) + .permute(0, 1, 3, 2) + .contiguous() + .view(batch, channels // factor, width * factor) + ) + + +class DownsampleShortcut(nn.Module): + def __init__(self, in_channels: int, out_channels: int, factor: int): + super().__init__() + self.factor = factor + self.group_size = in_channels * factor // out_channels + self.out_channels = out_channels + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = _pixel_unshuffle_1d(hidden_states, self.factor) + batch, _channels, width = hidden_states.shape + return hidden_states.view(batch, self.out_channels, self.group_size, width).mean(dim=2) + + +class UpsampleShortcut(nn.Module): + def __init__(self, in_channels: int, out_channels: int, factor: int): + super().__init__() + self.factor = factor + self.repeats = out_channels * factor // in_channels + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.repeat_interleave(self.repeats, dim=1) + return _pixel_shuffle_1d(hidden_states, self.factor) + + +class VaeResidualUnit(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, dilation: int, kernel_size: int = 7, use_snake: bool = False + ): + super().__init__() + padding = (dilation * (kernel_size - 1)) // 2 + activation = "snake" if use_snake else "elu" + self.layers = nn.Sequential( + _get_vae_activation(activation, channels=out_channels), + _wn_conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding), + _get_vae_activation(activation, channels=out_channels), + _wn_conv1d(out_channels, out_channels, kernel_size=1), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return hidden_states + self.layers(hidden_states) + + +class VaeEncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + use_snake: bool = False, + downsample_shortcut: str = "none", + ): + super().__init__() + layers = [ + VaeResidualUnit(in_channels, in_channels, dilation=1, use_snake=use_snake), + VaeResidualUnit(in_channels, in_channels, dilation=3, use_snake=use_snake), + VaeResidualUnit(in_channels, in_channels, dilation=9, use_snake=use_snake), + ] + activation = "snake" if use_snake else "elu" + layers.append(_get_vae_activation(activation, channels=in_channels)) + layers.append( + _wn_conv1d(in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) + ) + self.layers = nn.Sequential(*layers) + self.residual = ( + DownsampleShortcut(in_channels, out_channels, stride) if downsample_shortcut == "averaging" else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.residual is None: + return self.layers(hidden_states) + return self.layers(hidden_states) + self.residual(hidden_states) + + +class VaeDecoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + use_snake: bool = False, + upsample_shortcut: str = "none", + ): + super().__init__() + activation = "snake" if use_snake else "elu" + layers = [ + _get_vae_activation(activation, channels=in_channels), + _wn_conv_transpose1d( + in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2) + ), + VaeResidualUnit(out_channels, out_channels, dilation=1, use_snake=use_snake), + VaeResidualUnit(out_channels, out_channels, dilation=3, use_snake=use_snake), + VaeResidualUnit(out_channels, out_channels, dilation=9, use_snake=use_snake), + ] + self.layers = nn.Sequential(*layers) + self.residual = ( + UpsampleShortcut(in_channels, out_channels, stride) if upsample_shortcut == "duplicating" else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.residual is None: + return self.layers(hidden_states) + return self.layers(hidden_states) + self.residual(hidden_states) + + +class AudioDiTVaeEncoder(nn.Module): + def __init__( + self, + in_channels: int = 1, + channels: int = 128, + c_mults=None, + strides=None, + latent_dim: int = 64, + encoder_latent_dim: int = 128, + use_snake: bool = True, + downsample_shortcut: str = "averaging", + out_shortcut: str = "averaging", + ): + super().__init__() + c_mults = [1] + (c_mults or [1, 2, 4, 8, 16]) + strides = list(strides or [2] * (len(c_mults) - 1)) + if len(strides) < len(c_mults) - 1: + strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides))) + else: + strides = strides[: len(c_mults) - 1] + channels_base = channels + layers = [_wn_conv1d(in_channels, c_mults[0] * channels_base, kernel_size=7, padding=3)] + for idx in range(len(c_mults) - 1): + layers.append( + VaeEncoderBlock( + c_mults[idx] * channels_base, + c_mults[idx + 1] * channels_base, + strides[idx], + use_snake=use_snake, + downsample_shortcut=downsample_shortcut, + ) + ) + layers.append(_wn_conv1d(c_mults[-1] * channels_base, encoder_latent_dim, kernel_size=3, padding=1)) + self.layers = nn.Sequential(*layers) + self.shortcut = ( + DownsampleShortcut(c_mults[-1] * channels_base, encoder_latent_dim, 1) + if out_shortcut == "averaging" + else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.shortcut is None: + return self.layers(hidden_states) + hidden_states = self.layers[:-1](hidden_states) + return self.layers[-1](hidden_states) + self.shortcut(hidden_states) + + +class AudioDiTVaeDecoder(nn.Module): + def __init__( + self, + in_channels: int = 1, + channels: int = 128, + c_mults=None, + strides=None, + latent_dim: int = 64, + use_snake: bool = True, + in_shortcut: str = "duplicating", + final_tanh: bool = False, + upsample_shortcut: str = "duplicating", + ): + super().__init__() + c_mults = [1] + (c_mults or [1, 2, 4, 8, 16]) + strides = list(strides or [2] * (len(c_mults) - 1)) + if len(strides) < len(c_mults) - 1: + strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides))) + else: + strides = strides[: len(c_mults) - 1] + channels_base = channels + + self.shortcut = ( + UpsampleShortcut(latent_dim, c_mults[-1] * channels_base, 1) if in_shortcut == "duplicating" else None + ) + + layers = [_wn_conv1d(latent_dim, c_mults[-1] * channels_base, kernel_size=7, padding=3)] + for idx in range(len(c_mults) - 1, 0, -1): + layers.append( + VaeDecoderBlock( + c_mults[idx] * channels_base, + c_mults[idx - 1] * channels_base, + strides[idx - 1], + use_snake=use_snake, + upsample_shortcut=upsample_shortcut, + ) + ) + activation = "snake" if use_snake else "elu" + layers.append(_get_vae_activation(activation, channels=c_mults[0] * channels_base)) + layers.append(_wn_conv1d(c_mults[0] * channels_base, in_channels, kernel_size=7, padding=3, bias=False)) + layers.append(nn.Tanh() if final_tanh else nn.Identity()) + self.layers = nn.Sequential(*layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.shortcut is None: + return self.layers(hidden_states) + hidden_states = self.shortcut(hidden_states) + self.layers[0](hidden_states) + return self.layers[1:](hidden_states) + + +@dataclass +class LongCatAudioDiTVaeEncoderOutput(BaseOutput): + latents: torch.Tensor + + +@dataclass +class LongCatAudioDiTVaeDecoderOutput(BaseOutput): + sample: torch.Tensor + + +class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 1, + channels: int = 128, + c_mults=None, + strides=None, + latent_dim: int = 64, + encoder_latent_dim: int = 128, + use_snake: bool = True, + downsample_shortcut: str = "averaging", + upsample_shortcut: str = "duplicating", + out_shortcut: str = "averaging", + in_shortcut: str = "duplicating", + final_tanh: bool = False, + downsampling_ratio: int = 2048, + sample_rate: int = 24000, + scale: float = 0.71, + ): + super().__init__() + self.encoder = AudioDiTVaeEncoder( + in_channels=in_channels, + channels=channels, + c_mults=c_mults, + strides=strides, + latent_dim=latent_dim, + encoder_latent_dim=encoder_latent_dim, + use_snake=use_snake, + downsample_shortcut=downsample_shortcut, + out_shortcut=out_shortcut, + ) + self.decoder = AudioDiTVaeDecoder( + in_channels=in_channels, + channels=channels, + c_mults=c_mults, + strides=strides, + latent_dim=latent_dim, + use_snake=use_snake, + in_shortcut=in_shortcut, + final_tanh=final_tanh, + upsample_shortcut=upsample_shortcut, + ) + + @property + def sampling_rate(self) -> int: + return self.config.sample_rate + + def encode( + self, + sample: torch.Tensor, + sample_posterior: bool = True, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> LongCatAudioDiTVaeEncoderOutput | tuple[torch.Tensor]: + encoder_dtype = next(self.encoder.parameters()).dtype + if sample.dtype != encoder_dtype: + sample = sample.to(encoder_dtype) + encoded = self.encoder(sample) + mean, scale_param = encoded.chunk(2, dim=1) + std = F.softplus(scale_param) + 1e-4 + if sample_posterior: + noise = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype) + latents = mean + std * noise + else: + latents = mean + latents = latents / self.config.scale + if encoder_dtype != torch.float32: + latents = latents.float() + if not return_dict: + return (latents,) + return LongCatAudioDiTVaeEncoderOutput(latents=latents) + + def decode( + self, latents: torch.Tensor, return_dict: bool = True + ) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]: + decoder_dtype = next(self.decoder.parameters()).dtype + latents = latents * self.config.scale + if latents.dtype != decoder_dtype: + latents = latents.to(decoder_dtype) + decoded = self.decoder(latents) + if decoder_dtype != torch.float32: + decoded = decoded.float() + if not return_dict: + return (decoded,) + return LongCatAudioDiTVaeDecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]: + latents = self.encode(sample, sample_posterior=sample_posterior, return_dict=True, generator=generator).latents + decoded = self.decode(latents, return_dict=True).sample + if not return_dict: + return (decoded,) + return LongCatAudioDiTVaeDecoderOutput(sample=decoded) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..ae91c5a54e49 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -38,6 +38,7 @@ from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_ltx2 import LTX2VideoTransformer3DModel + from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py new file mode 100644 index 000000000000..bfcf9a8f7a3b --- /dev/null +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -0,0 +1,582 @@ +# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from the LongCat-AudioDiT reference implementation: +# https://github.com/meituan-longcat/LongCat-AudioDiT + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.torch_utils import maybe_allow_in_graph +from ..modeling_utils import ModelMixin + + +@dataclass +class LongCatAudioDiTTransformerOutput(BaseOutput): + sample: torch.Tensor + + +class AudioDiTRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + normalized = hidden_states.float() * torch.rsqrt( + hidden_states.float().pow(2).mean(dim=-1, keepdim=True) + self.eps + ) + return normalized.to(hidden_states.dtype) * self.weight + + +class AudioDiTSinusPositionEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, timesteps: torch.Tensor, scale: float = 1000.0) -> torch.Tensor: + device = timesteps.device + half_dim = self.dim // 2 + exponent = math.log(10000) / max(half_dim - 1, 1) + embeddings = torch.exp(torch.arange(half_dim, device=device).float() * -exponent) + embeddings = scale * timesteps.unsqueeze(1) * embeddings.unsqueeze(0) + return torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + + +class AudioDiTTimestepEmbedding(nn.Module): + def __init__(self, dim: int, freq_embed_dim: int = 256): + super().__init__() + self.time_embed = AudioDiTSinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep: torch.Tensor) -> torch.Tensor: + hidden_states = self.time_embed(timestep) + return self.time_mlp(hidden_states.to(timestep.dtype)) + + +class AudioDiTRotaryEmbedding(nn.Module): + def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 100000.0): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self._cos = None + self._sin = None + self._cached_len = 0 + self._cached_device = None + + def _build(self, seq_len: int, device: torch.device, dtype: torch.dtype): + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + steps = torch.arange(seq_len, dtype=torch.int64).type_as(inv_freq) + freqs = torch.outer(steps, inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + self._cos = embeddings.cos().to(dtype=dtype, device=device) + self._sin = embeddings.sin().to(dtype=dtype, device=device) + self._cached_len = seq_len + self._cached_device = device + + def forward(self, hidden_states: torch.Tensor, seq_len: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: + seq_len = hidden_states.shape[1] if seq_len is None else seq_len + if self._cos is None or seq_len > self._cached_len or self._cached_device != hidden_states.device: + self._build(max(seq_len, self.max_position_embeddings), hidden_states.device, hidden_states.dtype) + return self._cos[:seq_len].to(hidden_states.dtype), self._sin[:seq_len].to(hidden_states.dtype) + + +def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: + first, second = hidden_states.chunk(2, dim=-1) + return torch.cat((-second, first), dim=-1) + + +def _apply_rotary_emb(hidden_states: torch.Tensor, rope: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = rope + cos = cos[None, None].to(hidden_states.device) + sin = sin[None, None].to(hidden_states.device) + return (hidden_states.float() * cos + _rotate_half(hidden_states).float() * sin).to(hidden_states.dtype) + + +class AudioDiTGRN(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gx = torch.norm(hidden_states, p=2, dim=1, keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (hidden_states * nx) + self.beta + hidden_states + + +class AudioDiTConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + kernel_size: int = 7, + bias: bool = True, + eps: float = 1e-6, + ): + super().__init__() + padding = (dilation * (kernel_size - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=kernel_size, padding=padding, groups=dim, dilation=dilation, bias=bias + ) + self.norm = nn.LayerNorm(dim, eps=eps) + self.pwconv1 = nn.Linear(dim, intermediate_dim, bias=bias) + self.act = nn.SiLU() + self.grn = AudioDiTGRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.dwconv(hidden_states.transpose(1, 2)).transpose(1, 2) + hidden_states = self.norm(hidden_states) + hidden_states = self.pwconv1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.grn(hidden_states) + hidden_states = self.pwconv2(hidden_states) + return residual + hidden_states + + +class AudioDiTEmbedder(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.proj = nn.Sequential(nn.Linear(in_dim, out_dim), nn.SiLU(), nn.Linear(out_dim, out_dim)) + + def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None) -> torch.Tensor: + if mask is not None: + hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0) + hidden_states = self.proj(hidden_states) + if mask is not None: + hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0) + return hidden_states + + +class AudioDiTAdaLNMLP(nn.Module): + def __init__(self, in_dim: int, out_dim: int, bias: bool = True): + super().__init__() + self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(in_dim, out_dim, bias=bias)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(hidden_states) + + +class AudioDiTAdaLayerNormZeroFinal(nn.Module): + def __init__(self, dim: int, bias: bool = True, eps: float = 1e-6): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2, bias=bias) + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + def forward(self, hidden_states: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor: + embedding = self.linear(self.silu(embedding)) + scale, shift = torch.chunk(embedding, 2, dim=-1) + hidden_states = self.norm(hidden_states.float()).type_as(hidden_states) + if scale.ndim == 2: + hidden_states = hidden_states * (1 + scale)[:, None, :] + shift[:, None, :] + else: + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states + + +def _modulate( + hidden_states: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=eps).type_as(hidden_states) + if scale.ndim == 2: + return hidden_states * (1 + scale[:, None]) + shift[:, None] + return hidden_states * (1 + scale) + shift + + +def _masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + query_mask: torch.BoolTensor | None = None, + key_mask: torch.BoolTensor | None = None, +) -> torch.Tensor: + attn_mask = None + if key_mask is not None: + attn_mask = key_mask[:, None, None, :].expand(-1, query.shape[1], query.shape[2], -1) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + if query_mask is not None: + hidden_states = hidden_states * query_mask[:, None, :, None].to(hidden_states.dtype) + return hidden_states + + +class AudioDiTSelfAttention(nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + dropout: float = 0.0, + bias: bool = True, + qk_norm: bool = False, + eps: float = 1e-6, + ): + super().__init__() + self.heads = heads + self.inner_dim = dim_head * heads + self.to_q = nn.Linear(dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(dim, self.inner_dim, bias=bias) + self.qk_norm = qk_norm + if qk_norm: + self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, dim, bias=bias), nn.Dropout(dropout)]) + + def forward( + self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None, rope: tuple | None = None + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + if self.qk_norm: + query = self.q_norm(query) + key = self.k_norm(key) + head_dim = self.inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + if rope is not None: + query = _apply_rotary_emb(query, rope) + key = _apply_rotary_emb(key, rope) + hidden_states = _masked_attention(query, key, value, query_mask=mask, key_mask=mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim).to(query.dtype) + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class AudioDiTCrossAttention(nn.Module): + def __init__( + self, + q_dim: int, + kv_dim: int, + heads: int, + dim_head: int, + dropout: float = 0.0, + bias: bool = True, + qk_norm: bool = False, + eps: float = 1e-6, + ): + super().__init__() + self.heads = heads + self.inner_dim = dim_head * heads + self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias) + self.qk_norm = qk_norm + if qk_norm: + self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)]) + + def forward( + self, + hidden_states: torch.Tensor, + cond: torch.Tensor, + mask: torch.BoolTensor | None = None, + cond_mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + cond_rope: tuple | None = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + query = self.to_q(hidden_states) + key = self.to_k(cond) + value = self.to_v(cond) + if self.qk_norm: + query = self.q_norm(query) + key = self.k_norm(key) + head_dim = self.inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + if rope is not None: + query = _apply_rotary_emb(query, rope) + if cond_rope is not None: + key = _apply_rotary_emb(key, cond_rope) + hidden_states = _masked_attention(query, key, value, query_mask=mask, key_mask=cond_mask) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim).to(query.dtype) + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class AudioDiTFeedForward(nn.Module): + def __init__(self, dim: int, mult: float = 4.0, dropout: float = 0.0, bias: bool = True): + super().__init__() + inner_dim = int(dim * mult) + self.ff = nn.Sequential( + nn.Linear(dim, inner_dim, bias=bias), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim, bias=bias), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.ff(hidden_states) + + +@maybe_allow_in_graph +class AudioDiTBlock(nn.Module): + def __init__( + self, + dim: int, + cond_dim: int, + heads: int, + dim_head: int, + dropout: float = 0.0, + bias: bool = True, + qk_norm: bool = False, + eps: float = 1e-6, + cross_attn: bool = True, + cross_attn_norm: bool = False, + adaln_type: str = "global", + adaln_use_text_cond: bool = True, + ff_mult: float = 4.0, + ): + super().__init__() + self.adaln_type = adaln_type + self.adaln_use_text_cond = adaln_use_text_cond + if adaln_type == "local": + self.adaln_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True) + elif adaln_type == "global": + self.adaln_scale_shift = nn.Parameter(torch.randn(dim * 6) / dim**0.5) + self.self_attn = AudioDiTSelfAttention( + dim, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps + ) + self.use_cross_attn = cross_attn + if cross_attn: + self.cross_attn = AudioDiTCrossAttention( + dim, cond_dim, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps + ) + self.cross_attn_norm = ( + nn.LayerNorm(dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity() + ) + self.cross_attn_norm_c = ( + nn.LayerNorm(cond_dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity() + ) + self.ffn = AudioDiTFeedForward(dim=dim, mult=ff_mult, dropout=dropout, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + timestep_embed: torch.Tensor, + cond: torch.Tensor, + mask: torch.BoolTensor | None = None, + cond_mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + cond_rope: tuple | None = None, + adaln_global_out: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.adaln_type == "local" and adaln_global_out is None: + if self.adaln_use_text_cond: + denom = cond_mask.sum(1, keepdim=True).clamp(min=1).to(cond.dtype) + cond_mean = cond.sum(1) / denom + norm_cond = timestep_embed + cond_mean + else: + norm_cond = timestep_embed + adaln_out = self.adaln_mlp(norm_cond) + gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1) + else: + adaln_out = adaln_global_out + self.adaln_scale_shift.unsqueeze(0) + gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1) + + norm_hidden_states = _modulate(hidden_states, scale_sa, shift_sa) + attn_output = self.self_attn(norm_hidden_states, mask=mask, rope=rope) + hidden_states = hidden_states + gate_sa.unsqueeze(1) * attn_output + + if self.use_cross_attn: + cross_output = self.cross_attn( + hidden_states=self.cross_attn_norm(hidden_states), + cond=self.cross_attn_norm_c(cond), + mask=mask, + cond_mask=cond_mask, + rope=rope, + cond_rope=cond_rope, + ) + hidden_states = hidden_states + cross_output + + norm_hidden_states = _modulate(hidden_states, scale_ffn, shift_ffn) + ff_output = self.ffn(norm_hidden_states) + hidden_states = hidden_states + gate_ffn.unsqueeze(1) * ff_output + return hidden_states + + +class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + dit_dim: int = 1536, + dit_depth: int = 24, + dit_heads: int = 24, + dit_text_dim: int = 768, + latent_dim: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attn: bool = True, + adaln_type: str = "global", + adaln_use_text_cond: bool = True, + long_skip: bool = True, + text_conv: bool = True, + qk_norm: bool = True, + cross_attn_norm: bool = False, + eps: float = 1e-6, + use_latent_condition: bool = True, + ): + super().__init__() + dim = dit_dim + dim_head = dim // dit_heads + self.long_skip = long_skip + self.adaln_type = adaln_type + self.adaln_use_text_cond = adaln_use_text_cond + self.time_embed = AudioDiTTimestepEmbedding(dim) + self.input_embed = AudioDiTEmbedder(latent_dim, dim) + self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) + self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) + self.blocks = nn.ModuleList( + [ + AudioDiTBlock( + dim=dim, + cond_dim=dim, + heads=dit_heads, + dim_head=dim_head, + dropout=dropout, + bias=bias, + qk_norm=qk_norm, + eps=eps, + cross_attn=cross_attn, + cross_attn_norm=cross_attn_norm, + adaln_type=adaln_type, + adaln_use_text_cond=adaln_use_text_cond, + ff_mult=4.0, + ) + for _ in range(dit_depth) + ] + ) + self.norm_out = AudioDiTAdaLayerNormZeroFinal(dim, bias=bias, eps=eps) + self.proj_out = nn.Linear(dim, latent_dim) + if adaln_type == "global": + self.adaln_global_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True) + self.text_conv = text_conv + if text_conv: + self.text_conv_layer = nn.Sequential( + *[AudioDiTConvNeXtV2Block(dim, dim * 2, bias=bias, eps=eps) for _ in range(4)] + ) + self.use_latent_condition = use_latent_condition + if use_latent_condition: + self.latent_embed = AudioDiTEmbedder(latent_dim, dim) + self.latent_cond_embedder = AudioDiTEmbedder(dim * 2, dim) + self._initialize_weights(bias=bias) + + def _initialize_weights(self, bias: bool = True): + if self.adaln_type == "local": + for block in self.blocks: + nn.init.constant_(block.adaln_mlp.mlp[-1].weight, 0) + if bias: + nn.init.constant_(block.adaln_mlp.mlp[-1].bias, 0) + elif self.adaln_type == "global": + nn.init.constant_(self.adaln_global_mlp.mlp[-1].weight, 0) + if bias: + nn.init.constant_(self.adaln_global_mlp.mlp[-1].bias, 0) + nn.init.constant_(self.norm_out.linear.weight, 0) + nn.init.constant_(self.proj_out.weight, 0) + if bias: + nn.init.constant_(self.norm_out.linear.bias, 0) + nn.init.constant_(self.proj_out.bias, 0) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.BoolTensor, + timestep: torch.Tensor, + attention_mask: torch.BoolTensor | None = None, + latent_cond: torch.Tensor | None = None, + return_dict: bool = True, + ) -> LongCatAudioDiTTransformerOutput | tuple[torch.Tensor]: + dtype = next(self.parameters()).dtype + hidden_states = hidden_states.to(dtype) + encoder_hidden_states = encoder_hidden_states.to(dtype) + timestep = timestep.to(dtype) + batch_size = hidden_states.shape[0] + if timestep.ndim == 0: + timestep = timestep.repeat(batch_size) + timestep_embed = self.time_embed(timestep) + text_mask = encoder_attention_mask.bool() + encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) + if self.text_conv: + encoder_hidden_states = self.text_conv_layer(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.masked_fill(text_mask.logical_not().unsqueeze(-1), 0.0) + hidden_states = self.input_embed(hidden_states, attention_mask) + if self.use_latent_condition and latent_cond is not None: + latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask) + hidden_states = self.latent_cond_embedder(torch.cat([hidden_states, latent_cond], dim=-1)) + residual = hidden_states.clone() if self.long_skip else None + rope = self.rotary_embed(hidden_states, hidden_states.shape[1]) + cond_rope = self.rotary_embed(encoder_hidden_states, encoder_hidden_states.shape[1]) + if self.adaln_type == "global": + if self.adaln_use_text_cond: + text_len = text_mask.sum(1).clamp(min=1).to(encoder_hidden_states.dtype) + text_mean = encoder_hidden_states.sum(1) / text_len.unsqueeze(1) + norm_cond = timestep_embed + text_mean + else: + norm_cond = timestep_embed + adaln_global_out = self.adaln_global_mlp(norm_cond) + for block in self.blocks: + hidden_states = block( + hidden_states=hidden_states, + timestep_embed=timestep_embed, + cond=encoder_hidden_states, + mask=attention_mask, + cond_mask=text_mask, + rope=rope, + cond_rope=cond_rope, + adaln_global_out=adaln_global_out, + ) + else: + norm_cond = timestep_embed + for block in self.blocks: + hidden_states = block( + hidden_states=hidden_states, + timestep_embed=timestep_embed, + cond=encoder_hidden_states, + mask=attention_mask, + cond_mask=text_mask, + rope=rope, + cond_rope=cond_rope, + ) + if self.long_skip: + hidden_states = hidden_states + residual + hidden_states = self.norm_out(hidden_states, norm_cond) + hidden_states = self.proj_out(hidden_states) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(-1).to(hidden_states.dtype) + if not return_dict: + return (hidden_states,) + return LongCatAudioDiTTransformerOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 05aad6e349f6..154c28d6bc24 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -326,6 +326,7 @@ _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] _import_structure["longcat_image"] = ["LongCatImagePipeline", "LongCatImageEditPipeline"] + _import_structure["longcat_audio_dit"] = ["LongCatAudioDiTPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -751,6 +752,7 @@ ) from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline + from .longcat_audio_dit import LongCatAudioDiTPipeline from .ltx import ( LTXConditionPipeline, LTXI2VLongMultiPromptPipeline, diff --git a/src/diffusers/pipelines/longcat_audio_dit/__init__.py b/src/diffusers/pipelines/longcat_audio_dit/__init__.py new file mode 100644 index 000000000000..61cb89b4140f --- /dev/null +++ b/src/diffusers/pipelines/longcat_audio_dit/__init__.py @@ -0,0 +1,40 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure['pipeline_longcat_audio_dit'] = ['LongCatAudioDiTPipeline'] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_longcat_audio_dit import LongCatAudioDiTPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()['__file__'], _import_structure, module_spec=__spec__) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py new file mode 100644 index 000000000000..56d02cf49468 --- /dev/null +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -0,0 +1,432 @@ +# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from the LongCat-AudioDiT reference implementation: +# https://github.com/meituan-longcat/LongCat-AudioDiT + +import json +import re +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import validate_hf_hub_args +from safetensors.torch import load_file +from torch.nn.utils.rnn import pad_sequence +from transformers import PreTrainedTokenizerBase, T5Tokenizer, UMT5Config, UMT5EncoderModel + +from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae +from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +def _lens_to_mask(lengths: torch.Tensor, length: int | None = None) -> torch.BoolTensor: + if length is None: + length = int(lengths.amax().item()) + seq = torch.arange(length, device=lengths.device) + return seq[None, :] < lengths[:, None] + + +def _normalize_text(text: str) -> str: + text = text.lower() + text = re.sub(r'["“”‘’]', " ", text) + text = re.sub(r"\s+", " ", text) + return text.strip() + + +def _approx_duration_from_text(text: str, max_duration: float = 30.0) -> float: + en_dur_per_char = 0.082 + zh_dur_per_char = 0.21 + text = re.sub(r"\s+", "", text) + num_zh = num_en = num_other = 0 + for char in text: + if "一" <= char <= "鿿": + num_zh += 1 + elif char.isalpha(): + num_en += 1 + else: + num_other += 1 + if num_zh > num_en: + num_zh += num_other + else: + num_en += num_other + return min(max_duration, num_zh * zh_dur_per_char + num_en * en_dur_per_char) + + +def _approx_batch_duration_from_prompts(prompts: list[str]) -> float: + if not prompts: + return 0.0 + return max(_approx_duration_from_text(prompt) for prompt in prompts) + + +def _extract_prefixed_state_dict(state_dict: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: + prefix = f"{prefix}." + return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)} + + +def _load_longcat_tokenizer( + pretrained_model_name_or_path: str | Path, + text_encoder_model: str | None, + tokenizer: PreTrainedTokenizerBase | str | Path | None, + local_files_only: bool | None, + subfolder: str | None = None, +) -> PreTrainedTokenizerBase: + if isinstance(tokenizer, PreTrainedTokenizerBase): + return tokenizer + + tokenizer_source: str | Path | None = tokenizer + if tokenizer_source is None: + pretrained_path = Path(pretrained_model_name_or_path) + local_tokenizer_dir = pretrained_path / (subfolder or "") / "tokenizer" + if pretrained_path.exists() and local_tokenizer_dir.is_dir(): + tokenizer_source = local_tokenizer_dir + else: + tokenizer_source = text_encoder_model or pretrained_model_name_or_path + + if tokenizer_source is None: + raise ValueError("Could not determine tokenizer source for LongCatAudioDiT.") + + tokenizer_kwargs = {"local_files_only": local_files_only} + if not isinstance(tokenizer_source, Path) and tokenizer_source == pretrained_model_name_or_path and subfolder: + tokenizer_kwargs["subfolder"] = subfolder + return T5Tokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs) + + +def _resolve_longcat_file( + pretrained_model_name_or_path: str | Path, + filename: str, + cache_dir: str | Path | None = None, + force_download: bool = False, + proxies: dict[str, str] | None = None, + local_files_only: bool | None = None, + token: str | bool | None = None, + revision: str | None = None, + subfolder: str | None = None, + local_dir: str | Path | None = None, + local_dir_use_symlinks: str | bool = "auto", + user_agent: dict[str, str] | None = None, +) -> str: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if Path(pretrained_model_name_or_path).is_dir(): + base = Path(pretrained_model_name_or_path) + if subfolder is not None: + base = base / subfolder + file_path = base / filename + if not file_path.is_file(): + raise EnvironmentError(f"Error no file named {filename} found in directory {base}.") + return str(file_path) + + try: + return hf_hub_download( + pretrained_model_name_or_path, + filename=filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + subfolder=subfolder, + revision=revision, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent=user_agent, + ) + except Exception as err: + raise EnvironmentError( + f"Can't load {filename} for '{pretrained_model_name_or_path}'. If you were trying to load it from " + f"'{HUGGINGFACE_CO_RESOLVE_ENDPOINT}', make sure the repo exists or that your local path is correct." + ) from err + + +class LongCatAudioDiTPipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + vae: LongCatAudioDiTVae, + text_encoder: UMT5EncoderModel, + tokenizer: PreTrainedTokenizerBase, + transformer: LongCatAudioDiTTransformer, + ): + super().__init__() + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) + self.sample_rate = getattr(vae.config, "sample_rate", 24000) + self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048) + self.latent_dim = getattr(transformer.config, "latent_dim", 64) + self.max_wav_duration = 30.0 + self.text_norm_feat = True + self.text_add_embed = True + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + tokenizer: PreTrainedTokenizerBase | str | Path | None = None, + torch_dtype: torch.dtype | None = None, + local_files_only: bool | None = None, + **kwargs: Any, + ) -> "LongCatAudioDiTPipeline": + cache_dir = kwargs.pop("cache_dir", None) + local_dir = kwargs.pop("local_dir", None) + local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + if kwargs: + logger.warning("Ignoring unsupported LongCatAudioDiTPipeline.from_pretrained kwargs: %s", sorted(kwargs)) + + config_path = _resolve_longcat_file( + pretrained_model_name_or_path, + "config.json", + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent={"file_type": "config"}, + ) + weights_path = _resolve_longcat_file( + pretrained_model_name_or_path, + "model.safetensors", + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent={"file_type": "weights"}, + ) + + with open(config_path) as handle: + config = json.load(handle) + + text_encoder_config = UMT5Config.from_dict(config["text_encoder_config"]) + text_encoder = UMT5EncoderModel(text_encoder_config) + transformer = LongCatAudioDiTTransformer( + dit_dim=config["dit_dim"], + dit_depth=config["dit_depth"], + dit_heads=config["dit_heads"], + dit_text_dim=config["dit_text_dim"], + latent_dim=config["latent_dim"], + dropout=config.get("dit_dropout", 0.0), + bias=config.get("dit_bias", True), + cross_attn=config.get("dit_cross_attn", True), + adaln_type=config.get("dit_adaln_type", "global"), + adaln_use_text_cond=config.get("dit_adaln_use_text_cond", True), + long_skip=config.get("dit_long_skip", True), + text_conv=config.get("dit_text_conv", True), + qk_norm=config.get("dit_qk_norm", True), + cross_attn_norm=config.get("dit_cross_attn_norm", False), + eps=config.get("dit_eps", 1e-6), + use_latent_condition=config.get("dit_use_latent_condition", True), + ) + vae_config = dict(config["vae_config"]) + vae_config.pop("model_type", None) + vae = LongCatAudioDiTVae(**vae_config) + + state_dict = load_file(weights_path) + transformer.load_state_dict(_extract_prefixed_state_dict(state_dict, "transformer"), strict=True) + vae.load_state_dict(_extract_prefixed_state_dict(state_dict, "vae"), strict=True) + text_missing, text_unexpected = text_encoder.load_state_dict( + _extract_prefixed_state_dict(state_dict, "text_encoder"), strict=False + ) + allowed_missing = {"shared.weight"} + unexpected_missing = set(text_missing) - allowed_missing + if unexpected_missing: + raise RuntimeError(f"Unexpected missing LongCatAudioDiT text encoder weights: {sorted(unexpected_missing)}") + if text_unexpected: + raise RuntimeError(f"Unexpected LongCatAudioDiT text encoder weights: {sorted(text_unexpected)}") + if "shared.weight" in text_missing: + text_encoder.shared.weight.data.copy_(text_encoder.encoder.embed_tokens.weight.data) + + tokenizer = _load_longcat_tokenizer( + pretrained_model_name_or_path, + config.get("text_encoder_model"), + tokenizer, + local_files_only=local_files_only, + subfolder=subfolder, + ) + + if torch_dtype is not None: + text_encoder = text_encoder.to(dtype=torch_dtype) + transformer = transformer.to(dtype=torch_dtype) + vae = vae.to(dtype=torch_dtype) + + pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) + pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) + pipe.latent_hop = config.get("latent_hop", pipe.latent_hop) + pipe.max_wav_duration = config.get("max_wav_duration", pipe.max_wav_duration) + pipe.text_norm_feat = config.get("text_norm_feat", pipe.text_norm_feat) + pipe.text_add_embed = config.get("text_add_embed", pipe.text_add_embed) + return pipe + + def encode_prompt(self, prompt: str | list[str], device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(prompt, str): + prompt = [prompt] + model_max_length = getattr(self.tokenizer, "model_max_length", 512) + if not isinstance(model_max_length, int) or model_max_length <= 0 or model_max_length > 32768: + model_max_length = 512 + text_inputs = self.tokenizer( + prompt, + padding="longest", + truncation=True, + max_length=model_max_length, + return_tensors="pt", + ) + input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + with torch.no_grad(): + output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + prompt_embeds = output.last_hidden_state + if self.text_norm_feat: + prompt_embeds = F.layer_norm(prompt_embeds, (prompt_embeds.shape[-1],), eps=1e-6) + if self.text_add_embed and getattr(output, "hidden_states", None): + first_hidden = output.hidden_states[0] + if self.text_norm_feat: + first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6) + prompt_embeds = prompt_embeds + first_hidden + lengths = attention_mask.sum(dim=1).to(device) + return prompt_embeds.float(), lengths + + def prepare_latents( + self, + batch_size: int, + duration: int, + device: torch.device, + dtype: torch.dtype, + generator: torch.Generator | list[torch.Generator] | None = None, + ) -> torch.Tensor: + latents = [ + torch.randn( + duration, + self.latent_dim, + device=device, + dtype=dtype, + generator=generator if isinstance(generator, torch.Generator) else None, + ) + for _ in range(batch_size) + ] + return pad_sequence(latents, padding_value=0.0, batch_first=True) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + audio_end_in_s: float | None = None, + duration: int | None = None, + num_inference_steps: int = 16, + guidance_scale: float = 4.0, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str = "np", + return_dict: bool = True, + ): + if prompt is None: + prompt = [] + elif isinstance(prompt, str): + prompt = [prompt] + else: + prompt = list(prompt) + batch_size = len(prompt) + if batch_size == 0: + raise ValueError("`prompt` must contain at least one prompt.") + + device = self._execution_device + normalized_prompts = [_normalize_text(text) for text in prompt] + if duration is None: + if audio_end_in_s is not None: + duration = int(audio_end_in_s * self.sample_rate // self.latent_hop) + else: + duration = int( + _approx_batch_duration_from_prompts(normalized_prompts) * self.sample_rate // self.latent_hop + ) + max_duration = int(self.max_wav_duration * self.sample_rate // self.latent_hop) + duration = max(1, min(duration, max_duration)) + + text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) + duration_tensor = torch.full((batch_size,), duration, device=device, dtype=torch.long) + mask = _lens_to_mask(duration_tensor) + text_mask = _lens_to_mask(text_condition_len, length=text_condition.shape[1]) + + if negative_prompt is None: + neg_text = torch.zeros_like(text_condition) + neg_text_len = text_condition_len + neg_text_mask = text_mask + else: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + neg_text, neg_text_len = self.encode_prompt(negative_prompt, device) + neg_text_mask = _lens_to_mask(neg_text_len, length=neg_text.shape[1]) + + latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=text_condition.dtype) + latents = self.prepare_latents(batch_size, duration, device, text_condition.dtype, generator=generator) + num_inference_steps = max(2, num_inference_steps) + timesteps = torch.linspace(0, 1, num_inference_steps, device=device, dtype=text_condition.dtype) + sample = latents + + def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tensor: + pred = self.transformer( + hidden_states=current_sample, + encoder_hidden_states=text_condition, + encoder_attention_mask=text_mask, + timestep=curr_t.expand(batch_size), + attention_mask=mask, + latent_cond=latent_cond, + ).sample + if guidance_scale < 1e-5: + return pred + null_pred = self.transformer( + hidden_states=current_sample, + encoder_hidden_states=neg_text, + encoder_attention_mask=neg_text_mask, + timestep=curr_t.expand(batch_size), + attention_mask=mask, + latent_cond=latent_cond, + ).sample + return pred + (pred - null_pred) * guidance_scale + + for idx in range(len(timesteps) - 1): + curr_t = timesteps[idx] + dt = timesteps[idx + 1] - timesteps[idx] + sample = sample + model_step(curr_t, sample) * dt + + if output_type == "latent": + if not return_dict: + return (sample,) + return AudioPipelineOutput(audios=sample) + + waveform = self.vae.decode(sample.permute(0, 2, 1)).sample + if output_type == "np": + waveform = waveform.cpu().float().numpy() + elif output_type != "pt": + raise ValueError(f"Unsupported output_type: {output_type}") + + if not return_dict: + return (waveform,) + return AudioPipelineOutput(audios=waveform) diff --git a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py new file mode 100644 index 000000000000..64ab4f06ef44 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py @@ -0,0 +1,27 @@ +import torch + +from diffusers import LongCatAudioDiTTransformer + + +def test_longcat_audio_transformer_forward_shape(): + model = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + hidden_states = torch.randn(2, 16, 8) + encoder_hidden_states = torch.randn(2, 10, 32) + encoder_attention_mask = torch.ones(2, 10, dtype=torch.bool) + timestep = torch.tensor([1.0, 1.0]) + + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + ) + + assert output.sample.shape == hidden_states.shape diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py new file mode 100644 index 000000000000..974c0882c106 --- /dev/null +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -0,0 +1,178 @@ +import json +import os +from pathlib import Path + +import pytest +import torch +from safetensors.torch import save_file +from transformers import UMT5Config, UMT5EncoderModel + +from diffusers import LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae +from tests.testing_utils import require_torch_accelerator, slow, torch_device + + +class DummyTokenizer: + model_max_length = 16 + + def __call__(self, texts, padding="longest", truncation=True, max_length=None, return_tensors="pt"): + if isinstance(texts, str): + texts = [texts] + batch = len(texts) + return type( + "TokenBatch", + (), + { + "input_ids": torch.ones(batch, 4, dtype=torch.long), + "attention_mask": torch.ones(batch, 4, dtype=torch.long), + }, + ) + + +def _build_components(): + text_encoder = UMT5EncoderModel(UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=128)) + transformer = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + vae = LongCatAudioDiTVae( + in_channels=1, + channels=16, + c_mults=[1, 2], + strides=[2], + latent_dim=8, + encoder_latent_dim=16, + downsampling_ratio=2, + sample_rate=24000, + ) + return text_encoder, transformer, vae + + +def test_longcat_audio_dit_vae_import(): + assert LongCatAudioDiTVae is not None + + +def test_longcat_audio_pipeline_constructs(): + text_encoder, transformer, vae = _build_components() + pipe = LongCatAudioDiTPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=DummyTokenizer(), transformer=transformer + ) + assert pipe is not None + + +def test_longcat_audio_pipeline_forward_pt_output(): + text_encoder, transformer, vae = _build_components() + pipe = LongCatAudioDiTPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=DummyTokenizer(), transformer=transformer + ) + + result = pipe( + prompt="soft ocean ambience", audio_end_in_s=0.1, num_inference_steps=2, guidance_scale=1.0, output_type="pt" + ) + + assert result.audios.ndim == 3 + assert result.audios.shape[0] == 1 + assert result.audios.shape[1] == 1 + assert result.audios.shape[-1] > 0 + + +def test_longcat_audio_pipeline_from_pretrained_local_dir(tmp_path, monkeypatch): + text_encoder, transformer, vae = _build_components() + model_dir = tmp_path / "longcat-audio-dit" + model_dir.mkdir() + + config = { + "dit_dim": 64, + "dit_depth": 2, + "dit_heads": 4, + "dit_text_dim": 32, + "latent_dim": 8, + "dit_dropout": 0.0, + "dit_bias": True, + "dit_cross_attn": True, + "dit_adaln_type": "global", + "dit_adaln_use_text_cond": True, + "dit_long_skip": True, + "dit_text_conv": False, + "dit_qk_norm": True, + "dit_cross_attn_norm": False, + "dit_eps": 1e-6, + "dit_use_latent_condition": True, + "sampling_rate": 24000, + "latent_hop": 2, + "max_wav_duration": 30.0, + "text_norm_feat": True, + "text_add_embed": True, + "text_encoder_model": "dummy-umt5", + "text_encoder_config": text_encoder.config.to_dict(), + "vae_config": {**dict(vae.config), "model_type": "longcat_audio_dit_vae"}, + } + with (model_dir / "config.json").open("w") as handle: + json.dump(config, handle) + + state_dict = {} + state_dict.update({f"text_encoder.{k}": v for k, v in text_encoder.state_dict().items() if k != "shared.weight"}) + state_dict.update({f"transformer.{k}": v for k, v in transformer.state_dict().items()}) + state_dict.update({f"vae.{k}": v for k, v in vae.state_dict().items()}) + save_file(state_dict, model_dir / "model.safetensors") + + monkeypatch.setattr( + "diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit.T5Tokenizer.from_pretrained", + lambda *args, **kwargs: DummyTokenizer(), + ) + + pipe = LongCatAudioDiTPipeline.from_pretrained(model_dir, local_files_only=True) + result = pipe( + prompt="soft ocean ambience", audio_end_in_s=0.1, num_inference_steps=2, guidance_scale=1.0, output_type="pt" + ) + + assert isinstance(pipe, LongCatAudioDiTPipeline) + assert pipe.sample_rate == 24000 + assert pipe.latent_hop == 2 + assert result.audios.ndim == 3 + assert result.audios.shape[-1] > 0 + + +def test_longcat_audio_top_level_imports(): + assert LongCatAudioDiTPipeline is not None + assert LongCatAudioDiTTransformer is not None + assert LongCatAudioDiTVae is not None + + +@slow +@require_torch_accelerator +def test_longcat_audio_pipeline_from_pretrained_real_local_weights(): + model_path = Path(os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")) + tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH") + if tokenizer_path_env is None: + pytest.skip("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set") + tokenizer_path = Path(tokenizer_path_env) + + if not model_path.exists(): + pytest.skip(f"LongCat-AudioDiT model path not found: {model_path}") + if not tokenizer_path.exists(): + pytest.skip(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}") + + pipe = LongCatAudioDiTPipeline.from_pretrained( + model_path, + tokenizer=tokenizer_path, + torch_dtype=torch.float16, + local_files_only=True, + ) + pipe = pipe.to(torch_device) + + result = pipe( + prompt="A calm ocean wave ambience with soft wind in the background.", + audio_end_in_s=2.0, + num_inference_steps=2, + guidance_scale=4.0, + output_type="pt", + ) + + assert result.audios.ndim == 3 + assert result.audios.shape[0] == 1 + assert result.audios.shape[1] == 1 + assert result.audios.shape[-1] > 0 From d2a2621b27307d65d3259c857e810fc979c1d46a Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 2 Apr 2026 23:35:31 +0800 Subject: [PATCH 02/11] upd Signed-off-by: Lancer --- .../en/api/pipelines/longcat_audio_dit.md | 20 +- .../pipeline_longcat_audio_dit.py | 31 +- .../test_longcat_audio_dit.py | 281 +++++++++++------- 3 files changed, 206 insertions(+), 126 deletions(-) diff --git a/docs/source/en/api/pipelines/longcat_audio_dit.md b/docs/source/en/api/pipelines/longcat_audio_dit.md index b605bb4ae672..86488416727e 100644 --- a/docs/source/en/api/pipelines/longcat_audio_dit.md +++ b/docs/source/en/api/pipelines/longcat_audio_dit.md @@ -26,34 +26,32 @@ This pipeline was adapted from the LongCat-AudioDiT reference implementation: ht ## Usage ```py +import soundfile as sf import torch from diffusers import LongCatAudioDiTPipeline -repo_id = "" -tokenizer_path = os.environ["LONGCAT_AUDIO_DIT_TOKENIZER_PATH"] - -pipe = LongCatAudioDiTPipeline.from_pretrained( - repo_id, - tokenizer=tokenizer_path, +pipeline = LongCatAudioDiTPipeline.from_pretrained( + "meituan-longcat/LongCat-AudioDiT-1B", torch_dtype=torch.float16, - local_files_only=True, ) -pipe = pipe.to("cuda") +pipeline = pipeline.to("cuda") -audio = pipe( +audio = pipeline( prompt="A calm ocean wave ambience with soft wind in the background.", - audio_end_in_s=2.0, + audio_end_in_s=5.0, num_inference_steps=16, guidance_scale=4.0, output_type="pt", ).audios + +output = audio[0, 0].float().cpu().numpy() +sf.write("longcat.wav", output, pipeline.sample_rate) ``` ## Tips - `audio_end_in_s` is the most direct way to control output duration. - `output_type="pt"` returns a PyTorch tensor shaped `(batch, channels, samples)`. -- If your tokenizer path is local-only, pass it explicitly to `from_pretrained(...)`. ## LongCatAudioDiTPipeline diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index 56d02cf49468..beed359e0ca3 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -26,7 +26,7 @@ from huggingface_hub.utils import validate_hf_hub_args from safetensors.torch import load_file from torch.nn.utils.rnn import pad_sequence -from transformers import PreTrainedTokenizerBase, T5Tokenizer, UMT5Config, UMT5EncoderModel +from transformers import AutoTokenizer, PreTrainedTokenizerBase, UMT5Config, UMT5EncoderModel from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging @@ -105,7 +105,7 @@ def _load_longcat_tokenizer( tokenizer_kwargs = {"local_files_only": local_files_only} if not isinstance(tokenizer_source, Path) and tokenizer_source == pretrained_model_name_or_path and subfolder: tokenizer_kwargs["subfolder"] = subfolder - return T5Tokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs) + return AutoTokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs) def _resolve_longcat_file( @@ -278,6 +278,10 @@ def from_pretrained( transformer = transformer.to(dtype=torch_dtype) vae = vae.to(dtype=torch_dtype) + text_encoder.eval() + transformer.eval() + vae.eval() + pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) pipe.latent_hop = config.get("latent_hop", pipe.latent_hop) @@ -322,15 +326,24 @@ def prepare_latents( dtype: torch.dtype, generator: torch.Generator | list[torch.Generator] | None = None, ) -> torch.Tensor: + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}." + ) + generators = generator + else: + generators = [generator] * batch_size + latents = [ torch.randn( duration, self.latent_dim, device=device, dtype=dtype, - generator=generator if isinstance(generator, torch.Generator) else None, + generator=generators[idx], ) - for _ in range(batch_size) + for idx in range(batch_size) ] return pad_sequence(latents, padding_value=0.0, batch_first=True) @@ -381,6 +394,12 @@ def __call__( else: if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] * batch_size + else: + negative_prompt = list(negative_prompt) + if len(negative_prompt) != batch_size: + raise ValueError( + f"`negative_prompt` must have batch size {batch_size}, but got {len(negative_prompt)} prompts." + ) neg_text, neg_text_len = self.encode_prompt(negative_prompt, device) neg_text_mask = _lens_to_mask(neg_text_len, length=neg_text.shape[1]) @@ -399,7 +418,7 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens attention_mask=mask, latent_cond=latent_cond, ).sample - if guidance_scale < 1e-5: + if guidance_scale <= 1.0: return pred null_pred = self.transformer( hidden_states=current_sample, @@ -409,7 +428,7 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens attention_mask=mask, latent_cond=latent_cond, ).sample - return pred + (pred - null_pred) * guidance_scale + return null_pred + (pred - null_pred) * guidance_scale for idx in range(len(timesteps) - 1): curr_t = timesteps[idx] diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py index 974c0882c106..74fbe502e7ca 100644 --- a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -1,8 +1,22 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os +import unittest from pathlib import Path -import pytest import torch from safetensors.torch import save_file from transformers import UMT5Config, UMT5EncoderModel @@ -28,112 +42,161 @@ def __call__(self, texts, padding="longest", truncation=True, max_length=None, r ) -def _build_components(): - text_encoder = UMT5EncoderModel(UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=128)) - transformer = LongCatAudioDiTTransformer( - dit_dim=64, - dit_depth=2, - dit_heads=4, - dit_text_dim=32, - latent_dim=8, - text_conv=False, - ) - vae = LongCatAudioDiTVae( - in_channels=1, - channels=16, - c_mults=[1, 2], - strides=[2], - latent_dim=8, - encoder_latent_dim=16, - downsampling_ratio=2, - sample_rate=24000, - ) - return text_encoder, transformer, vae - - -def test_longcat_audio_dit_vae_import(): - assert LongCatAudioDiTVae is not None - - -def test_longcat_audio_pipeline_constructs(): - text_encoder, transformer, vae = _build_components() - pipe = LongCatAudioDiTPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=DummyTokenizer(), transformer=transformer - ) - assert pipe is not None - - -def test_longcat_audio_pipeline_forward_pt_output(): - text_encoder, transformer, vae = _build_components() - pipe = LongCatAudioDiTPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=DummyTokenizer(), transformer=transformer - ) - - result = pipe( - prompt="soft ocean ambience", audio_end_in_s=0.1, num_inference_steps=2, guidance_scale=1.0, output_type="pt" - ) - - assert result.audios.ndim == 3 - assert result.audios.shape[0] == 1 - assert result.audios.shape[1] == 1 - assert result.audios.shape[-1] > 0 - - -def test_longcat_audio_pipeline_from_pretrained_local_dir(tmp_path, monkeypatch): - text_encoder, transformer, vae = _build_components() - model_dir = tmp_path / "longcat-audio-dit" - model_dir.mkdir() - - config = { - "dit_dim": 64, - "dit_depth": 2, - "dit_heads": 4, - "dit_text_dim": 32, - "latent_dim": 8, - "dit_dropout": 0.0, - "dit_bias": True, - "dit_cross_attn": True, - "dit_adaln_type": "global", - "dit_adaln_use_text_cond": True, - "dit_long_skip": True, - "dit_text_conv": False, - "dit_qk_norm": True, - "dit_cross_attn_norm": False, - "dit_eps": 1e-6, - "dit_use_latent_condition": True, - "sampling_rate": 24000, - "latent_hop": 2, - "max_wav_duration": 30.0, - "text_norm_feat": True, - "text_add_embed": True, - "text_encoder_model": "dummy-umt5", - "text_encoder_config": text_encoder.config.to_dict(), - "vae_config": {**dict(vae.config), "model_type": "longcat_audio_dit_vae"}, - } - with (model_dir / "config.json").open("w") as handle: - json.dump(config, handle) - - state_dict = {} - state_dict.update({f"text_encoder.{k}": v for k, v in text_encoder.state_dict().items() if k != "shared.weight"}) - state_dict.update({f"transformer.{k}": v for k, v in transformer.state_dict().items()}) - state_dict.update({f"vae.{k}": v for k, v in vae.state_dict().items()}) - save_file(state_dict, model_dir / "model.safetensors") - - monkeypatch.setattr( - "diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit.T5Tokenizer.from_pretrained", - lambda *args, **kwargs: DummyTokenizer(), - ) - - pipe = LongCatAudioDiTPipeline.from_pretrained(model_dir, local_files_only=True) - result = pipe( - prompt="soft ocean ambience", audio_end_in_s=0.1, num_inference_steps=2, guidance_scale=1.0, output_type="pt" - ) +@require_torch_accelerator +class LongCatAudioDiTPipelineFastTests(unittest.TestCase): + pipeline_class = LongCatAudioDiTPipeline + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder = UMT5EncoderModel(UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=128)) + transformer = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + vae = LongCatAudioDiTVae( + in_channels=1, + channels=16, + c_mults=[1, 2], + strides=[2], + latent_dim=8, + encoder_latent_dim=16, + downsampling_ratio=2, + sample_rate=24000, + ) - assert isinstance(pipe, LongCatAudioDiTPipeline) - assert pipe.sample_rate == 24000 - assert pipe.latent_hop == 2 - assert result.audios.ndim == 3 - assert result.audios.shape[-1] > 0 + return { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": DummyTokenizer(), + "transformer": transformer, + } + + def get_dummy_inputs(self, device, seed=0, prompt="soft ocean ambience"): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + return { + "prompt": prompt, + "audio_end_in_s": 0.1, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "generator": generator, + "output_type": "pt", + } + + def test_inference(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)).audios + + self.assertEqual(output.ndim, 3) + self.assertEqual(output.shape[0], 1) + self.assertEqual(output.shape[1], 1) + self.assertGreater(output.shape[-1], 0) + + def test_inference_batch_single_identical(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + output1 = pipe(**self.get_dummy_inputs(device, seed=42)).audios + output2 = pipe(**self.get_dummy_inputs(device, seed=42)).audios + + self.assertTrue(torch.allclose(output1, output2, atol=1e-4)) + + def test_inference_batch_multiple_prompts(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + output = pipe( + prompt=["soft ocean ambience", "gentle rain ambience"], + audio_end_in_s=0.1, + num_inference_steps=2, + guidance_scale=1.0, + generator=generator, + output_type="pt", + ).audios + + self.assertEqual(output.ndim, 3) + self.assertEqual(output.shape[0], 2) + self.assertEqual(output.shape[1], 1) + self.assertGreater(output.shape[-1], 0) + + def test_from_pretrained_local_dir(self): + import tempfile + from unittest.mock import patch + + device = "cpu" + components = self.get_dummy_components() + text_encoder = components["text_encoder"] + transformer = components["transformer"] + vae = components["vae"] + + with tempfile.TemporaryDirectory() as tmp_dir: + model_dir = Path(tmp_dir) / "longcat-audio-dit" + model_dir.mkdir() + + config = { + "dit_dim": 64, + "dit_depth": 2, + "dit_heads": 4, + "dit_text_dim": 32, + "latent_dim": 8, + "dit_dropout": 0.0, + "dit_bias": True, + "dit_cross_attn": True, + "dit_adaln_type": "global", + "dit_adaln_use_text_cond": True, + "dit_long_skip": True, + "dit_text_conv": False, + "dit_qk_norm": True, + "dit_cross_attn_norm": False, + "dit_eps": 1e-6, + "dit_use_latent_condition": True, + "sampling_rate": 24000, + "latent_hop": 2, + "max_wav_duration": 30.0, + "text_norm_feat": True, + "text_add_embed": True, + "text_encoder_model": "dummy-umt5", + "text_encoder_config": text_encoder.config.to_dict(), + "vae_config": {**dict(vae.config), "model_type": "longcat_audio_dit_vae"}, + } + with (model_dir / "config.json").open("w") as handle: + json.dump(config, handle) + + state_dict = {} + state_dict.update({f"text_encoder.{k}": v for k, v in text_encoder.state_dict().items() if k != "shared.weight"}) + state_dict.update({f"transformer.{k}": v for k, v in transformer.state_dict().items()}) + state_dict.update({f"vae.{k}": v for k, v in vae.state_dict().items()}) + save_file(state_dict, model_dir / "model.safetensors") + + with patch( + "diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit.AutoTokenizer.from_pretrained", + return_value=DummyTokenizer(), + ): + pipe = LongCatAudioDiTPipeline.from_pretrained(model_dir, local_files_only=True) + + output = pipe(**self.get_dummy_inputs(device, seed=0)).audios + + self.assertIsInstance(pipe, LongCatAudioDiTPipeline) + self.assertEqual(pipe.sample_rate, 24000) + self.assertEqual(pipe.latent_hop, 2) + self.assertEqual(output.ndim, 3) + self.assertGreater(output.shape[-1], 0) def test_longcat_audio_top_level_imports(): @@ -148,13 +211,13 @@ def test_longcat_audio_pipeline_from_pretrained_real_local_weights(): model_path = Path(os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")) tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH") if tokenizer_path_env is None: - pytest.skip("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set") + raise unittest.SkipTest("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set") tokenizer_path = Path(tokenizer_path_env) if not model_path.exists(): - pytest.skip(f"LongCat-AudioDiT model path not found: {model_path}") + raise unittest.SkipTest(f"LongCat-AudioDiT model path not found: {model_path}") if not tokenizer_path.exists(): - pytest.skip(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}") + raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}") pipe = LongCatAudioDiTPipeline.from_pretrained( model_path, From 354c983318593bee8983bec6dba713a574df5c8b Mon Sep 17 00:00:00 2001 From: Lancer Date: Mon, 6 Apr 2026 09:15:49 +0800 Subject: [PATCH 03/11] upd --- .../transformer_longcat_audio_dit.py | 171 ++++++++++++------ .../pipeline_longcat_audio_dit.py | 32 ++++ ...st_models_transformer_longcat_audio_dit.py | 27 +++ .../test_longcat_audio_dit.py | 17 +- 4 files changed, 189 insertions(+), 58 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index bfcf9a8f7a3b..f9e9388aff55 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -25,6 +25,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionModuleMixin +from ..attention_dispatch import dispatch_attention_fn from ..modeling_utils import ModelMixin @@ -106,8 +108,8 @@ def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: def _apply_rotary_emb(hidden_states: torch.Tensor, rope: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: cos, sin = rope - cos = cos[None, None].to(hidden_states.device) - sin = sin[None, None].to(hidden_states.device) + cos = cos[None, :, None].to(hidden_states.device) + sin = sin[None, :, None].to(hidden_states.device) return (hidden_states.float() * cos + _rotate_half(hidden_states).float() * sin).to(hidden_states.dtype) @@ -205,25 +207,55 @@ def _modulate( return hidden_states * (1 + scale) + shift -def _masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - query_mask: torch.BoolTensor | None = None, - key_mask: torch.BoolTensor | None = None, -) -> torch.Tensor: - attn_mask = None - if key_mask is not None: - attn_mask = key_mask[:, None, None, :].expand(-1, query.shape[1], query.shape[2], -1) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False - ) - if query_mask is not None: - hidden_states = hidden_states * query_mask[:, None, :, None].to(hidden_states.dtype) - return hidden_states - - -class AudioDiTSelfAttention(nn.Module): +class AudioDiTSelfAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "AudioDiTSelfAttention", + hidden_states: torch.Tensor, + mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + if attn.qk_norm: + query = attn.q_norm(query) + key = attn.k_norm(key) + + head_dim = attn.inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + if rope is not None: + query = _apply_rotary_emb(query, rope) + key = _apply_rotary_emb(key, rope) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + if mask is not None: + hidden_states = hidden_states * mask[:, :, None, None].to(hidden_states.dtype) + + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AudioDiTSelfAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = AudioDiTSelfAttnProcessor + _available_processors = [AudioDiTSelfAttnProcessor] def __init__( self, dim: int, @@ -245,32 +277,67 @@ def __init__( self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, dim, bias=bias), nn.Dropout(dropout)]) + self.set_processor(self._default_processor_cls()) def forward( self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None, rope: tuple | None = None + ) -> torch.Tensor: + return self.processor(self, hidden_states, mask=mask, rope=rope) + + +class AudioDiTCrossAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "AudioDiTCrossAttention", + hidden_states: torch.Tensor, + cond: torch.Tensor, + mask: torch.BoolTensor | None = None, + cond_mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + cond_rope: tuple | None = None, ) -> torch.Tensor: batch_size = hidden_states.shape[0] - query = self.to_q(hidden_states) - key = self.to_k(hidden_states) - value = self.to_v(hidden_states) - if self.qk_norm: - query = self.q_norm(query) - key = self.k_norm(key) - head_dim = self.inner_dim // self.heads - query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + query = attn.to_q(hidden_states) + key = attn.to_k(cond) + value = attn.to_v(cond) + + if attn.qk_norm: + query = attn.q_norm(query) + key = attn.k_norm(key) + + head_dim = attn.inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + if rope is not None: query = _apply_rotary_emb(query, rope) - key = _apply_rotary_emb(key, rope) - hidden_states = _masked_attention(query, key, value, query_mask=mask, key_mask=mask) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim).to(query.dtype) - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1](hidden_states) + if cond_rope is not None: + key = _apply_rotary_emb(key, cond_rope) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=cond_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + if mask is not None: + hidden_states = hidden_states * mask[:, :, None, None].to(hidden_states.dtype) + + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) return hidden_states -class AudioDiTCrossAttention(nn.Module): +class AudioDiTCrossAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = AudioDiTCrossAttnProcessor + _available_processors = [AudioDiTCrossAttnProcessor] def __init__( self, q_dim: int, @@ -293,6 +360,7 @@ def __init__( self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)]) + self.set_processor(self._default_processor_cls()) def forward( self, @@ -303,26 +371,15 @@ def forward( rope: tuple | None = None, cond_rope: tuple | None = None, ) -> torch.Tensor: - batch_size = hidden_states.shape[0] - query = self.to_q(hidden_states) - key = self.to_k(cond) - value = self.to_v(cond) - if self.qk_norm: - query = self.q_norm(query) - key = self.k_norm(key) - head_dim = self.inner_dim // self.heads - query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - if rope is not None: - query = _apply_rotary_emb(query, rope) - if cond_rope is not None: - key = _apply_rotary_emb(key, cond_rope) - hidden_states = _masked_attention(query, key, value, query_mask=mask, key_mask=cond_mask) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim).to(query.dtype) - hidden_states = self.to_out[0](hidden_states) - hidden_states = self.to_out[1](hidden_states) - return hidden_states + return self.processor( + self, + hidden_states, + cond=cond, + mask=mask, + cond_mask=cond_mask, + rope=rope, + cond_rope=cond_rope, + ) class AudioDiTFeedForward(nn.Module): diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index beed359e0ca3..938a33106b5d 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -191,6 +191,38 @@ def from_pretrained( token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) + try: + cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + ) + except (EnvironmentError, OSError, ValueError): + pass + else: + return super().from_pretrained( + pretrained_model_name_or_path, + tokenizer=tokenizer, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + **kwargs, + ) + if kwargs: logger.warning("Ignoring unsupported LongCatAudioDiTPipeline.from_pretrained kwargs: %s", sorted(kwargs)) diff --git a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py index 64ab4f06ef44..858f7e8484d1 100644 --- a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py +++ b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py @@ -25,3 +25,30 @@ def test_longcat_audio_transformer_forward_shape(): ) assert output.sample.shape == hidden_states.shape + + +def test_longcat_audio_transformer_masked_forward(): + model = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + hidden_states = torch.randn(2, 16, 8) + encoder_hidden_states = torch.randn(2, 10, 32) + encoder_attention_mask = torch.tensor([[1] * 10, [1] * 6 + [0] * 4], dtype=torch.bool) + attention_mask = torch.tensor([[1] * 16, [1] * 9 + [0] * 7], dtype=torch.bool) + timestep = torch.tensor([1.0, 1.0]) + + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + attention_mask=attention_mask, + ) + + assert output.sample.shape == hidden_states.shape + assert torch.all(output.sample[1, 9:] == 0) diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py index 74fbe502e7ca..ce16e9e26aab 100644 --- a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -42,7 +42,6 @@ def __call__(self, texts, padding="longest", truncation=True, max_length=None, r ) -@require_torch_accelerator class LongCatAudioDiTPipelineFastTests(unittest.TestCase): pipeline_class = LongCatAudioDiTPipeline @@ -135,6 +134,22 @@ def test_inference_batch_multiple_prompts(self): self.assertEqual(output.shape[1], 1) self.assertGreater(output.shape[-1], 0) + def test_save_pretrained_roundtrip(self): + import tempfile + + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipe.save_pretrained(tmp_dir) + reloaded = self.pipeline_class.from_pretrained(tmp_dir, tokenizer=DummyTokenizer(), local_files_only=True) + output = reloaded(**self.get_dummy_inputs(device, seed=0)).audios + + self.assertIsInstance(reloaded, LongCatAudioDiTPipeline) + self.assertEqual(output.ndim, 3) + self.assertGreater(output.shape[-1], 0) + def test_from_pretrained_local_dir(self): import tempfile from unittest.mock import patch From d278357119b67221c15c03664f9d15e459d1ab89 Mon Sep 17 00:00:00 2001 From: Lancer <402430575@qq.com> Date: Wed, 8 Apr 2026 07:46:41 +0800 Subject: [PATCH 04/11] Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../autoencoder_longcat_audio_dit.py | 56 +++++++++++-------- .../transformer_longcat_audio_dit.py | 12 ++-- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py index 9ab0a0d27470..f7f73555290b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py +++ b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py @@ -56,10 +56,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def _get_vae_activation(name: str, channels: int = 0) -> nn.Module: if name == "elu": - return nn.ELU() - if name == "snake": - return Snake1d(channels) - raise ValueError(f"Unknown activation: {name}") + act = nn.ELU() + elif name == "snake": + act = Snake1d(channels) + else: + raise ValueError(f"Unknown activation: {name}") + return act def _pixel_unshuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: @@ -150,9 +152,11 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.residual is None: - return self.layers(hidden_states) - return self.layers(hidden_states) + self.residual(hidden_states) + output_hidden_states = self.layers(hidden_states) + if self.residual is not None: + residual = self.residual(hidden_states) + output_hidden_states = output_hidden_states + residual + return output_hidden_states class VaeDecoderBlock(nn.Module): @@ -181,9 +185,11 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.residual is None: - return self.layers(hidden_states) - return self.layers(hidden_states) + self.residual(hidden_states) + output_hidden_states = self.layers(hidden_states) + if self.residual is not None: + residual = self.residual(hidden_states) + output_hidden_states = output_hidden_states + residual + return output_hidden_states class AudioDiTVaeEncoder(nn.Module): @@ -191,8 +197,8 @@ def __init__( self, in_channels: int = 1, channels: int = 128, - c_mults=None, - strides=None, + c_mults: list[int] | None = None, + strides: list[int] | None = None, latent_dim: int = 64, encoder_latent_dim: int = 128, use_snake: bool = True, @@ -227,10 +233,12 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.shortcut is None: - return self.layers(hidden_states) hidden_states = self.layers[:-1](hidden_states) - return self.layers[-1](hidden_states) + self.shortcut(hidden_states) + output_hidden_states = self.layers[-1](hidden_states) + if self.shortcut is not None: + shortcut = self.shortcut(hidden_states) + output_hidden_states = output_hidden_states + shortcut + return output_hidden_states class AudioDiTVaeDecoder(nn.Module): @@ -238,8 +246,8 @@ def __init__( self, in_channels: int = 1, channels: int = 128, - c_mults=None, - strides=None, + c_mults: list[int] | None = None, + strides: list[int] | None = None, latent_dim: int = 64, use_snake: bool = True, in_shortcut: str = "duplicating", @@ -277,10 +285,12 @@ def __init__( self.layers = nn.Sequential(*layers) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.shortcut is None: - return self.layers(hidden_states) - hidden_states = self.shortcut(hidden_states) + self.layers[0](hidden_states) - return self.layers[1:](hidden_states) + hidden_states = self.layers[:-1](hidden_states) + output_hidden_states = self.layers[-1](hidden_states) + if self.shortcut is not None: + shortcut = self.shortcut(hidden_states) + output_hidden_states = output_hidden_states + shortcut + return output_hidden_states @dataclass @@ -299,8 +309,8 @@ def __init__( self, in_channels: int = 1, channels: int = 128, - c_mults=None, - strides=None, + c_mults: list[int] | None = None, + strides: list[int] | None = None, latent_dim: int = 64, encoder_latent_dim: int = 128, use_snake: bool = True, diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index f9e9388aff55..b17126993181 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -293,11 +293,11 @@ def __call__( self, attn: "AudioDiTCrossAttention", hidden_states: torch.Tensor, - cond: torch.Tensor, - mask: torch.BoolTensor | None = None, - cond_mask: torch.BoolTensor | None = None, - rope: tuple | None = None, - cond_rope: tuple | None = None, + encoder_hidden_states: torch.Tensor, + post_attention_mask: torch.BoolTensor | None = None, + attention_mask: torch.BoolTensor | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: batch_size = hidden_states.shape[0] query = attn.to_q(hidden_states) @@ -422,9 +422,11 @@ def __init__( self.adaln_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True) elif adaln_type == "global": self.adaln_scale_shift = nn.Parameter(torch.randn(dim * 6) / dim**0.5) + self.self_attn = AudioDiTSelfAttention( dim, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps ) + self.use_cross_attn = cross_attn if cross_attn: self.cross_attn = AudioDiTCrossAttention( From 800d2d3580308cfba9a99896414cdfbb4b572001 Mon Sep 17 00:00:00 2001 From: Lancer Date: Wed, 8 Apr 2026 11:58:23 +0800 Subject: [PATCH 05/11] upd Signed-off-by: Lancer --- .../autoencoder_longcat_audio_dit.py | 10 ++-- .../transformer_longcat_audio_dit.py | 48 +++++++++---------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py index f7f73555290b..f3b84f7870f1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py +++ b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py @@ -285,12 +285,10 @@ def __init__( self.layers = nn.Sequential(*layers) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.layers[:-1](hidden_states) - output_hidden_states = self.layers[-1](hidden_states) - if self.shortcut is not None: - shortcut = self.shortcut(hidden_states) - output_hidden_states = output_hidden_states + shortcut - return output_hidden_states + if self.shortcut is None: + return self.layers(hidden_states) + hidden_states = self.shortcut(hidden_states) + self.layers[0](hidden_states) + return self.layers[1:](hidden_states) @dataclass diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index b17126993181..7e02325c12fa 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -301,8 +301,8 @@ def __call__( ) -> torch.Tensor: batch_size = hidden_states.shape[0] query = attn.to_q(hidden_states) - key = attn.to_k(cond) - value = attn.to_v(cond) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) if attn.qk_norm: query = attn.q_norm(query) @@ -313,21 +313,21 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) - if rope is not None: - query = _apply_rotary_emb(query, rope) - if cond_rope is not None: - key = _apply_rotary_emb(key, cond_rope) + if audio_rotary_emb is not None: + query = _apply_rotary_emb(query, audio_rotary_emb) + if prompt_rotary_emb is not None: + key = _apply_rotary_emb(key, prompt_rotary_emb) hidden_states = dispatch_attention_fn( query, key, value, - attn_mask=cond_mask, + attn_mask=attention_mask, backend=self._attention_backend, parallel_config=self._parallel_config, ) - if mask is not None: - hidden_states = hidden_states * mask[:, :, None, None].to(hidden_states.dtype) + if post_attention_mask is not None: + hidden_states = hidden_states * post_attention_mask[:, :, None, None].to(hidden_states.dtype) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) hidden_states = attn.to_out[0](hidden_states) @@ -365,20 +365,20 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - cond: torch.Tensor, - mask: torch.BoolTensor | None = None, - cond_mask: torch.BoolTensor | None = None, - rope: tuple | None = None, - cond_rope: tuple | None = None, + encoder_hidden_states: torch.Tensor, + post_attention_mask: torch.BoolTensor | None = None, + attention_mask: torch.BoolTensor | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: return self.processor( self, hidden_states, - cond=cond, - mask=mask, - cond_mask=cond_mask, - rope=rope, - cond_rope=cond_rope, + encoder_hidden_states=encoder_hidden_states, + post_attention_mask=post_attention_mask, + attention_mask=attention_mask, + audio_rotary_emb=audio_rotary_emb, + prompt_rotary_emb=prompt_rotary_emb, ) @@ -471,11 +471,11 @@ def forward( if self.use_cross_attn: cross_output = self.cross_attn( hidden_states=self.cross_attn_norm(hidden_states), - cond=self.cross_attn_norm_c(cond), - mask=mask, - cond_mask=cond_mask, - rope=rope, - cond_rope=cond_rope, + encoder_hidden_states=self.cross_attn_norm_c(cond), + post_attention_mask=mask, + attention_mask=cond_mask, + audio_rotary_emb=rope, + prompt_rotary_emb=cond_rope, ) hidden_states = hidden_states + cross_output From 06351829381fb33da50c9cd60fed5c685d485728 Mon Sep 17 00:00:00 2001 From: Lancer Date: Thu, 9 Apr 2026 02:04:17 +0800 Subject: [PATCH 06/11] upd Signed-off-by: Lancer --- src/diffusers/models/__init__.py | 2 + .../autoencoder_longcat_audio_dit.py | 77 ++++---- .../transformer_longcat_audio_dit.py | 180 +++++++----------- .../pipeline_longcat_audio_dit.py | 161 +++++++++------- 4 files changed, 198 insertions(+), 222 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 2b24b53a7035..df586294102c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -180,6 +180,7 @@ AutoencoderTiny, AutoencoderVidTok, ConsistencyDecoderVAE, + LongCatAudioDiTVae, VQModel, ) from .cache_utils import CacheMixin @@ -232,6 +233,7 @@ HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, + LongCatAudioDiTTransformer, LongCatImageTransformer2DModel, LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, diff --git a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py index f3b84f7870f1..bd7538f9a510 100644 --- a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py +++ b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py @@ -64,16 +64,6 @@ def _get_vae_activation(name: str, channels: int = 0) -> nn.Module: return act -def _pixel_unshuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: - batch, channels, width = hidden_states.size() - return ( - hidden_states.view(batch, channels, width // factor, factor) - .permute(0, 1, 3, 2) - .contiguous() - .view(batch, channels * factor, width // factor) - ) - - def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: batch, channels, width = hidden_states.size() return ( @@ -92,9 +82,14 @@ def __init__(self, in_channels: int, out_channels: int, factor: int): self.out_channels = out_channels def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = _pixel_unshuffle_1d(hidden_states, self.factor) - batch, _channels, width = hidden_states.shape - return hidden_states.view(batch, self.out_channels, self.group_size, width).mean(dim=2) + batch, channels, width = hidden_states.shape + hidden_states = ( + hidden_states.view(batch, channels, width // self.factor, self.factor) + .permute(0, 1, 3, 2) + .contiguous() + .view(batch, channels * self.factor, width // self.factor) + ) + return hidden_states.view(batch, self.out_channels, self.group_size, width // self.factor).mean(dim=2) class UpsampleShortcut(nn.Module): @@ -110,15 +105,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class VaeResidualUnit(nn.Module): def __init__( - self, in_channels: int, out_channels: int, dilation: int, kernel_size: int = 7, use_snake: bool = False + self, in_channels: int, out_channels: int, dilation: int, kernel_size: int = 7, act_fn: str = "snake" ): super().__init__() padding = (dilation * (kernel_size - 1)) // 2 - activation = "snake" if use_snake else "elu" self.layers = nn.Sequential( - _get_vae_activation(activation, channels=out_channels), + _get_vae_activation(act_fn, channels=out_channels), _wn_conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding), - _get_vae_activation(activation, channels=out_channels), + _get_vae_activation(act_fn, channels=out_channels), _wn_conv1d(out_channels, out_channels, kernel_size=1), ) @@ -132,17 +126,16 @@ def __init__( in_channels: int, out_channels: int, stride: int, - use_snake: bool = False, + act_fn: str = "snake", downsample_shortcut: str = "none", ): super().__init__() layers = [ - VaeResidualUnit(in_channels, in_channels, dilation=1, use_snake=use_snake), - VaeResidualUnit(in_channels, in_channels, dilation=3, use_snake=use_snake), - VaeResidualUnit(in_channels, in_channels, dilation=9, use_snake=use_snake), + VaeResidualUnit(in_channels, in_channels, dilation=1, act_fn=act_fn), + VaeResidualUnit(in_channels, in_channels, dilation=3, act_fn=act_fn), + VaeResidualUnit(in_channels, in_channels, dilation=9, act_fn=act_fn), ] - activation = "snake" if use_snake else "elu" - layers.append(_get_vae_activation(activation, channels=in_channels)) + layers.append(_get_vae_activation(act_fn, channels=in_channels)) layers.append( _wn_conv1d(in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) ) @@ -165,19 +158,18 @@ def __init__( in_channels: int, out_channels: int, stride: int, - use_snake: bool = False, + act_fn: str = "snake", upsample_shortcut: str = "none", ): super().__init__() - activation = "snake" if use_snake else "elu" layers = [ - _get_vae_activation(activation, channels=in_channels), + _get_vae_activation(act_fn, channels=in_channels), _wn_conv_transpose1d( in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2) ), - VaeResidualUnit(out_channels, out_channels, dilation=1, use_snake=use_snake), - VaeResidualUnit(out_channels, out_channels, dilation=3, use_snake=use_snake), - VaeResidualUnit(out_channels, out_channels, dilation=9, use_snake=use_snake), + VaeResidualUnit(out_channels, out_channels, dilation=1, act_fn=act_fn), + VaeResidualUnit(out_channels, out_channels, dilation=3, act_fn=act_fn), + VaeResidualUnit(out_channels, out_channels, dilation=9, act_fn=act_fn), ] self.layers = nn.Sequential(*layers) self.residual = ( @@ -201,7 +193,7 @@ def __init__( strides: list[int] | None = None, latent_dim: int = 64, encoder_latent_dim: int = 128, - use_snake: bool = True, + act_fn: str = "snake", downsample_shortcut: str = "averaging", out_shortcut: str = "averaging", ): @@ -220,7 +212,7 @@ def __init__( c_mults[idx] * channels_base, c_mults[idx + 1] * channels_base, strides[idx], - use_snake=use_snake, + act_fn=act_fn, downsample_shortcut=downsample_shortcut, ) ) @@ -249,7 +241,7 @@ def __init__( c_mults: list[int] | None = None, strides: list[int] | None = None, latent_dim: int = 64, - use_snake: bool = True, + act_fn: str = "snake", in_shortcut: str = "duplicating", final_tanh: bool = False, upsample_shortcut: str = "duplicating", @@ -274,12 +266,11 @@ def __init__( c_mults[idx] * channels_base, c_mults[idx - 1] * channels_base, strides[idx - 1], - use_snake=use_snake, + act_fn=act_fn, upsample_shortcut=upsample_shortcut, ) ) - activation = "snake" if use_snake else "elu" - layers.append(_get_vae_activation(activation, channels=c_mults[0] * channels_base)) + layers.append(_get_vae_activation(act_fn, channels=c_mults[0] * channels_base)) layers.append(_wn_conv1d(c_mults[0] * channels_base, in_channels, kernel_size=7, padding=3, bias=False)) layers.append(nn.Tanh() if final_tanh else nn.Identity()) self.layers = nn.Sequential(*layers) @@ -311,7 +302,8 @@ def __init__( strides: list[int] | None = None, latent_dim: int = 64, encoder_latent_dim: int = 128, - use_snake: bool = True, + act_fn: str | None = None, + use_snake: bool | None = None, downsample_shortcut: str = "averaging", upsample_shortcut: str = "duplicating", out_shortcut: str = "averaging", @@ -322,6 +314,11 @@ def __init__( scale: float = 0.71, ): super().__init__() + if act_fn is None: + if use_snake is None: + act_fn = "snake" + else: + act_fn = "snake" if use_snake else "elu" self.encoder = AudioDiTVaeEncoder( in_channels=in_channels, channels=channels, @@ -329,7 +326,7 @@ def __init__( strides=strides, latent_dim=latent_dim, encoder_latent_dim=encoder_latent_dim, - use_snake=use_snake, + act_fn=act_fn, downsample_shortcut=downsample_shortcut, out_shortcut=out_shortcut, ) @@ -339,16 +336,12 @@ def __init__( c_mults=c_mults, strides=strides, latent_dim=latent_dim, - use_snake=use_snake, + act_fn=act_fn, in_shortcut=in_shortcut, final_tanh=final_tanh, upsample_shortcut=upsample_shortcut, ) - @property - def sampling_rate(self) -> int: - return self.config.sample_rate - def encode( self, sample: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index 7e02325c12fa..dbc6c88648ca 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -24,30 +24,17 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph from ..attention import AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm @dataclass class LongCatAudioDiTTransformerOutput(BaseOutput): sample: torch.Tensor - -class AudioDiTRMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - normalized = hidden_states.float() * torch.rsqrt( - hidden_states.float().pow(2).mean(dim=-1, keepdim=True) + self.eps - ) - return normalized.to(hidden_states.dtype) * self.weight - - class AudioDiTSinusPositionEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() @@ -79,26 +66,21 @@ def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - self._cos = None - self._sin = None - self._cached_len = 0 - self._cached_device = None - def _build(self, seq_len: int, device: torch.device, dtype: torch.dtype): + @lru_cache_unless_export(maxsize=128) + def _build(self, seq_len: int, device: torch.device | None = None) -> tuple[torch.Tensor, torch.Tensor]: inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - steps = torch.arange(seq_len, dtype=torch.int64).type_as(inv_freq) + if device is not None: + inv_freq = inv_freq.to(device) + steps = torch.arange(seq_len, dtype=torch.int64, device=inv_freq.device).type_as(inv_freq) freqs = torch.outer(steps, inv_freq) embeddings = torch.cat((freqs, freqs), dim=-1) - self._cos = embeddings.cos().to(dtype=dtype, device=device) - self._sin = embeddings.sin().to(dtype=dtype, device=device) - self._cached_len = seq_len - self._cached_device = device + return embeddings.cos().contiguous(), embeddings.sin().contiguous() def forward(self, hidden_states: torch.Tensor, seq_len: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: seq_len = hidden_states.shape[1] if seq_len is None else seq_len - if self._cos is None or seq_len > self._cached_len or self._cached_device != hidden_states.device: - self._build(max(seq_len, self.max_position_embeddings), hidden_states.device, hidden_states.dtype) - return self._cos[:seq_len].to(hidden_states.dtype), self._sin[:seq_len].to(hidden_states.dtype) + cos, sin = self._build(max(seq_len, self.max_position_embeddings), hidden_states.device) + return cos[:seq_len].to(dtype=hidden_states.dtype), sin[:seq_len].to(dtype=hidden_states.dtype) def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: @@ -198,22 +180,13 @@ def forward(self, hidden_states: torch.Tensor, embedding: torch.Tensor) -> torch return hidden_states -def _modulate( - hidden_states: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, eps: float = 1e-6 -) -> torch.Tensor: - hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=eps).type_as(hidden_states) - if scale.ndim == 2: - return hidden_states * (1 + scale[:, None]) + shift[:, None] - return hidden_states * (1 + scale) + shift - - class AudioDiTSelfAttnProcessor: _attention_backend = None _parallel_config = None def __call__( self, - attn: "AudioDiTSelfAttention", + attn: "AudioDiTAttention", hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None, rope: tuple | None = None, @@ -253,36 +226,55 @@ def __call__( return hidden_states -class AudioDiTSelfAttention(nn.Module, AttentionModuleMixin): - _default_processor_cls = AudioDiTSelfAttnProcessor - _available_processors = [AudioDiTSelfAttnProcessor] +class AudioDiTAttention(nn.Module, AttentionModuleMixin): def __init__( self, - dim: int, + q_dim: int, + kv_dim: int | None, heads: int, dim_head: int, dropout: float = 0.0, bias: bool = True, qk_norm: bool = False, eps: float = 1e-6, + processor: AttentionModuleMixin | None = None, ): super().__init__() + kv_dim = q_dim if kv_dim is None else kv_dim self.heads = heads self.inner_dim = dim_head * heads - self.to_q = nn.Linear(dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(dim, self.inner_dim, bias=bias) + self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias) self.qk_norm = qk_norm if qk_norm: - self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) - self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) - self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, dim, bias=bias), nn.Dropout(dropout)]) - self.set_processor(self._default_processor_cls()) + self.q_norm = RMSNorm(self.inner_dim, eps=eps) + self.k_norm = RMSNorm(self.inner_dim, eps=eps) + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)]) + self.set_processor(processor or AudioDiTSelfAttnProcessor()) def forward( - self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None, rope: tuple | None = None + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + post_attention_mask: torch.BoolTensor | None = None, + attention_mask: torch.BoolTensor | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + mask: torch.BoolTensor | None = None, + rope: tuple | None = None, ) -> torch.Tensor: - return self.processor(self, hidden_states, mask=mask, rope=rope) + if encoder_hidden_states is None: + return self.processor(self, hidden_states, mask=mask, rope=rope) + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + post_attention_mask=post_attention_mask, + attention_mask=attention_mask, + audio_rotary_emb=audio_rotary_emb, + prompt_rotary_emb=prompt_rotary_emb, + ) class AudioDiTCrossAttnProcessor: @@ -291,7 +283,7 @@ class AudioDiTCrossAttnProcessor: def __call__( self, - attn: "AudioDiTCrossAttention", + attn: "AudioDiTAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, post_attention_mask: torch.BoolTensor | None = None, @@ -335,53 +327,6 @@ def __call__( return hidden_states -class AudioDiTCrossAttention(nn.Module, AttentionModuleMixin): - _default_processor_cls = AudioDiTCrossAttnProcessor - _available_processors = [AudioDiTCrossAttnProcessor] - def __init__( - self, - q_dim: int, - kv_dim: int, - heads: int, - dim_head: int, - dropout: float = 0.0, - bias: bool = True, - qk_norm: bool = False, - eps: float = 1e-6, - ): - super().__init__() - self.heads = heads - self.inner_dim = dim_head * heads - self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias) - self.qk_norm = qk_norm - if qk_norm: - self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) - self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) - self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)]) - self.set_processor(self._default_processor_cls()) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - post_attention_mask: torch.BoolTensor | None = None, - attention_mask: torch.BoolTensor | None = None, - audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, - prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, - ) -> torch.Tensor: - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - post_attention_mask=post_attention_mask, - attention_mask=attention_mask, - audio_rotary_emb=audio_rotary_emb, - prompt_rotary_emb=prompt_rotary_emb, - ) - - class AudioDiTFeedForward(nn.Module): def __init__(self, dim: int, mult: float = 4.0, dropout: float = 0.0, bias: bool = True): super().__init__() @@ -423,14 +368,22 @@ def __init__( elif adaln_type == "global": self.adaln_scale_shift = nn.Parameter(torch.randn(dim * 6) / dim**0.5) - self.self_attn = AudioDiTSelfAttention( - dim, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps + self.self_attn = AudioDiTAttention( + dim, None, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps ) self.use_cross_attn = cross_attn if cross_attn: - self.cross_attn = AudioDiTCrossAttention( - dim, cond_dim, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps + self.cross_attn = AudioDiTAttention( + dim, + cond_dim, + heads, + dim_head, + dropout=dropout, + bias=bias, + qk_norm=qk_norm, + eps=eps, + processor=AudioDiTCrossAttnProcessor(), ) self.cross_attn_norm = ( nn.LayerNorm(dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity() @@ -464,7 +417,8 @@ def forward( adaln_out = adaln_global_out + self.adaln_scale_shift.unsqueeze(0) gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1) - norm_hidden_states = _modulate(hidden_states, scale_sa, shift_sa) + norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_sa[:, None]) + shift_sa[:, None] attn_output = self.self_attn(norm_hidden_states, mask=mask, rope=rope) hidden_states = hidden_states + gate_sa.unsqueeze(1) * attn_output @@ -479,7 +433,8 @@ def forward( ) hidden_states = hidden_states + cross_output - norm_hidden_states = _modulate(hidden_states, scale_ffn, shift_ffn) + norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_ffn[:, None]) + shift_ffn[:, None] ff_output = self.ffn(norm_hidden_states) hidden_states = hidden_states + gate_ffn.unsqueeze(1) * ff_output return hidden_states @@ -511,9 +466,6 @@ def __init__( super().__init__() dim = dit_dim dim_head = dim // dit_heads - self.long_skip = long_skip - self.adaln_type = adaln_type - self.adaln_use_text_cond = adaln_use_text_cond self.time_embed = AudioDiTTimestepEmbedding(dim) self.input_embed = AudioDiTEmbedder(latent_dim, dim) self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) @@ -554,12 +506,12 @@ def __init__( self._initialize_weights(bias=bias) def _initialize_weights(self, bias: bool = True): - if self.adaln_type == "local": + if self.config.adaln_type == "local": for block in self.blocks: nn.init.constant_(block.adaln_mlp.mlp[-1].weight, 0) if bias: nn.init.constant_(block.adaln_mlp.mlp[-1].bias, 0) - elif self.adaln_type == "global": + elif self.config.adaln_type == "global": nn.init.constant_(self.adaln_global_mlp.mlp[-1].weight, 0) if bias: nn.init.constant_(self.adaln_global_mlp.mlp[-1].bias, 0) @@ -596,11 +548,11 @@ def forward( if self.use_latent_condition and latent_cond is not None: latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask) hidden_states = self.latent_cond_embedder(torch.cat([hidden_states, latent_cond], dim=-1)) - residual = hidden_states.clone() if self.long_skip else None + residual = hidden_states.clone() if self.config.long_skip else None rope = self.rotary_embed(hidden_states, hidden_states.shape[1]) cond_rope = self.rotary_embed(encoder_hidden_states, encoder_hidden_states.shape[1]) - if self.adaln_type == "global": - if self.adaln_use_text_cond: + if self.config.adaln_type == "global": + if self.config.adaln_use_text_cond: text_len = text_mask.sum(1).clamp(min=1).to(encoder_hidden_states.dtype) text_mean = encoder_hidden_states.sum(1) / text_len.unsqueeze(1) norm_cond = timestep_embed + text_mean @@ -630,7 +582,7 @@ def forward( rope=rope, cond_rope=cond_rope, ) - if self.long_skip: + if self.config.long_skip: hidden_states = hidden_states + residual hidden_states = self.norm_out(hidden_states, norm_cond) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index 938a33106b5d..50158e5c9a24 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -25,11 +25,11 @@ from huggingface_hub import hf_hub_download from huggingface_hub.utils import validate_hf_hub_args from safetensors.torch import load_file -from torch.nn.utils.rnn import pad_sequence from transformers import AutoTokenizer, PreTrainedTokenizerBase, UMT5Config, UMT5EncoderModel from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline @@ -50,7 +50,12 @@ def _normalize_text(text: str) -> str: return text.strip() -def _approx_duration_from_text(text: str, max_duration: float = 30.0) -> float: +def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: + if isinstance(text, list): + if not text: + return 0.0 + return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text) + en_dur_per_char = 0.082 zh_dur_per_char = 0.21 text = re.sub(r"\s+", "", text) @@ -69,12 +74,6 @@ def _approx_duration_from_text(text: str, max_duration: float = 30.0) -> float: return min(max_duration, num_zh * zh_dur_per_char + num_en * en_dur_per_char) -def _approx_batch_duration_from_prompts(prompts: list[str]) -> float: - if not prompts: - return 0.0 - return max(_approx_duration_from_text(prompt) for prompt in prompts) - - def _extract_prefixed_state_dict(state_dict: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: prefix = f"{prefix}." return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)} @@ -165,9 +164,15 @@ def __init__( transformer: LongCatAudioDiTTransformer, ): super().__init__() - self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) self.sample_rate = getattr(vae.config, "sample_rate", 24000) self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048) + self.vae_scale_factor = self.latent_hop self.latent_dim = getattr(transformer.config, "latent_dim", 64) self.max_wav_duration = 30.0 self.text_norm_feat = True @@ -317,6 +322,7 @@ def from_pretrained( pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) pipe.latent_hop = config.get("latent_hop", pipe.latent_hop) + pipe.vae_scale_factor = pipe.latent_hop pipe.max_wav_duration = config.get("max_wav_duration", pipe.max_wav_duration) pipe.text_norm_feat = config.get("text_norm_feat", pipe.text_norm_feat) pipe.text_add_embed = config.get("text_add_embed", pipe.text_add_embed) @@ -357,62 +363,90 @@ def prepare_latents( device: torch.device, dtype: torch.dtype, generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, ) -> torch.Tensor: - if isinstance(generator, list): - if len(generator) != batch_size: + if latents is not None: + if latents.ndim != 3: + raise ValueError(f"`latents` must have shape (batch_size, duration, latent_dim), but got {tuple(latents.shape)}.") + if latents.shape[0] != batch_size: + raise ValueError(f"`latents` must have batch size {batch_size}, but got {latents.shape[0]}.") + if latents.shape[2] != self.latent_dim: + raise ValueError(f"`latents` must have latent_dim {self.latent_dim}, but got {latents.shape[2]}.") + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError(f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}.") + + return randn_tensor( + (batch_size, duration, self.latent_dim), generator=generator, device=device, dtype=dtype + ) + + def check_inputs( + self, + prompt: list[str], + negative_prompt: str | list[str] | None, + output_type: str, + ) -> None: + if len(prompt) == 0: + raise ValueError("`prompt` must contain at least one prompt.") + + if output_type not in {"np", "pt", "latent"}: + raise ValueError(f"Unsupported output_type: {output_type}") + + if negative_prompt is not None and not isinstance(negative_prompt, str): + negative_prompt = list(negative_prompt) + if len(negative_prompt) != len(prompt): raise ValueError( - f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}." + f"`negative_prompt` must have batch size {len(prompt)}, but got {len(negative_prompt)} prompts." ) - generators = generator - else: - generators = [generator] * batch_size - - latents = [ - torch.randn( - duration, - self.latent_dim, - device=device, - dtype=dtype, - generator=generators[idx], - ) - for idx in range(batch_size) - ] - return pad_sequence(latents, padding_value=0.0, batch_first=True) @torch.no_grad() def __call__( self, prompt: str | list[str], negative_prompt: str | list[str] | None = None, - audio_end_in_s: float | None = None, - duration: int | None = None, + audio_duration_s: float | None = None, + latents: torch.Tensor | None = None, num_inference_steps: int = 16, guidance_scale: float = 4.0, generator: torch.Generator | list[torch.Generator] | None = None, output_type: str = "np", return_dict: bool = True, ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): Prompt or prompts that guide audio generation. + negative_prompt (`str` or `list[str]`, *optional*): Negative prompt(s) for classifier-free guidance. + audio_duration_s (`float`, *optional*): Target audio duration in seconds. Ignored when `latents` is provided. + latents (`torch.Tensor`, *optional*): Pre-generated noisy latents of shape `(batch_size, duration, latent_dim)`. + num_inference_steps (`int`, defaults to 16): Number of denoising steps. Values below 2 are promoted to 2. + guidance_scale (`float`, defaults to 4.0): Guidance scale for classifier-free guidance. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): Random generator(s). + output_type (`str`, defaults to `"np"`): Output format: `"np"`, `"pt"`, or `"latent"`. + return_dict (`bool`, defaults to `True`): Whether to return `AudioPipelineOutput`. + """ if prompt is None: prompt = [] elif isinstance(prompt, str): prompt = [prompt] else: prompt = list(prompt) + self.check_inputs(prompt, negative_prompt, output_type) batch_size = len(prompt) - if batch_size == 0: - raise ValueError("`prompt` must contain at least one prompt.") device = self._execution_device normalized_prompts = [_normalize_text(text) for text in prompt] - if duration is None: - if audio_end_in_s is not None: - duration = int(audio_end_in_s * self.sample_rate // self.latent_hop) - else: - duration = int( - _approx_batch_duration_from_prompts(normalized_prompts) * self.sample_rate // self.latent_hop - ) - max_duration = int(self.max_wav_duration * self.sample_rate // self.latent_hop) - duration = max(1, min(duration, max_duration)) + if latents is not None: + duration = latents.shape[1] + elif audio_duration_s is not None: + duration = int(audio_duration_s * self.sample_rate // self.vae_scale_factor) + else: + duration = int(_approx_duration_from_text(normalized_prompts) * self.sample_rate // self.vae_scale_factor) + max_duration = int(self.max_wav_duration * self.sample_rate // self.vae_scale_factor) + if latents is None: + duration = max(1, min(duration, max_duration)) text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) duration_tensor = torch.full((batch_size,), duration, device=device, dtype=torch.long) @@ -428,44 +462,41 @@ def __call__( negative_prompt = [negative_prompt] * batch_size else: negative_prompt = list(negative_prompt) - if len(negative_prompt) != batch_size: - raise ValueError( - f"`negative_prompt` must have batch size {batch_size}, but got {len(negative_prompt)} prompts." - ) neg_text, neg_text_len = self.encode_prompt(negative_prompt, device) neg_text_mask = _lens_to_mask(neg_text_len, length=neg_text.shape[1]) latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=text_condition.dtype) - latents = self.prepare_latents(batch_size, duration, device, text_condition.dtype, generator=generator) - num_inference_steps = max(2, num_inference_steps) + latents = self.prepare_latents( + batch_size, duration, device, text_condition.dtype, generator=generator, latents=latents + ) + if num_inference_steps < 2: + logger.warning("`num_inference_steps`=%s is not supported; using 2 instead.", num_inference_steps) + num_inference_steps = 2 timesteps = torch.linspace(0, 1, num_inference_steps, device=device, dtype=text_condition.dtype) sample = latents - def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tensor: + for idx in range(len(timesteps) - 1): + curr_t = timesteps[idx] + dt = timesteps[idx + 1] - timesteps[idx] pred = self.transformer( - hidden_states=current_sample, + hidden_states=sample, encoder_hidden_states=text_condition, encoder_attention_mask=text_mask, timestep=curr_t.expand(batch_size), attention_mask=mask, latent_cond=latent_cond, ).sample - if guidance_scale <= 1.0: - return pred - null_pred = self.transformer( - hidden_states=current_sample, - encoder_hidden_states=neg_text, - encoder_attention_mask=neg_text_mask, - timestep=curr_t.expand(batch_size), - attention_mask=mask, - latent_cond=latent_cond, - ).sample - return null_pred + (pred - null_pred) * guidance_scale - - for idx in range(len(timesteps) - 1): - curr_t = timesteps[idx] - dt = timesteps[idx + 1] - timesteps[idx] - sample = sample + model_step(curr_t, sample) * dt + if guidance_scale > 1.0: + null_pred = self.transformer( + hidden_states=sample, + encoder_hidden_states=neg_text, + encoder_attention_mask=neg_text_mask, + timestep=curr_t.expand(batch_size), + attention_mask=mask, + latent_cond=latent_cond, + ).sample + pred = null_pred + (pred - null_pred) * guidance_scale + sample = sample + pred * dt if output_type == "latent": if not return_dict: @@ -475,8 +506,6 @@ def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tens waveform = self.vae.decode(sample.permute(0, 2, 1)).sample if output_type == "np": waveform = waveform.cpu().float().numpy() - elif output_type != "pt": - raise ValueError(f"Unsupported output_type: {output_type}") if not return_dict: return (waveform,) From d283af4daf5ab3a8a3473ce35f172dca310f3c27 Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 10 Apr 2026 16:05:18 +0800 Subject: [PATCH 07/11] upd Signed-off-by: Lancer --- .../transformer_longcat_audio_dit.py | 31 +-- .../pipeline_longcat_audio_dit.py | 6 +- ...st_models_transformer_longcat_audio_dit.py | 169 +++++++++++----- tests/pipelines/longcat_audio_dit/__init__.py | 0 .../test_longcat_audio_dit.py | 180 +++++++++++------- 5 files changed, 250 insertions(+), 136 deletions(-) create mode 100644 tests/pipelines/longcat_audio_dit/__init__.py diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index dbc6c88648ca..7f2c042f8489 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -188,8 +188,8 @@ def __call__( self, attn: "AudioDiTAttention", hidden_states: torch.Tensor, - mask: torch.BoolTensor | None = None, - rope: tuple | None = None, + attention_mask: torch.BoolTensor | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: batch_size = hidden_states.shape[0] query = attn.to_q(hidden_states) @@ -205,20 +205,20 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) - if rope is not None: - query = _apply_rotary_emb(query, rope) - key = _apply_rotary_emb(key, rope) + if audio_rotary_emb is not None: + query = _apply_rotary_emb(query, audio_rotary_emb) + key = _apply_rotary_emb(key, audio_rotary_emb) hidden_states = dispatch_attention_fn( query, key, value, - attn_mask=mask, + attn_mask=attention_mask, backend=self._attention_backend, parallel_config=self._parallel_config, ) - if mask is not None: - hidden_states = hidden_states * mask[:, :, None, None].to(hidden_states.dtype) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask[:, :, None, None].to(hidden_states.dtype) hidden_states = hidden_states.flatten(2, 3).to(query.dtype) hidden_states = attn.to_out[0](hidden_states) @@ -261,11 +261,14 @@ def forward( attention_mask: torch.BoolTensor | None = None, audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, prompt_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, - mask: torch.BoolTensor | None = None, - rope: tuple | None = None, ) -> torch.Tensor: if encoder_hidden_states is None: - return self.processor(self, hidden_states, mask=mask, rope=rope) + return self.processor( + self, + hidden_states, + attention_mask=attention_mask, + audio_rotary_emb=audio_rotary_emb, + ) return self.processor( self, hidden_states, @@ -419,7 +422,11 @@ def forward( norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_sa[:, None]) + shift_sa[:, None] - attn_output = self.self_attn(norm_hidden_states, mask=mask, rope=rope) + attn_output = self.self_attn( + norm_hidden_states, + attention_mask=mask, + audio_rotary_emb=rope, + ) hidden_states = hidden_states + gate_sa.unsqueeze(1) * attn_output if self.use_cross_attn: diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index 50158e5c9a24..43492fe21010 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -171,8 +171,7 @@ def __init__( transformer=transformer, ) self.sample_rate = getattr(vae.config, "sample_rate", 24000) - self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048) - self.vae_scale_factor = self.latent_hop + self.vae_scale_factor = getattr(vae.config, "downsampling_ratio", 2048) self.latent_dim = getattr(transformer.config, "latent_dim", 64) self.max_wav_duration = 30.0 self.text_norm_feat = True @@ -321,8 +320,7 @@ def from_pretrained( pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) - pipe.latent_hop = config.get("latent_hop", pipe.latent_hop) - pipe.vae_scale_factor = pipe.latent_hop + pipe.vae_scale_factor = config.get("vae_scale_factor", config.get("latent_hop", pipe.vae_scale_factor)) pipe.max_wav_duration = config.get("max_wav_duration", pipe.max_wav_duration) pipe.text_norm_feat = config.get("text_norm_feat", pipe.text_norm_feat) pipe.text_add_embed = config.get("text_add_embed", pipe.text_add_embed) diff --git a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py index 858f7e8484d1..0a52653a8a7b 100644 --- a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py +++ b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py @@ -1,54 +1,125 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest import torch from diffusers import LongCatAudioDiTTransformer +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, +) + + +enable_full_determinism() + + +class LongCatAudioDiTTransformerTesterConfig(BaseModelTesterConfig): + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def model_class(self): + return LongCatAudioDiTTransformer + + @property + def output_shape(self) -> tuple[int, ...]: + return (16, 8) + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | bool | float | str]: + return { + "dit_dim": 64, + "dit_depth": 2, + "dit_heads": 4, + "dit_text_dim": 32, + "latent_dim": 8, + "text_conv": False, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + sequence_length = 16 + encoder_sequence_length = 10 + latent_dim = 8 + text_dim = 32 + + return { + "hidden_states": randn_tensor( + (batch_size, sequence_length, latent_dim), generator=self.generator, device=torch_device + ), + "encoder_hidden_states": randn_tensor( + (batch_size, encoder_sequence_length, text_dim), generator=self.generator, device=torch_device + ), + "encoder_attention_mask": torch.ones( + batch_size, encoder_sequence_length, dtype=torch.bool, device=torch_device + ), + "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.bool, device=torch_device), + "timestep": torch.ones(batch_size, device=torch_device), + } + + +class TestLongCatAudioDiTTransformer(LongCatAudioDiTTransformerTesterConfig, ModelTesterMixin): + pass + + +class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): + def test_layerwise_casting_memory(self): + pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.") + + def test_layerwise_casting_training(self): + pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.") + + def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): + pytest.skip("LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet.") + + +class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin): + def test_torch_compile_repeated_blocks(self): + pytest.skip("LongCatAudioDiTTransformer does not define repeated blocks for regional compilation.") + + +class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterConfig, AttentionTesterMixin): + pass + + +def test_longcat_audio_attention_uses_standard_self_attn_kwargs(): + from diffusers.models.transformers.transformer_longcat_audio_dit import AudioDiTAttention + + attn = AudioDiTAttention(q_dim=4, kv_dim=None, heads=1, dim_head=4, dropout=0.0, bias=False) + + eye = torch.eye(4) + with torch.no_grad(): + attn.to_q.weight.copy_(eye) + attn.to_k.weight.copy_(eye) + attn.to_v.weight.copy_(eye) + attn.to_out[0].weight.copy_(eye) + + hidden_states = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [0.5, 0.5, 0.5, 0.5]]]) + attention_mask = torch.tensor([[True, False]]) + output = attn(hidden_states=hidden_states, attention_mask=attention_mask) -def test_longcat_audio_transformer_forward_shape(): - model = LongCatAudioDiTTransformer( - dit_dim=64, - dit_depth=2, - dit_heads=4, - dit_text_dim=32, - latent_dim=8, - text_conv=False, - ) - hidden_states = torch.randn(2, 16, 8) - encoder_hidden_states = torch.randn(2, 10, 32) - encoder_attention_mask = torch.ones(2, 10, dtype=torch.bool) - timestep = torch.tensor([1.0, 1.0]) - - output = model( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - ) - - assert output.sample.shape == hidden_states.shape - - -def test_longcat_audio_transformer_masked_forward(): - model = LongCatAudioDiTTransformer( - dit_dim=64, - dit_depth=2, - dit_heads=4, - dit_text_dim=32, - latent_dim=8, - text_conv=False, - ) - hidden_states = torch.randn(2, 16, 8) - encoder_hidden_states = torch.randn(2, 10, 32) - encoder_attention_mask = torch.tensor([[1] * 10, [1] * 6 + [0] * 4], dtype=torch.bool) - attention_mask = torch.tensor([[1] * 16, [1] * 9 + [0] * 7], dtype=torch.bool) - timestep = torch.tensor([1.0, 1.0]) - - output = model( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - attention_mask=attention_mask, - ) - - assert output.sample.shape == hidden_states.shape - assert torch.all(output.sample[1, 9:] == 0) + assert torch.allclose(output[:, 1], torch.zeros_like(output[:, 1])) diff --git a/tests/pipelines/longcat_audio_dit/__init__.py b/tests/pipelines/longcat_audio_dit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py index ce16e9e26aab..9010edfa49f4 100644 --- a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -22,7 +22,13 @@ from transformers import UMT5Config, UMT5EncoderModel from diffusers import LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae -from tests.testing_utils import require_torch_accelerator, slow, torch_device + +from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device +from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() class DummyTokenizer: @@ -42,8 +48,16 @@ def __call__(self, texts, padding="longest", truncation=True, max_length=None, r ) -class LongCatAudioDiTPipelineFastTests(unittest.TestCase): +class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = LongCatAudioDiTPipeline + params = ( + TEXT_TO_AUDIO_PARAMS - {"audio_length_in_s", "prompt_embeds", "negative_prompt_embeds", "cross_attention_kwargs"} + ) | {"audio_duration_s"} + batch_params = TEXT_TO_AUDIO_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - {"num_images_per_prompt"} + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False def get_dummy_components(self): torch.manual_seed(0) @@ -82,7 +96,7 @@ def get_dummy_inputs(self, device, seed=0, prompt="soft ocean ambience"): return { "prompt": prompt, - "audio_end_in_s": 0.1, + "audio_duration_s": 0.1, "num_inference_steps": 2, "guidance_scale": 1.0, "generator": generator, @@ -102,39 +116,7 @@ def test_inference(self): self.assertEqual(output.shape[1], 1) self.assertGreater(output.shape[-1], 0) - def test_inference_batch_single_identical(self): - device = "cpu" - pipe = self.pipeline_class(**self.get_dummy_components()) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - output1 = pipe(**self.get_dummy_inputs(device, seed=42)).audios - output2 = pipe(**self.get_dummy_inputs(device, seed=42)).audios - - self.assertTrue(torch.allclose(output1, output2, atol=1e-4)) - - def test_inference_batch_multiple_prompts(self): - device = "cpu" - pipe = self.pipeline_class(**self.get_dummy_components()) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - generator = torch.Generator(device=device).manual_seed(42) - output = pipe( - prompt=["soft ocean ambience", "gentle rain ambience"], - audio_end_in_s=0.1, - num_inference_steps=2, - guidance_scale=1.0, - generator=generator, - output_type="pt", - ).audios - - self.assertEqual(output.ndim, 3) - self.assertEqual(output.shape[0], 2) - self.assertEqual(output.shape[1], 1) - self.assertGreater(output.shape[-1], 0) - - def test_save_pretrained_roundtrip(self): + def test_save_load_local(self): import tempfile device = "cpu" @@ -150,6 +132,57 @@ def test_save_pretrained_roundtrip(self): self.assertEqual(output.ndim, 3) self.assertGreater(output.shape[-1], 0) + def test_save_load_optional_components(self): + self.skipTest("LongCatAudioDiTPipeline does not define optional components.") + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_model_cpu_offload_forward_pass(self): + self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + + def test_cpu_offload_forward_pass_twice(self): + self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + + def test_sequential_cpu_offload_forward_pass(self): + self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + + def test_sequential_offload_forward_pass_twice(self): + self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + + def test_pipeline_level_group_offloading_inference(self): + self.skipTest("LongCatAudioDiTPipeline group offloading coverage is not ready for the standard PipelineTesterMixin test.") + + def test_pipeline_with_accelerator_device_map(self): + self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so device-map roundtrip coverage is skipped here.") + + def test_save_load_float16(self): + self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so float16 reload coverage is skipped here.") + + def test_num_images_per_prompt(self): + self.skipTest("LongCatAudioDiTPipeline does not support num_images_per_prompt.") + + def test_cfg(self): + self.skipTest("LongCatAudioDiTPipeline does not support generic CFG callback tests.") + + def test_callback_inputs(self): + self.skipTest("LongCatAudioDiTPipeline does not expose callback inputs.") + + def test_callback_cfg(self): + self.skipTest("LongCatAudioDiTPipeline does not expose callback CFG inputs.") + + def test_serialization_with_variants(self): + self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer that is not variant-serializable.") + + def test_loading_with_variants(self): + self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer that is not variant-serializable.") + + def test_loading_with_incorrect_variants_raises_error(self): + self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer that is not variant-serializable.") + + def test_encode_prompt_works_in_isolation(self): + self.skipTest("LongCatAudioDiTPipeline.encode_prompt has a custom signature.") + def test_from_pretrained_local_dir(self): import tempfile from unittest.mock import patch @@ -182,7 +215,7 @@ def test_from_pretrained_local_dir(self): "dit_eps": 1e-6, "dit_use_latent_condition": True, "sampling_rate": 24000, - "latent_hop": 2, + "vae_scale_factor": 2, "max_wav_duration": 30.0, "text_norm_feat": True, "text_add_embed": True, @@ -194,7 +227,9 @@ def test_from_pretrained_local_dir(self): json.dump(config, handle) state_dict = {} - state_dict.update({f"text_encoder.{k}": v for k, v in text_encoder.state_dict().items() if k != "shared.weight"}) + state_dict.update( + {f"text_encoder.{k}": v for k, v in text_encoder.state_dict().items() if k != "shared.weight"} + ) state_dict.update({f"transformer.{k}": v for k, v in transformer.state_dict().items()}) state_dict.update({f"vae.{k}": v for k, v in vae.state_dict().items()}) save_file(state_dict, model_dir / "model.safetensors") @@ -209,7 +244,7 @@ def test_from_pretrained_local_dir(self): self.assertIsInstance(pipe, LongCatAudioDiTPipeline) self.assertEqual(pipe.sample_rate, 24000) - self.assertEqual(pipe.latent_hop, 2) + self.assertEqual(pipe.vae_scale_factor, 2) self.assertEqual(output.ndim, 3) self.assertGreater(output.shape[-1], 0) @@ -222,35 +257,38 @@ def test_longcat_audio_top_level_imports(): @slow @require_torch_accelerator -def test_longcat_audio_pipeline_from_pretrained_real_local_weights(): - model_path = Path(os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")) - tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH") - if tokenizer_path_env is None: - raise unittest.SkipTest("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set") - tokenizer_path = Path(tokenizer_path_env) - - if not model_path.exists(): - raise unittest.SkipTest(f"LongCat-AudioDiT model path not found: {model_path}") - if not tokenizer_path.exists(): - raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}") - - pipe = LongCatAudioDiTPipeline.from_pretrained( - model_path, - tokenizer=tokenizer_path, - torch_dtype=torch.float16, - local_files_only=True, - ) - pipe = pipe.to(torch_device) - - result = pipe( - prompt="A calm ocean wave ambience with soft wind in the background.", - audio_end_in_s=2.0, - num_inference_steps=2, - guidance_scale=4.0, - output_type="pt", - ) - - assert result.audios.ndim == 3 - assert result.audios.shape[0] == 1 - assert result.audios.shape[1] == 1 - assert result.audios.shape[-1] > 0 +class LongCatAudioDiTPipelineSlowTests(unittest.TestCase): + pipeline_class = LongCatAudioDiTPipeline + + def test_longcat_audio_pipeline_from_pretrained_real_local_weights(self): + model_path = Path(os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")) + tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH") + if tokenizer_path_env is None: + raise unittest.SkipTest("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set") + tokenizer_path = Path(tokenizer_path_env) + + if not model_path.exists(): + raise unittest.SkipTest(f"LongCat-AudioDiT model path not found: {model_path}") + if not tokenizer_path.exists(): + raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}") + + pipe = LongCatAudioDiTPipeline.from_pretrained( + model_path, + tokenizer=tokenizer_path, + torch_dtype=torch.float16, + local_files_only=True, + ) + pipe = pipe.to(torch_device) + + result = pipe( + prompt="A calm ocean wave ambience with soft wind in the background.", + audio_duration_s=2.0, + num_inference_steps=2, + guidance_scale=4.0, + output_type="pt", + ) + + assert result.audios.ndim == 3 + assert result.audios.shape[0] == 1 + assert result.audios.shape[1] == 1 + assert result.audios.shape[-1] > 0 From a6e2e165d60f128f5c30ba82cde97506f439d7de Mon Sep 17 00:00:00 2001 From: Lancer Date: Fri, 10 Apr 2026 17:00:47 +0800 Subject: [PATCH 08/11] upd Signed-off-by: Lancer --- .../pipeline_longcat_audio_dit.py | 36 ++++++++++++++----- .../test_longcat_audio_dit.py | 25 ++++++++++++- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index 43492fe21010..12cafcccd45b 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -28,6 +28,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase, UMT5Config, UMT5EncoderModel from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae +from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline @@ -79,6 +80,12 @@ def _extract_prefixed_state_dict(state_dict: dict[str, torch.Tensor], prefix: st return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)} +def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]: + num_inference_steps = max(int(num_inference_steps), 2) + num_updates = num_inference_steps - 1 + return torch.linspace(1.0, 1.0 / num_updates, num_updates, dtype=torch.float32).tolist() + + def _load_longcat_tokenizer( pretrained_model_name_or_path: str | Path, text_encoder_model: str | None, @@ -162,13 +169,17 @@ def __init__( text_encoder: UMT5EncoderModel, tokenizer: PreTrainedTokenizerBase, transformer: LongCatAudioDiTTransformer, + scheduler: FlowMatchEulerDiscreteScheduler | None = None, ): super().__init__() + if not isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, + scheduler=scheduler, ) self.sample_rate = getattr(vae.config, "sample_rate", 24000) self.vae_scale_factor = getattr(vae.config, "downsampling_ratio", 2048) @@ -318,7 +329,11 @@ def from_pretrained( transformer.eval() vae.eval() - pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) + scheduler_config = {"shift": 1.0, "invert_sigmas": True} + scheduler_config.update(config.get("scheduler_config", {})) + scheduler = FlowMatchEulerDiscreteScheduler(**scheduler_config) + + pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler) pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) pipe.vae_scale_factor = config.get("vae_scale_factor", config.get("latent_hop", pipe.vae_scale_factor)) pipe.max_wav_duration = config.get("max_wav_duration", pipe.max_wav_duration) @@ -470,17 +485,22 @@ def __call__( if num_inference_steps < 2: logger.warning("`num_inference_steps`=%s is not supported; using 2 instead.", num_inference_steps) num_inference_steps = 2 - timesteps = torch.linspace(0, 1, num_inference_steps, device=device, dtype=text_condition.dtype) + + self.scheduler.set_timesteps( + sigmas=_get_uniform_flow_match_scheduler_sigmas(num_inference_steps), + device=device, + ) + self.scheduler.set_begin_index(0) + timesteps = self.scheduler.timesteps sample = latents - for idx in range(len(timesteps) - 1): - curr_t = timesteps[idx] - dt = timesteps[idx + 1] - timesteps[idx] + for t in timesteps: + curr_t = (t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=text_condition.dtype) pred = self.transformer( hidden_states=sample, encoder_hidden_states=text_condition, encoder_attention_mask=text_mask, - timestep=curr_t.expand(batch_size), + timestep=curr_t, attention_mask=mask, latent_cond=latent_cond, ).sample @@ -489,12 +509,12 @@ def __call__( hidden_states=sample, encoder_hidden_states=neg_text, encoder_attention_mask=neg_text_mask, - timestep=curr_t.expand(batch_size), + timestep=curr_t, attention_mask=mask, latent_cond=latent_cond, ).sample pred = null_pred + (pred - null_pred) * guidance_scale - sample = sample + pred * dt + sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] if output_type == "latent": if not return_dict: diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py index 9010edfa49f4..2f9a8491bce2 100644 --- a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -21,7 +21,10 @@ from safetensors.torch import save_file from transformers import UMT5Config, UMT5EncoderModel -from diffusers import LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae +from diffusers import FlowMatchEulerDiscreteScheduler, LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae +from diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit import ( + _get_uniform_flow_match_scheduler_sigmas, +) from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS @@ -183,6 +186,26 @@ def test_loading_with_incorrect_variants_raises_error(self): def test_encode_prompt_works_in_isolation(self): self.skipTest("LongCatAudioDiTPipeline.encode_prompt has a custom signature.") + def test_uniform_flow_match_scheduler_grid_matches_legacy_manual_updates(self): + num_inference_steps = 6 + scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True) + scheduler.set_timesteps( + sigmas=_get_uniform_flow_match_scheduler_sigmas(num_inference_steps), device="cpu" + ) + + expected_timesteps = torch.linspace(0, 1, num_inference_steps, dtype=torch.float32)[:-1] + actual_timesteps = scheduler.timesteps / scheduler.config.num_train_timesteps + self.assertTrue(torch.allclose(actual_timesteps, expected_timesteps, atol=1e-6, rtol=0)) + + sample = torch.zeros(1, 2, 3) + model_output = torch.ones_like(sample) + expected = sample.clone() + for t0, t1, scheduler_t in zip(expected_timesteps[:-1], expected_timesteps[1:], scheduler.timesteps): + expected = expected + model_output * (t1 - t0) + sample = scheduler.step(model_output, scheduler_t, sample, return_dict=False)[0] + + self.assertTrue(torch.allclose(sample, expected, atol=1e-6, rtol=0)) + def test_from_pretrained_local_dir(self): import tempfile from unittest.mock import patch From 938733976b85f1958dc1b4b65551638033e98bcc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 11 Apr 2026 03:29:54 +0000 Subject: [PATCH 09/11] Apply style fixes --- docs/source/en/_toctree.yml | 4 +- src/diffusers/__init__.py | 16 +++---- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/autoencoders/__init__.py | 2 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformer_longcat_audio_dit.py | 9 +++- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/longcat_audio_dit/__init__.py | 4 +- .../pipeline_longcat_audio_dit.py | 28 +++++++---- ...st_models_transformer_longcat_audio_dit.py | 4 +- .../test_longcat_audio_dit.py | 46 +++++++++++++------ 11 files changed, 78 insertions(+), 43 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 885e2aa27181..184c768c7e7a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -486,10 +486,10 @@ - sections: - local: api/pipelines/audioldm2 title: AudioLDM 2 - - local: api/pipelines/stable_audio - title: Stable Audio - local: api/pipelines/longcat_audio_dit title: LongCat-AudioDiT + - local: api/pipelines/stable_audio + title: Stable Audio title: Audio - sections: - local: api/pipelines/animatediff diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b48a7f0a1c46..d278349ae841 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -178,13 +178,12 @@ ] ) _import_structure["image_processor"] = [ - "IPAdapterMaskProcessor", "InpaintProcessor", + "IPAdapterMaskProcessor", "PixArtImageProcessor", "VaeImageProcessor", "VaeImageProcessorLDM3D", ] - _import_structure["video_processor"] = ["VideoProcessor"] _import_structure["models"].extend( [ "AllegroTransformer3DModel", @@ -212,7 +211,6 @@ "AutoencoderKLTemporalDecoder", "AutoencoderKLWan", "AutoencoderOobleck", - "LongCatAudioDiTVae", "AutoencoderRAE", "AutoencoderTiny", "AutoencoderVidTok", @@ -253,8 +251,9 @@ "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", - "LongCatImageTransformer2DModel", "LongCatAudioDiTTransformer", + "LongCatAudioDiTVae", + "LongCatImageTransformer2DModel", "LTX2VideoTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", @@ -398,6 +397,7 @@ ] ) _import_structure["training_utils"] = ["EMAModel"] + _import_structure["video_processor"] = ["VideoProcessor"] try: if not (is_torch_available() and is_scipy_available()): @@ -594,9 +594,9 @@ "LEditsPPPipelineStableDiffusionXL", "LLaDA2Pipeline", "LLaDA2PipelineOutput", + "LongCatAudioDiTPipeline", "LongCatImageEditPipeline", "LongCatImagePipeline", - "LongCatAudioDiTPipeline", "LTX2ConditionPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", @@ -1010,7 +1010,6 @@ AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, - LongCatAudioDiTVae, AutoencoderRAE, AutoencoderTiny, AutoencoderVidTok, @@ -1051,8 +1050,9 @@ Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, - LongCatImageTransformer2DModel, LongCatAudioDiTTransformer, + LongCatAudioDiTVae, + LongCatImageTransformer2DModel, LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -1368,9 +1368,9 @@ LEditsPPPipelineStableDiffusionXL, LLaDA2Pipeline, LLaDA2PipelineOutput, + LongCatAudioDiTPipeline, LongCatImageEditPipeline, LongCatImagePipeline, - LongCatAudioDiTPipeline, LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index df586294102c..8abd36171fd0 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -50,8 +50,8 @@ _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] - _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_longcat_audio_dit"] = ["LongCatAudioDiTVae"] + _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_vidtok"] = ["AutoencoderVidTok"] @@ -112,8 +112,8 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] - _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] + _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 803b27285a42..90dfa31fab6f 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -19,8 +19,8 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan -from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_longcat_audio_dit import LongCatAudioDiTVae +from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_rae import AutoencoderRAE from .autoencoder_tiny import AutoencoderTiny from .autoencoder_vidtok import AutoencoderVidTok diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ae91c5a54e49..19a3c6091d0d 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -35,10 +35,10 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel + from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_ltx2 import LTX2VideoTransformer3DModel - from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index 7f2c042f8489..6b85bc279816 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -35,6 +35,7 @@ class LongCatAudioDiTTransformerOutput(BaseOutput): sample: torch.Tensor + class AudioDiTSinusPositionEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() @@ -420,7 +421,9 @@ def forward( adaln_out = adaln_global_out + self.adaln_scale_shift.unsqueeze(0) gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1) - norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(hidden_states) + norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as( + hidden_states + ) norm_hidden_states = norm_hidden_states * (1 + scale_sa[:, None]) + shift_sa[:, None] attn_output = self.self_attn( norm_hidden_states, @@ -440,7 +443,9 @@ def forward( ) hidden_states = hidden_states + cross_output - norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as(hidden_states) + norm_hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=1e-6).type_as( + hidden_states + ) norm_hidden_states = norm_hidden_states * (1 + scale_ffn[:, None]) + shift_ffn[:, None] ff_output = self.ffn(norm_hidden_states) hidden_states = hidden_states + gate_ffn.unsqueeze(1) * ff_output diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 154c28d6bc24..db5fb1beec01 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -751,8 +751,8 @@ LEditsPPPipelineStableDiffusionXL, ) from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput - from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline from .longcat_audio_dit import LongCatAudioDiTPipeline + from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline from .ltx import ( LTXConditionPipeline, LTXI2VLongMultiPromptPipeline, diff --git a/src/diffusers/pipelines/longcat_audio_dit/__init__.py b/src/diffusers/pipelines/longcat_audio_dit/__init__.py index 61cb89b4140f..b7c03a70371a 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/__init__.py +++ b/src/diffusers/pipelines/longcat_audio_dit/__init__.py @@ -21,7 +21,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure['pipeline_longcat_audio_dit'] = ['LongCatAudioDiTPipeline'] + _import_structure["pipeline_longcat_audio_dit"] = ["LongCatAudioDiTPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,7 +34,7 @@ else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()['__file__'], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index 12cafcccd45b..788c859313dd 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -77,7 +77,7 @@ def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0 def _extract_prefixed_state_dict(state_dict: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: prefix = f"{prefix}." - return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)} + return {key[len(prefix) :]: value for key, value in state_dict.items() if key.startswith(prefix)} def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]: @@ -306,7 +306,9 @@ def from_pretrained( allowed_missing = {"shared.weight"} unexpected_missing = set(text_missing) - allowed_missing if unexpected_missing: - raise RuntimeError(f"Unexpected missing LongCatAudioDiT text encoder weights: {sorted(unexpected_missing)}") + raise RuntimeError( + f"Unexpected missing LongCatAudioDiT text encoder weights: {sorted(unexpected_missing)}" + ) if text_unexpected: raise RuntimeError(f"Unexpected LongCatAudioDiT text encoder weights: {sorted(text_unexpected)}") if "shared.weight" in text_missing: @@ -333,7 +335,9 @@ def from_pretrained( scheduler_config.update(config.get("scheduler_config", {})) scheduler = FlowMatchEulerDiscreteScheduler(**scheduler_config) - pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler) + pipe = cls( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler + ) pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) pipe.vae_scale_factor = config.get("vae_scale_factor", config.get("latent_hop", pipe.vae_scale_factor)) pipe.max_wav_duration = config.get("max_wav_duration", pipe.max_wav_duration) @@ -380,7 +384,9 @@ def prepare_latents( ) -> torch.Tensor: if latents is not None: if latents.ndim != 3: - raise ValueError(f"`latents` must have shape (batch_size, duration, latent_dim), but got {tuple(latents.shape)}.") + raise ValueError( + f"`latents` must have shape (batch_size, duration, latent_dim), but got {tuple(latents.shape)}." + ) if latents.shape[0] != batch_size: raise ValueError(f"`latents` must have batch size {batch_size}, but got {latents.shape[0]}.") if latents.shape[2] != self.latent_dim: @@ -388,11 +394,11 @@ def prepare_latents( return latents.to(device=device, dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError(f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}.") + raise ValueError( + f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}." + ) - return randn_tensor( - (batch_size, duration, self.latent_dim), generator=generator, device=device, dtype=dtype - ) + return randn_tensor((batch_size, duration, self.latent_dim), generator=generator, device=device, dtype=dtype) def check_inputs( self, @@ -432,8 +438,10 @@ def __call__( Args: prompt (`str` or `list[str]`): Prompt or prompts that guide audio generation. negative_prompt (`str` or `list[str]`, *optional*): Negative prompt(s) for classifier-free guidance. - audio_duration_s (`float`, *optional*): Target audio duration in seconds. Ignored when `latents` is provided. - latents (`torch.Tensor`, *optional*): Pre-generated noisy latents of shape `(batch_size, duration, latent_dim)`. + audio_duration_s (`float`, *optional*): + Target audio duration in seconds. Ignored when `latents` is provided. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents of shape `(batch_size, duration, latent_dim)`. num_inference_steps (`int`, defaults to 16): Number of denoising steps. Values below 2 are promoted to 2. guidance_scale (`float`, defaults to 4.0): Guidance scale for classifier-free guidance. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): Random generator(s). diff --git a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py index 0a52653a8a7b..b1767693faa3 100644 --- a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py +++ b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py @@ -93,7 +93,9 @@ def test_layerwise_casting_training(self): pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.") def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): - pytest.skip("LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet.") + pytest.skip( + "LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet." + ) class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin): diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py index 2f9a8491bce2..f94f5ab20226 100644 --- a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -21,7 +21,12 @@ from safetensors.torch import save_file from transformers import UMT5Config, UMT5EncoderModel -from diffusers import FlowMatchEulerDiscreteScheduler, LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae +from diffusers import ( + FlowMatchEulerDiscreteScheduler, + LongCatAudioDiTPipeline, + LongCatAudioDiTTransformer, + LongCatAudioDiTVae, +) from diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit import ( _get_uniform_flow_match_scheduler_sigmas, ) @@ -54,7 +59,8 @@ def __call__(self, texts, padding="longest", truncation=True, max_length=None, r class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = LongCatAudioDiTPipeline params = ( - TEXT_TO_AUDIO_PARAMS - {"audio_length_in_s", "prompt_embeds", "negative_prompt_embeds", "cross_attention_kwargs"} + TEXT_TO_AUDIO_PARAMS + - {"audio_length_in_s", "prompt_embeds", "negative_prompt_embeds", "cross_attention_kwargs"} ) | {"audio_duration_s"} batch_params = TEXT_TO_AUDIO_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"num_images_per_prompt"} @@ -142,25 +148,39 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) def test_model_cpu_offload_forward_pass(self): - self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + self.skipTest( + "LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test." + ) def test_cpu_offload_forward_pass_twice(self): - self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + self.skipTest( + "LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test." + ) def test_sequential_cpu_offload_forward_pass(self): - self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + self.skipTest( + "LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test." + ) def test_sequential_offload_forward_pass_twice(self): - self.skipTest("LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test.") + self.skipTest( + "LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test." + ) def test_pipeline_level_group_offloading_inference(self): - self.skipTest("LongCatAudioDiTPipeline group offloading coverage is not ready for the standard PipelineTesterMixin test.") + self.skipTest( + "LongCatAudioDiTPipeline group offloading coverage is not ready for the standard PipelineTesterMixin test." + ) def test_pipeline_with_accelerator_device_map(self): - self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so device-map roundtrip coverage is skipped here.") + self.skipTest( + "LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so device-map roundtrip coverage is skipped here." + ) def test_save_load_float16(self): - self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so float16 reload coverage is skipped here.") + self.skipTest( + "LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so float16 reload coverage is skipped here." + ) def test_num_images_per_prompt(self): self.skipTest("LongCatAudioDiTPipeline does not support num_images_per_prompt.") @@ -189,9 +209,7 @@ def test_encode_prompt_works_in_isolation(self): def test_uniform_flow_match_scheduler_grid_matches_legacy_manual_updates(self): num_inference_steps = 6 scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True) - scheduler.set_timesteps( - sigmas=_get_uniform_flow_match_scheduler_sigmas(num_inference_steps), device="cpu" - ) + scheduler.set_timesteps(sigmas=_get_uniform_flow_match_scheduler_sigmas(num_inference_steps), device="cpu") expected_timesteps = torch.linspace(0, 1, num_inference_steps, dtype=torch.float32)[:-1] actual_timesteps = scheduler.timesteps / scheduler.config.num_train_timesteps @@ -284,7 +302,9 @@ class LongCatAudioDiTPipelineSlowTests(unittest.TestCase): pipeline_class = LongCatAudioDiTPipeline def test_longcat_audio_pipeline_from_pretrained_real_local_weights(self): - model_path = Path(os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")) + model_path = Path( + os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B") + ) tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH") if tokenizer_path_env is None: raise unittest.SkipTest("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set") From 5bed39e61c4d4f9080d557eb4f45aa6698c377ff Mon Sep 17 00:00:00 2001 From: Lancer Date: Sat, 11 Apr 2026 11:46:45 +0800 Subject: [PATCH 10/11] upd Signed-off-by: Lancer --- src/diffusers/utils/dummy_pt_objects.py | 30 +++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 ++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index cf4fdc1bbdcc..390266c57930 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1361,6 +1361,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LongCatAudioDiTTransformer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LongCatAudioDiTVae(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LongCatImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1e4d14566160..ae4eacfcb285 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2252,6 +2252,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LongCatAudioDiTPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LongCatImageEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 9a05478c6831b83d7aa82da9f19dc807dfa83306 Mon Sep 17 00:00:00 2001 From: Lancer Date: Sat, 11 Apr 2026 20:59:34 +0800 Subject: [PATCH 11/11] upd Signed-off-by: Lancer --- .../autoencoder_longcat_audio_dit.py | 5 + .../transformer_longcat_audio_dit.py | 6 +- .../pipeline_longcat_audio_dit.py | 151 +++++++++++------- ...st_models_transformer_longcat_audio_dit.py | 12 +- .../test_longcat_audio_dit.py | 89 +++-------- 5 files changed, 121 insertions(+), 142 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py index bd7538f9a510..455599a30f60 100644 --- a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py +++ b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py @@ -25,6 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook from ...utils.torch_utils import randn_tensor from ..modeling_utils import ModelMixin from .vae import AutoencoderMixin @@ -293,6 +294,8 @@ class LongCatAudioDiTVaeDecoderOutput(BaseOutput): class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin): + _supports_group_offloading = False + @register_to_config def __init__( self, @@ -342,6 +345,7 @@ def __init__( upsample_shortcut=upsample_shortcut, ) + @apply_forward_hook def encode( self, sample: torch.Tensor, @@ -367,6 +371,7 @@ def encode( return (latents,) return LongCatAudioDiTVaeEncoderOutput(latents=latents) + @apply_forward_hook def decode( self, latents: torch.Tensor, return_dict: bool = True ) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]: diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py index 6b85bc279816..4262f8fbfdc8 100644 --- a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -454,6 +454,7 @@ def forward( class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = False + _repeated_blocks = ["AudioDiTBlock"] @register_to_config def __init__( @@ -543,8 +544,7 @@ def forward( latent_cond: torch.Tensor | None = None, return_dict: bool = True, ) -> LongCatAudioDiTTransformerOutput | tuple[torch.Tensor]: - dtype = next(self.parameters()).dtype - hidden_states = hidden_states.to(dtype) + dtype = hidden_states.dtype encoder_hidden_states = encoder_hidden_states.to(dtype) timestep = timestep.to(dtype) batch_size = hidden_states.shape[0] @@ -558,7 +558,7 @@ def forward( encoder_hidden_states = encoder_hidden_states.masked_fill(text_mask.logical_not().unsqueeze(-1), 0.0) hidden_states = self.input_embed(hidden_states, attention_mask) if self.use_latent_condition and latent_cond is not None: - latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask) + latent_cond = self.latent_embed(latent_cond.to(hidden_states.dtype), attention_mask) hidden_states = self.latent_cond_embedder(torch.cat([hidden_states, latent_cond], dim=-1)) residual = hidden_states.clone() if self.config.long_skip else None rope = self.rotary_embed(hidden_states, hidden_states.shape[1]) diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py index 788c859313dd..2cdd78b284f0 100644 --- a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -18,7 +18,7 @@ import json import re from pathlib import Path -from typing import Any +from typing import Any, Callable import torch import torch.nn.functional as F @@ -52,27 +52,30 @@ def _normalize_text(text: str) -> str: def _approx_duration_from_text(text: str | list[str], max_duration: float = 30.0) -> float: - if isinstance(text, list): - if not text: - return 0.0 - return max(_approx_duration_from_text(prompt, max_duration=max_duration) for prompt in text) + if not text: + return 0.0 + if isinstance(text, str): + text = [text] en_dur_per_char = 0.082 zh_dur_per_char = 0.21 - text = re.sub(r"\s+", "", text) - num_zh = num_en = num_other = 0 - for char in text: - if "一" <= char <= "鿿": - num_zh += 1 - elif char.isalpha(): - num_en += 1 + durations = [] + for prompt in text: + prompt = re.sub(r"\s+", "", prompt) + num_zh = num_en = num_other = 0 + for char in prompt: + if "一" <= char <= "鿿": + num_zh += 1 + elif char.isalpha(): + num_en += 1 + else: + num_other += 1 + if num_zh > num_en: + num_zh += num_other else: - num_other += 1 - if num_zh > num_en: - num_zh += num_other - else: - num_en += num_other - return min(max_duration, num_zh * zh_dur_per_char + num_en * en_dur_per_char) + num_en += num_other + durations.append(num_zh * zh_dur_per_char + num_en * en_dur_per_char) + return min(max_duration, max(durations)) if durations else 0.0 def _extract_prefixed_state_dict(state_dict: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: @@ -80,12 +83,6 @@ def _extract_prefixed_state_dict(state_dict: dict[str, torch.Tensor], prefix: st return {key[len(prefix) :]: value for key, value in state_dict.items() if key.startswith(prefix)} -def _get_uniform_flow_match_scheduler_sigmas(num_inference_steps: int) -> list[float]: - num_inference_steps = max(int(num_inference_steps), 2) - num_updates = num_inference_steps - 1 - return torch.linspace(1.0, 1.0 / num_updates, num_updates, dtype=torch.float32).tolist() - - def _load_longcat_tokenizer( pretrained_model_name_or_path: str | Path, text_encoder_model: str | None, @@ -162,6 +159,7 @@ def _resolve_longcat_file( class LongCatAudioDiTPipeline(DiffusionPipeline): model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, @@ -188,6 +186,14 @@ def __init__( self.text_norm_feat = True self.text_add_embed = True + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + @classmethod @validate_hf_hub_args def from_pretrained( @@ -371,7 +377,7 @@ def encode_prompt(self, prompt: str | list[str], device: torch.device) -> tuple[ first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6) prompt_embeds = prompt_embeds + first_hidden lengths = attention_mask.sum(dim=1).to(device) - return prompt_embeds.float(), lengths + return prompt_embeds, lengths def prepare_latents( self, @@ -405,6 +411,7 @@ def check_inputs( prompt: list[str], negative_prompt: str | list[str] | None, output_type: str, + callback_on_step_end_tensor_inputs: list[str] | None = None, ) -> None: if len(prompt) == 0: raise ValueError("`prompt` must contain at least one prompt.") @@ -412,6 +419,14 @@ def check_inputs( if output_type not in {"np", "pt", "latent"}: raise ValueError(f"Unsupported output_type: {output_type}") + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if negative_prompt is not None and not isinstance(negative_prompt, str): negative_prompt = list(negative_prompt) if len(negative_prompt) != len(prompt): @@ -431,6 +446,8 @@ def __call__( generator: torch.Generator | list[torch.Generator] | None = None, output_type: str = "np", return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], ): r""" Function invoked when calling the pipeline for generation. @@ -442,11 +459,16 @@ def __call__( Target audio duration in seconds. Ignored when `latents` is provided. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents of shape `(batch_size, duration, latent_dim)`. - num_inference_steps (`int`, defaults to 16): Number of denoising steps. Values below 2 are promoted to 2. + num_inference_steps (`int`, defaults to 16): Number of denoising steps. guidance_scale (`float`, defaults to 4.0): Guidance scale for classifier-free guidance. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): Random generator(s). output_type (`str`, defaults to `"np"`): Output format: `"np"`, `"pt"`, or `"latent"`. return_dict (`bool`, defaults to `True`): Whether to return `AudioPipelineOutput`. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step with the pipeline, step index, timestep, and tensor + inputs specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, defaults to `["latents"]`): + Tensor inputs passed to `callback_on_step_end`. """ if prompt is None: prompt = [] @@ -454,8 +476,9 @@ def __call__( prompt = [prompt] else: prompt = list(prompt) - self.check_inputs(prompt, negative_prompt, output_type) + self.check_inputs(prompt, negative_prompt, output_type, callback_on_step_end_tensor_inputs) batch_size = len(prompt) + self._guidance_scale = guidance_scale device = self._execution_device normalized_prompts = [_normalize_text(text) for text in prompt] @@ -469,69 +492,77 @@ def __call__( if latents is None: duration = max(1, min(duration, max_duration)) - text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) + prompt_embeds, prompt_embeds_len = self.encode_prompt(normalized_prompts, device) duration_tensor = torch.full((batch_size,), duration, device=device, dtype=torch.long) mask = _lens_to_mask(duration_tensor) - text_mask = _lens_to_mask(text_condition_len, length=text_condition.shape[1]) + text_mask = _lens_to_mask(prompt_embeds_len, length=prompt_embeds.shape[1]) if negative_prompt is None: - neg_text = torch.zeros_like(text_condition) - neg_text_len = text_condition_len - neg_text_mask = text_mask + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_prompt_embeds_len = prompt_embeds_len + negative_prompt_embeds_mask = text_mask else: if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] * batch_size else: negative_prompt = list(negative_prompt) - neg_text, neg_text_len = self.encode_prompt(negative_prompt, device) - neg_text_mask = _lens_to_mask(neg_text_len, length=neg_text.shape[1]) + negative_prompt_embeds, negative_prompt_embeds_len = self.encode_prompt(negative_prompt, device) + negative_prompt_embeds_mask = _lens_to_mask( + negative_prompt_embeds_len, length=negative_prompt_embeds.shape[1] + ) - latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=text_condition.dtype) + latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=prompt_embeds.dtype) latents = self.prepare_latents( - batch_size, duration, device, text_condition.dtype, generator=generator, latents=latents + batch_size, duration, device, prompt_embeds.dtype, generator=generator, latents=latents ) - if num_inference_steps < 2: - logger.warning("`num_inference_steps`=%s is not supported; using 2 instead.", num_inference_steps) - num_inference_steps = 2 + if num_inference_steps < 1: + raise ValueError("num_inference_steps must be a positive integer.") - self.scheduler.set_timesteps( - sigmas=_get_uniform_flow_match_scheduler_sigmas(num_inference_steps), - device=device, - ) + sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps, dtype=torch.float32).tolist() + self.scheduler.set_timesteps(sigmas=sigmas, device=device) self.scheduler.set_begin_index(0) timesteps = self.scheduler.timesteps - sample = latents + self._num_timesteps = len(timesteps) - for t in timesteps: - curr_t = (t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=text_condition.dtype) + for i, t in enumerate(timesteps): + curr_t = (t / self.scheduler.config.num_train_timesteps).expand(batch_size).to(dtype=prompt_embeds.dtype) pred = self.transformer( - hidden_states=sample, - encoder_hidden_states=text_condition, + hidden_states=latents, + encoder_hidden_states=prompt_embeds, encoder_attention_mask=text_mask, timestep=curr_t, attention_mask=mask, latent_cond=latent_cond, ).sample - if guidance_scale > 1.0: + if self.guidance_scale > 1.0: null_pred = self.transformer( - hidden_states=sample, - encoder_hidden_states=neg_text, - encoder_attention_mask=neg_text_mask, + hidden_states=latents, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_embeds_mask, timestep=curr_t, attention_mask=mask, latent_cond=latent_cond, ).sample - pred = null_pred + (pred - null_pred) * guidance_scale - sample = self.scheduler.step(pred, t, sample, return_dict=False)[0] + pred = null_pred + (pred - null_pred) * self.guidance_scale + latents = self.scheduler.step(pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) if output_type == "latent": - if not return_dict: - return (sample,) - return AudioPipelineOutput(audios=sample) + waveform = latents + else: + waveform = self.vae.decode(latents.permute(0, 2, 1)).sample + if output_type == "np": + waveform = waveform.cpu().float().numpy() - waveform = self.vae.decode(sample.permute(0, 2, 1)).sample - if output_type == "np": - waveform = waveform.cpu().float().numpy() + self.maybe_free_model_hooks() if not return_dict: return (waveform,) diff --git a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py index b1767693faa3..b418a3068449 100644 --- a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py +++ b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py @@ -87,20 +87,14 @@ class TestLongCatAudioDiTTransformer(LongCatAudioDiTTransformerTesterConfig, Mod class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): def test_layerwise_casting_memory(self): - pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.") - - def test_layerwise_casting_training(self): - pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.") - - def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): pytest.skip( - "LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet." + "LongCatAudioDiTTransformer tiny test config does not provide stable layerwise casting peak memory " + "coverage." ) class TestLongCatAudioDiTTransformerCompile(LongCatAudioDiTTransformerTesterConfig, TorchCompileTesterMixin): - def test_torch_compile_repeated_blocks(self): - pytest.skip("LongCatAudioDiTTransformer does not define repeated blocks for regional compilation.") + pass class TestLongCatAudioDiTTransformerAttention(LongCatAudioDiTTransformerTesterConfig, AttentionTesterMixin): diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py index f94f5ab20226..2c26475c7af2 100644 --- a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -19,7 +19,7 @@ import torch from safetensors.torch import save_file -from transformers import UMT5Config, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5Config, UMT5EncoderModel from diffusers import ( FlowMatchEulerDiscreteScheduler, @@ -27,10 +27,6 @@ LongCatAudioDiTTransformer, LongCatAudioDiTVae, ) -from diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit import ( - _get_uniform_flow_match_scheduler_sigmas, -) - from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -39,23 +35,6 @@ enable_full_determinism() -class DummyTokenizer: - model_max_length = 16 - - def __call__(self, texts, padding="longest", truncation=True, max_length=None, return_tensors="pt"): - if isinstance(texts, str): - texts = [texts] - batch = len(texts) - return type( - "TokenBatch", - (), - { - "input_ids": torch.ones(batch, 4, dtype=torch.long), - "attention_mask": torch.ones(batch, 4, dtype=torch.long), - }, - ) - - class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = LongCatAudioDiTPipeline params = ( @@ -70,7 +49,10 @@ class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) - text_encoder = UMT5EncoderModel(UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=128)) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = UMT5EncoderModel( + UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=tokenizer.vocab_size) + ) transformer = LongCatAudioDiTTransformer( dit_dim=64, dit_depth=2, @@ -93,7 +75,7 @@ def get_dummy_components(self): return { "vae": vae, "text_encoder": text_encoder, - "tokenizer": DummyTokenizer(), + "tokenizer": tokenizer, "transformer": transformer, } @@ -134,16 +116,13 @@ def test_save_load_local(self): with tempfile.TemporaryDirectory() as tmp_dir: pipe.save_pretrained(tmp_dir) - reloaded = self.pipeline_class.from_pretrained(tmp_dir, tokenizer=DummyTokenizer(), local_files_only=True) + reloaded = self.pipeline_class.from_pretrained(tmp_dir, local_files_only=True) output = reloaded(**self.get_dummy_inputs(device, seed=0)).audios self.assertIsInstance(reloaded, LongCatAudioDiTPipeline) self.assertEqual(output.ndim, 3) self.assertGreater(output.shape[-1], 0) - def test_save_load_optional_components(self): - self.skipTest("LongCatAudioDiTPipeline does not define optional components.") - def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) @@ -159,12 +138,14 @@ def test_cpu_offload_forward_pass_twice(self): def test_sequential_cpu_offload_forward_pass(self): self.skipTest( - "LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test." + "LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with " + "sequential offloading." ) def test_sequential_offload_forward_pass_twice(self): self.skipTest( - "LongCatAudioDiTPipeline offload coverage is not ready for the standard PipelineTesterMixin test." + "LongCatAudioDiTPipeline uses `torch.nn.utils.weight_norm`, which is not compatible with " + "sequential offloading." ) def test_pipeline_level_group_offloading_inference(self): @@ -172,53 +153,26 @@ def test_pipeline_level_group_offloading_inference(self): "LongCatAudioDiTPipeline group offloading coverage is not ready for the standard PipelineTesterMixin test." ) - def test_pipeline_with_accelerator_device_map(self): - self.skipTest( - "LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so device-map roundtrip coverage is skipped here." - ) - - def test_save_load_float16(self): - self.skipTest( - "LongCatAudioDiTPipeline fast tests use a dummy tokenizer, so float16 reload coverage is skipped here." - ) - def test_num_images_per_prompt(self): self.skipTest("LongCatAudioDiTPipeline does not support num_images_per_prompt.") - def test_cfg(self): - self.skipTest("LongCatAudioDiTPipeline does not support generic CFG callback tests.") - - def test_callback_inputs(self): - self.skipTest("LongCatAudioDiTPipeline does not expose callback inputs.") - - def test_callback_cfg(self): - self.skipTest("LongCatAudioDiTPipeline does not expose callback CFG inputs.") - - def test_serialization_with_variants(self): - self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer that is not variant-serializable.") - - def test_loading_with_variants(self): - self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer that is not variant-serializable.") - - def test_loading_with_incorrect_variants_raises_error(self): - self.skipTest("LongCatAudioDiTPipeline fast tests use a dummy tokenizer that is not variant-serializable.") - def test_encode_prompt_works_in_isolation(self): self.skipTest("LongCatAudioDiTPipeline.encode_prompt has a custom signature.") - def test_uniform_flow_match_scheduler_grid_matches_legacy_manual_updates(self): + def test_uniform_flow_match_scheduler_grid_matches_manual_updates(self): num_inference_steps = 6 scheduler = FlowMatchEulerDiscreteScheduler(shift=1.0, invert_sigmas=True) - scheduler.set_timesteps(sigmas=_get_uniform_flow_match_scheduler_sigmas(num_inference_steps), device="cpu") + sigmas = torch.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps, dtype=torch.float32).tolist() + scheduler.set_timesteps(sigmas=sigmas, device="cpu") - expected_timesteps = torch.linspace(0, 1, num_inference_steps, dtype=torch.float32)[:-1] + expected_grid = torch.linspace(0, 1, num_inference_steps + 1, dtype=torch.float32) actual_timesteps = scheduler.timesteps / scheduler.config.num_train_timesteps - self.assertTrue(torch.allclose(actual_timesteps, expected_timesteps, atol=1e-6, rtol=0)) + self.assertTrue(torch.allclose(actual_timesteps, expected_grid[:-1], atol=1e-6, rtol=0)) sample = torch.zeros(1, 2, 3) model_output = torch.ones_like(sample) expected = sample.clone() - for t0, t1, scheduler_t in zip(expected_timesteps[:-1], expected_timesteps[1:], scheduler.timesteps): + for t0, t1, scheduler_t in zip(expected_grid[:-1], expected_grid[1:], scheduler.timesteps): expected = expected + model_output * (t1 - t0) sample = scheduler.step(model_output, scheduler_t, sample, return_dict=False)[0] @@ -226,8 +180,6 @@ def test_uniform_flow_match_scheduler_grid_matches_legacy_manual_updates(self): def test_from_pretrained_local_dir(self): import tempfile - from unittest.mock import patch - device = "cpu" components = self.get_dummy_components() text_encoder = components["text_encoder"] @@ -237,6 +189,7 @@ def test_from_pretrained_local_dir(self): with tempfile.TemporaryDirectory() as tmp_dir: model_dir = Path(tmp_dir) / "longcat-audio-dit" model_dir.mkdir() + components["tokenizer"].save_pretrained(model_dir / "tokenizer") config = { "dit_dim": 64, @@ -275,11 +228,7 @@ def test_from_pretrained_local_dir(self): state_dict.update({f"vae.{k}": v for k, v in vae.state_dict().items()}) save_file(state_dict, model_dir / "model.safetensors") - with patch( - "diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit.AutoTokenizer.from_pretrained", - return_value=DummyTokenizer(), - ): - pipe = LongCatAudioDiTPipeline.from_pretrained(model_dir, local_files_only=True) + pipe = LongCatAudioDiTPipeline.from_pretrained(model_dir, local_files_only=True) output = pipe(**self.get_dummy_inputs(device, seed=0)).audios