diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 372ce4f76e91..cb4c609915ee 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -31,6 +31,18 @@ class TransformerBlockMetadata: _cls: Type = None _cached_parameter_indices: dict[str, int] = None + def _register(self, cls): + """Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`. + + Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that opt into the decorator + pattern (e.g. Flux). Writes directly to the registry dict instead of going through + ``TransformerBlockRegistry.register`` so we don't trigger the lazy bulk-init while the decorated class's module + is mid-import (the bulk-init imports from the same module → circular). + """ + self._cls = cls + cls._block_metadata = self + TransformerBlockRegistry._registry[cls] = self + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): kwargs = kwargs or {} if identifier in kwargs: @@ -107,8 +119,8 @@ def _register(cls): def _register_attention_processors_metadata(): from ..models.attention_processor import AttnProcessor2_0 + from ..models.transformers.flux import FluxAttnProcessor from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor - from ..models.transformers.transformer_flux import FluxAttnProcessor from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0 from ..models.transformers.transformer_wan import WanAttnProcessor2_0 @@ -172,9 +184,9 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): from ..models.attention import BasicTransformerBlock, JointTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock + from ..models.transformers.flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock - from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..models.transformers.transformer_hunyuan_video import ( HunyuanVideoSingleTransformerBlock, HunyuanVideoTokenReplaceSingleTransformerBlock, diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py new file mode 100644 index 000000000000..f10faeba73a1 --- /dev/null +++ b/src/diffusers/loaders/lora.py @@ -0,0 +1,997 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import functools +import json +import os +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Literal, Optional, Set, Union + +import safetensors +import torch +from huggingface_hub import model_info +from huggingface_hub.constants import HF_HUB_OFFLINE + +from ..hooks.group_offloading import ( + _GROUP_OFFLOADING, + _LAYER_EXECUTION_TRACKER, + _LAZY_PREFETCH_GROUP_OFFLOADING, + _apply_group_offloading, + _get_top_level_group_offload_hook, + _maybe_remove_and_reapply_group_offloading, +) +from ..hooks.hooks import HookRegistry +from ..models.model_loading_utils import load_state_dict +from ..utils import ( + HUB_KWARGS, + USE_PEFT_BACKEND, + _get_model_file, + delete_adapter_layers, + deprecate, + get_adapter_name, + is_accelerate_available, + is_peft_available, + is_peft_version, + logging, + recurse_remove_peft_layers, + set_weights_and_activate_adapters, +) +from ..utils.state_dict_utils import _load_sft_state_dict_metadata +from .unet_loader_utils import _maybe_expand_lora_scales + + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, add_hook_to_module, remove_hook_from_module + + +if is_peft_available(): + from peft import LoraConfig, PeftConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import get_peft_model_state_dict + from peft.utils.hotswap import ( + check_hotswap_configs_compatible, + hotswap_adapter_from_state_dict, + prepare_model_for_compiled_hotswap, + ) + + +logger = logging.get_logger(__name__) + + +# Minimum PEFT version this mixin relies on. Bumping this lets us delete the +# version-fallback branches scattered through the methods (DoRA, lora_bias, +# hotswap, set_adapter hasattr, etc.). +_MIN_PEFT_VERSION_FOR_LORA = "0.14.1" +_HAS_REQUIRED_PEFT = USE_PEFT_BACKEND and is_peft_version(">=", _MIN_PEFT_VERSION_FOR_LORA) + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" + + +def _normalize_lora_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: + """Rewrite ``.lora_down/.lora_up`` (kohya-ish) suffixes to ``.lora_A/.lora_B`` (diffusers). + + Universal — every LoRA state dict goes through this regardless of model. Module-level so both :class:`LoRAHandler` + (in its ``map_to_diffusers`` dispatcher) and :class:`LoRAModelMixin` (as a public ``normalize_lora_suffixes`` + utility) can call it without circular references. + """ + out: Dict[str, "torch.Tensor"] = {} + for k, v in state_dict.items(): + new_k = ( + k.replace(".lora_down.weight", ".lora_A.weight") + .replace(".lora_up.weight", ".lora_B.weight") + .replace(".down.weight", ".lora_A.weight") + .replace(".up.weight", ".lora_B.weight") + ) + out[new_k] = v + return out + + +@dataclass +class LoRAHandler: + """Composition-style holder for a model class's LoRA conversion configuration. + + Attached as the ``_lora`` class attribute on :class:`LoRAModelMixin` (overridden per-model). Holds the per-model + foreign-format conversion data. Public conversion utilities (``normalize_lora_suffixes``, ``detect_lora_format``) + live on :class:`LoRAModelMixin` and read from this handler. + + Attributes: + format_keys: Map of format name (``"kohya"``, ``"xlabs"``, ...) to identifying key substrings. The first + format whose substrings appear in the state dict wins. + map_lora_to_diffusers_fn: Callable ``(state_dict, **kwargs) -> state_dict`` that rewrites foreign-format + keys to diffusers naming. ``None`` for models that only ingest diffusers-native LoRAs. + """ + + format_keys: Dict[str, Set[str]] = field(default_factory=dict) + map_lora_to_diffusers_fn: Optional[Callable[..., Dict[str, "torch.Tensor"]]] = None + + def map_to_diffusers(self, state_dict: Dict[str, "torch.Tensor"], **kwargs) -> Dict[str, "torch.Tensor"]: + """Run the per-model converter (or pass through if none is registered). + + Callers are expected to call :meth:`LoRAModelMixin.normalize_lora_suffixes` separately before this — the + kohya-style suffix normalization is universal and isn't this handler's responsibility. + """ + if self.map_lora_to_diffusers_fn is None: + return state_dict + return self.map_lora_to_diffusers_fn(state_dict, **kwargs) + + +# Per-class hook for expanding adapter weights before activation. Models that need +# expansion (currently only UNet variants) register here; everything else falls +# through to the identity default so new transformers don't need an entry. +_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( + lambda: (lambda model_cls, weights: weights), + { + "UNet2DConditionModel": _maybe_expand_lora_scales, + "UNetMotionModel": _maybe_expand_lora_scales, + }, +) + + +def _requires_peft(method): + """Guard a method with a uniform PEFT availability + minimum-version check.""" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if not _HAS_REQUIRED_PEFT: + raise ValueError( + f"`{method.__name__}()` requires PEFT >= {_MIN_PEFT_VERSION_FOR_LORA}. " + "Please install or upgrade PEFT: `pip install -U peft`." + ) + return method(self, *args, **kwargs) + + return wrapper + + +def _fuse_lora_apply(module, lora_scale=1.0, safe_fusing=False, adapter_names=None): + """Per-module callback for ``self.apply(...)`` in ``fuse_lora``.""" + if not isinstance(module, BaseTunerLayer): + return + if lora_scale != 1.0: + module.scale_layer(lora_scale) + module.merge(safe_merge=safe_fusing, adapter_names=adapter_names) + + +def _unfuse_lora_apply(module): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + +def _serialize_lora_adapter_metadata(peft_config): + """Convert a ``PeftConfig`` to a JSON string suitable for the safetensors metadata blob. + + PEFT configs may contain ``set`` values (which JSON can't serialize); coerce those to lists first. + """ + cfg = peft_config.to_dict() + for key, value in cfg.items(): + if isinstance(value, set): + cfg[key] = list(value) + return json.dumps(cfg, indent=2, sort_keys=True) + + +def _scope_state_dict_to_adapter(state_dict, adapter_name): + """Rewrite ``lora_A.weight`` / ``lora_B.weight`` keys to include the adapter name + (the format expected by ``hotswap_adapter_from_state_dict``).""" + out = {} + for k, v in state_dict.items(): + if k.endswith("lora_A.weight") or k.endswith("lora_B.weight"): + k = k[: -len(".weight")] + f".{adapter_name}.weight" + elif k.endswith("lora_B.bias"): # lora_bias=True option + k = k[: -len(".bias")] + f".{adapter_name}.bias" + out[k] = v + return out + + +def _split_majority_and_outliers(value_dict): + """Return ``(majority, outliers)`` for ``value_dict``. + + ``majority`` is the most common value (or the lone value if all are equal, or None for an empty dict). ``outliers`` + is a sub-dict of the items whose value differs from the majority — empty when every value matches. + """ + values = list(value_dict.values()) + if not values: + return None, {} + if len(set(values)) == 1: + return values[0], {} + majority = collections.Counter(values).most_common(1)[0][0] + return majority, {k: v for k, v in value_dict.items() if v != majority} + + +@contextmanager +def _offloading_disabled(model): + """Temporarily strip accelerate and group-offload hooks from ``model``. + + PEFT injection and weight loading mutate the model graph in ways that fight with active offload hooks (sequential + CPU offload, group offload, etc.). This context saves the hook state, removes the hooks for the duration of the + block, and restores them on exit so existing offloading config survives a LoRA load. + """ + + saved_hf_hook = None + is_sequential = False + if hasattr(model, "_hf_hook"): + hook = model._hf_hook + if isinstance(hook, CpuOffload): + saved_hf_hook = hook + elif isinstance(hook, AlignDevicesHook) or ( + hasattr(hook, "hooks") and isinstance(hook.hooks[0], AlignDevicesHook) + ): + saved_hf_hook = hook + is_sequential = True + + if saved_hf_hook is not None: + remove_hook_from_module(model, recurse=is_sequential) + + saved_group_offload_config = None + top_level_group_hook = _get_top_level_group_offload_hook(model) + if top_level_group_hook is not None: + saved_group_offload_config = top_level_group_hook.config + registry = HookRegistry.check_if_exists_or_initialize(model) + registry.remove_hook(_GROUP_OFFLOADING, recurse=True) + registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True) + registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True) + + try: + yield + finally: + if saved_hf_hook is not None: + add_hook_to_module(model, saved_hf_hook) + if saved_group_offload_config is not None: + _apply_group_offloading(model, saved_group_offload_config) + + +def _create_lora_config(state_dict, network_alphas, rank_dict, metadata=None): + """Build a PEFT ``LoraConfig`` from a LoRA state dict. + + ``metadata`` (when present) overrides the inferred kwargs entirely — used when a saved adapter shipped its own + serialized ``LoraConfig`` blob. Otherwise we infer: per-module rank / alpha values that don't match the majority go + into ``rank_pattern`` / ``alpha_pattern``; the majority becomes the global default. + """ + if metadata is not None: + return LoraConfig(**metadata) + + r, rank_outliers = _split_majority_and_outliers(rank_dict) + rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_outliers.items()} + + lora_alpha = r + alpha_pattern = {} + if network_alphas: + lora_alpha, alpha_outliers = _split_majority_and_outliers(network_alphas) + if alpha_outliers: + # PEFT-converted alpha keys (UNet / transformer LoRAs) carry ``.lora_A.``; + # raw kohya-style alphas (legacy text-encoder LoRAs) carry ``.down.``. + sample = next(iter(alpha_outliers)) + if ".lora_A." in sample: + alpha_pattern = {k.split(".lora_A.")[0].replace(".alpha", ""): v for k, v in alpha_outliers.items()} + else: + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_outliers.items()} + + lora_config_kwargs = { + "r": r, + "lora_alpha": lora_alpha, + "rank_pattern": rank_pattern, + "alpha_pattern": alpha_pattern, + "target_modules": list({name.split(".lora")[0] for name in state_dict}), + "use_dora": any("lora_magnitude_vector" in k for k in state_dict), + "lora_bias": any("lora_B" in k and k.endswith(".bias") for k in state_dict), + } + + return LoraConfig(**lora_config_kwargs) + + +def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): + if incompatible_keys is None: + return + warn_msg = "" + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model: " + f"{', '.join(lora_unexpected_keys)}. " + ) + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model: " + f"{', '.join(lora_missing_keys)}." + ) + if warn_msg: + logger.warning(warn_msg) + + +def _fetch_state_dict(pretrained_model_name_or_path_or_dict, weight_name=None, **hub_kwargs): + """Load a LoRA state dict from a path/repo/dict. + + Safetensors only — pickle (``.bin``) LoRAs are no longer supported. Re-save legacy checkpoints with + ``safetensors.torch.save_file`` or load them manually with ``torch.load`` and pass the resulting dict. + + ``hub_kwargs`` are the download / file-discovery options forwarded to ``_get_model_file`` (see ``HUB_KWARGS`` for + the canonical set). + """ + if isinstance(pretrained_model_name_or_path_or_dict, dict): + return pretrained_model_name_or_path_or_dict + + source = pretrained_model_name_or_path_or_dict + local_files_only = hub_kwargs.get("local_files_only") + name = weight_name or _best_guess_weight_name(source, ".safetensors", local_files_only) + model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) + return load_state_dict(model_file) + + +def _fetch_lora_metadata(pretrained_model_name_or_path_or_dict, weight_name=None, **hub_kwargs): + """Load LoRA adapter metadata from a safetensors file's sidecar. + + Returns ``None`` for non-safetensors sources (dicts, ``.bin`` files, missing sidecar). The hub layer caches the + file, so calling this after + """ + if isinstance(pretrained_model_name_or_path_or_dict, dict): + return None + + source = pretrained_model_name_or_path_or_dict + local_files_only = hub_kwargs.get("local_files_only") + name = weight_name or _best_guess_weight_name(source, ".safetensors", local_files_only) + if not name or not name.endswith(".safetensors"): + return None + try: + model_file = _get_model_file(source, weights_name=name, **hub_kwargs) + return _load_sft_state_dict_metadata(model_file) + except (IOError, safetensors.SafetensorError): + return None + + +def _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False +): + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return None + if os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + + # Strip non-LoRA files: scheduler/optimizer state, intermediate checkpoints. + unallowed = {"scheduler", "optimizer", "checkpoint"} + targeted_files = [f for f in targeted_files if not any(s in f for s in unallowed)] + + # Prefer the canonical filenames if present. + for canonical in (LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE): + if any(f.endswith(canonical) for f in targeted_files): + targeted_files = [f for f in targeted_files if f.endswith(canonical)] + break + + if not targeted_files: + return None + if len(targeted_files) > 1: + logger.warning( + f"Provided path contains more than one weights file in the {file_extension} format. " + f"`{targeted_files[0]}` is going to be loaded; for precise control, specify a `weight_name` " + "in `load_lora_weights`." + ) + return targeted_files[0] + + +class LoRAModelMixin: + """ + Single mixin for everything LoRA on a diffusers model: PEFT adapter lifecycle (load / fuse / unfuse / set / delete + / hotswap) plus foreign-format conversion (kohya / xlabs / bfl / kontext / etc.) into diffusers naming. + + Per-model conversion knobs live in a :class:`LoRAHandler` declared in the model's ``lora.py`` (e.g. ``FLUX_LORA``) + and assigned to the model class as ``_lora = FLUX_LORA``. The default no-op handler just normalizes + ``.lora_down/.lora_up`` → ``.lora_A/.lora_B`` suffixes and returns the state dict unchanged. + + Install the latest version of PEFT, and use this mixin to: + + - Attach new adapters in the model. + - Attach multiple adapters and iteratively activate/deactivate them. + - Activate/deactivate all adapters from the model. + - Get a list of the active adapters. + """ + + # Runtime PEFT state — set during adapter load / hotswap setup. + _hf_peft_config_loaded = False + _lora_hotswap_kwargs: Optional[dict] = None + + # Per-model LoRA conversion config. Defaults to a no-op handler (only suffix normalization, no foreign-format + # conversion). Models override by assigning ``_lora = FLUX_LORA`` (etc.) in their class body. + _lora: LoRAHandler = LoRAHandler() + + @classmethod + def _metadata(cls): + """Contribute the ``_lora`` row to :class:`ModelMetadata` when foreign formats are registered.""" + from ..models.modeling_utils import DOCS_BASE + + formats = sorted(cls._lora.format_keys) + if not formats: + return {} + return { + "_lora": ( + formats, + ", ".join(formats), + "Foreign LoRA formats this model converts to diffusers naming on load.", + f"{DOCS_BASE}/training/lora", + ) + } + + @staticmethod + def normalize_lora_suffixes(state_dict: Dict[str, "torch.Tensor"]) -> Dict[str, "torch.Tensor"]: + """Rewrite ``.lora_down/.lora_up`` (kohya-ish) suffixes to ``.lora_A/.lora_B`` (diffusers). + + Universal — applies to every LoRA state dict regardless of model. Useful as a standalone utility for callers + that want suffix normalization without running the full ``map_to_diffusers`` pipeline. + """ + return _normalize_lora_suffixes(state_dict) + + def detect_lora_format(self, state_dict: Dict[str, "torch.Tensor"]) -> Optional[str]: + """Return the foreign LoRA format name (``"kohya"`` / ``"xlabs"`` / ...) matched by ``state_dict``, + or ``None`` if no registered format matches (e.g. it's already in diffusers naming). + + Reads ``self._lora.format_keys`` (the per-model registry of identifying key substrings). + """ + format_keys = self._lora.format_keys + if not format_keys: + return None + keys = set(state_dict) + for fmt, fmt_keys in format_keys.items(): + if any(any(fk in k for k in keys) for fk in fmt_keys): + return fmt + return None + + @_requires_peft + def load_adapter( + self, + adapter, + adapter_name=None, + prefix="transformer", + hotswap: bool = False, + **kwargs, + ): + r""" + Add an adapter to the underlying model. + + ``source`` can be either: + + - A ``PeftConfig`` (e.g. ``LoraConfig``) — initializes a fresh adapter with random weights, suitable for + training. + - A repo id, local path, or pre-loaded ``state_dict`` — loads pretrained adapter weights, suitable for + inference. + + For the config path, only ``adapter_name`` is used; ``prefix``, ``hotswap``, and the download/loading kwargs + apply to the pretrained path. + """ + adapter_name = adapter_name or get_adapter_name(self) + if isinstance(adapter, PeftConfig): + return self._load_adapter_from_config(adapter, adapter_name=adapter_name) + + return self._load_adapter_from_pretrained( + adapter, adapter_name=adapter_name, prefix=prefix, hotswap=hotswap, **kwargs + ) + + def _load_adapter_from_config(self, adapter_config, adapter_name="default"): + if self._hf_peft_config_loaded and adapter_name in getattr(self, "peft_config", {}): + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is + # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here. + adapter_config.base_model_name_or_path = None + inject_adapter_in_model(adapter_config, self, adapter_name) + self._hf_peft_config_loaded = True + self.set_adapters(adapter_name) + + def _load_adapter_from_pretrained( + self, + pretrained_model_name_or_path_or_dict, + adapter_name=None, + prefix="transformer", + hotswap: bool = False, + **kwargs, + ): + r""" + Loads a LoRA adapter into the underlying model. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + prefix (`str`, *optional*): Prefix to filter the state dict. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap + metadata: + LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to + initialize `LoraConfig`. + """ + hub_kwargs = {k: kwargs.pop(k, default) for k, default in HUB_KWARGS.items()} + hub_kwargs["user_agent"] = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + weight_name = kwargs.pop("weight_name", None) + network_alphas = kwargs.pop("network_alphas", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + metadata = kwargs.pop("metadata", None) + + if isinstance(pretrained_model_name_or_path_or_dict, dict): + state_dict = pretrained_model_name_or_path_or_dict + else: + source = pretrained_model_name_or_path_or_dict + name = weight_name or _best_guess_weight_name(source, ".safetensors", hub_kwargs.get("local_files_only")) + model_file = _get_model_file(source, weights_name=name or LORA_WEIGHT_NAME_SAFE, **hub_kwargs) + state_dict = load_state_dict(model_file) + + # Universal suffix normalization first (kohya-style ``.lora_down/.lora_up`` → ``.lora_A/.lora_B``), then + # run the per-model foreign-format converter (no-op when none is registered). + state_dict = self.normalize_lora_suffixes(state_dict) + state_dict = self._lora.map_to_diffusers(state_dict) + if not state_dict: + model_class_name = self.__class__.__name__ + logger.warning( + f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " + "This is safe to ignore if LoRA state dict didn't originally have any " + f"{model_class_name} related params. You can also try specifying `prefix=None` " + "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " + "https://github.com/huggingface/diffusers/issues/new" + ) + return + + metadata = metadata or _fetch_lora_metadata( + pretrained_model_name_or_path_or_dict, weight_name=weight_name, **hub_kwargs + ) + + if network_alphas is not None and prefix is None: + raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + + if network_alphas and metadata: + raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") + + if prefix is not None: + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} + + if adapter_name in getattr(self, "peft_config", {}) and not hotswap: + raise ValueError( + f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." + ) + if adapter_name not in getattr(self, "peft_config", {}) and hotswap: + raise ValueError( + f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " + "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." + ) + + rank = {} + for key, val in state_dict.items(): + # Cannot figure out rank from lora layers that don't have at least 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: + # See https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # Disambiguates module names sharing a common prefix + # (e.g. `proj_out.weight` vs `blocks.transformer.proj_out.weight`). + rank[f"^{key}"] = val.shape[1] + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] + network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config = _create_lora_config(state_dict, network_alphas, rank, metadata=metadata) + + # Mutating the model would otherwise fight with active offload hooks; the + # context manager strips them for the duration and restores them on exit. + peft_kwargs = {"low_cpu_mem_usage": low_cpu_mem_usage} + with _offloading_disabled(self): + if hotswap: + self._hotswap_adapter(state_dict, lora_config, adapter_name) + incompatible_keys = None + + else: + incompatible_keys = self._inject_adapter(state_dict, lora_config, adapter_name, peft_kwargs) + self._maybe_apply_deferred_hotswap_prep(lora_config) + + _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) + + def _inject_adapter(self, state_dict, lora_config, adapter_name, peft_kwargs): + """Inject a new adapter into ``self`` and load its weights. + + Returns the ``incompatible_keys`` reported by ``set_peft_model_state_dict``. On failure, rolls back any partial + peft_config / adapter modules so the model is left in its prior state. + """ + try: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + self._hf_peft_config_loaded = True + + return incompatible_keys + + except Exception as e: + self._rollback_adapter(adapter_name, e) + raise + + def _maybe_apply_deferred_hotswap_prep(self, lora_config): + """If ``enable_lora_hotswap`` was called before the first adapter was loaded, + we deferred ``prepare_model_for_compiled_hotswap`` until LoRA layers existed. Apply it now (after a successful + inject) and clear the stash so it only fires once.""" + if self._lora_hotswap_kwargs is None: + return + prepare_model_for_compiled_hotswap(self, config=lora_config, **self._lora_hotswap_kwargs) + self._lora_hotswap_kwargs = None + + def _hotswap_adapter(self, state_dict, lora_config, adapter_name): + """Replace the weights of an already-loaded adapter in-place. + + ``hotswap_adapter_from_state_dict`` raises on incompatible keys; reaching the end of this function means the + swap succeeded. + """ + state_dict = _scope_state_dict_to_adapter(state_dict, adapter_name) + check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + try: + hotswap_adapter_from_state_dict( + model=self, state_dict=state_dict, adapter_name=adapter_name, config=lora_config + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}") + self._rollback_adapter(adapter_name, e) + raise + + def _rollback_adapter(self, adapter_name, error): + """Remove ``adapter_name`` from ``self`` so failed loads don't leave partial state.""" + if hasattr(self, "peft_config"): + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + for active_adapter in module.active_adapters: + if adapter_name in active_adapter: + module.delete_adapter(adapter_name) + self.peft_config.pop(adapter_name, None) + + logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{error}") + + @_requires_peft + def save_adapter( + self, + save_directory, + adapter_name: str = "default", + upcast_before_saving: bool = False, + safe_serialization: bool = True, + weight_name: Optional[str] = None, + ): + """Save the LoRA parameters corresponding to the underlying model. + + Args: + save_directory: Directory to save LoRA parameters to. Created if missing. + adapter_name: Name of the adapter to serialize. Useful when the model has + multiple adapters loaded. + upcast_before_saving: Whether to cast the underlying model to ``torch.float32`` + before serialization. + safe_serialization: Save with ``safetensors`` (default) or pickled torch save. + weight_name: Override the default filename. + """ + if adapter_name not in getattr(self, "peft_config", {}): + raise ValueError(f"Adapter name {adapter_name} not found in the model.") + if os.path.isfile(save_directory): + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + + state_dict = get_peft_model_state_dict( + self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name + ) + + os.makedirs(save_directory, exist_ok=True) + weight_name = weight_name or (LORA_WEIGHT_NAME_SAFE if safe_serialization else LORA_WEIGHT_NAME) + save_path = Path(save_directory, weight_name).as_posix() + + if safe_serialization: + metadata = { + "format": "pt", + LORA_ADAPTER_METADATA_KEY: _serialize_lora_adapter_metadata(self.peft_config[adapter_name]), + } + safetensors.torch.save_file(state_dict, save_path, metadata=metadata) + else: + torch.save(state_dict, save_path) + + logger.info(f"Model weights saved in {save_path}") + + def save_lora_adapter(self, *args, **kwargs): + """Deprecated alias for :meth:`save_adapter`.""" + deprecate( + "save_lora_adapter", + "1.0.0", + "`save_lora_adapter` is deprecated; use `save_adapter` instead.", + ) + return self.save_adapter(*args, **kwargs) + + @_requires_peft + def set_adapters( + self, + adapter_names: Union[List[str], str], + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, + ): + """ + Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.). + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + adapter_weights (`Union[List[float], float]`, *optional*): + The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the + adapters. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + ``` + """ + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # Expand weights into a list, one entry per adapter + # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." + ) + + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + # e.g. [{...}, 7] -> [{expanded dict...}, 7] + scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__] + weights = scale_expansion_fn(self, weights) + + set_weights_and_activate_adapters(self, adapter_names, weights) + + def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: + """Deprecated alias for :meth:`load_adapter` with a ``PeftConfig``.""" + deprecate( + "add_adapter", + "1.0.0", + "`add_adapter` is deprecated; use `load_adapter(adapter_config)` instead.", + ) + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + return self.load_adapter(adapter_config, adapter_name=adapter_name) + + def load_lora_adapter( + self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs + ): + """Deprecated alias for :meth:`load_adapter`.""" + deprecate( + "load_lora_adapter", + "1.0.0", + "`load_lora_adapter` is deprecated; use `load_adapter` instead.", + ) + return self.load_adapter(pretrained_model_name_or_path_or_dict, prefix=prefix, hotswap=hotswap, **kwargs) + + def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: + """Deprecated alias for :meth:`set_adapters`. + + Note: ``set_adapters`` resets the per-adapter scale to ``1.0`` when no weights are passed; the original + ``set_adapter`` left the previous scale untouched. + """ + deprecate( + "set_adapter", + "1.0.0", + "`set_adapter` is deprecated; use `set_adapters` instead. " + "Note that `set_adapters(name)` resets the per-adapter scale to 1.0; " + "pass `weights=...` to control it explicitly.", + ) + return self.set_adapters(adapter_name) + + @_requires_peft + def disable_adapters(self) -> None: + r""" + Disable all adapters attached to the model and fallback to inference with the base model only. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + module.enable_adapters(enabled=False) + + @_requires_peft + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of + adapters to enable. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + [documentation](https://huggingface.co/docs/peft). + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + module.enable_adapters(enabled=True) + + @_requires_peft + def active_adapters(self) -> List[str]: + """Return the sorted union of active adapter names across all PEFT layers.""" + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + active = set() + for module in self.modules(): + if not isinstance(module, BaseTunerLayer): + continue + names = module.active_adapter + active.update([names] if isinstance(names, str) else names) + return sorted(active) + + @_requires_peft + def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): + """Merge LoRA adapter weights into the base model in-place.""" + self.apply( + partial(_fuse_lora_apply, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + ) + + @_requires_peft + def unfuse_lora(self): + """Reverse of :meth:`fuse_lora` — unmerge LoRA weights from the base model.""" + self.apply(_unfuse_lora_apply) + + @_requires_peft + def delete_adapters(self, adapter_names: Optional[Union[List[str], str]] = None): + """Remove adapter(s) from the model. + + Pass specific names to delete those adapters only — the PEFT wrapper layers (``lora_A`` / ``lora_B`` modules) + stay in place, so a subsequent :meth:`load_adapter` call can reuse them without re-injecting. + + Pass ``None`` (the default) to remove every adapter *and* strip the wrapper layers themselves, returning the + model to its pre-LoRA state. + """ + if adapter_names is None: + recurse_remove_peft_layers(self) + if hasattr(self, "peft_config"): + del self.peft_config + + self._hf_peft_config_loaded = False + + else: + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for adapter_name in adapter_names: + delete_adapter_layers(self, adapter_name) + if hasattr(self, "peft_config"): + self.peft_config.pop(adapter_name, None) + + # In-place mutation invalidates group-offload tensor refs; refresh them. + _maybe_remove_and_reapply_group_offloading(self) + + def unload_lora(self): + """Deprecated alias for :meth:`delete_adapters` (with no arguments).""" + deprecate( + "unload_lora", + "1.0.0", + "`unload_lora` is deprecated; use `delete_adapters()` (no args) for the same teardown.", + ) + return self.delete_adapters() + + def enable_lora_hotswap( + self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" + ) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`, *optional*, defaults to `128`): + The highest rank among all the adapters that will be loaded. + + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + if check_compiled not in ("error", "warn", "ignore"): + raise ValueError( + f"check_compiled should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." + ) + if getattr(self, "peft_config", {}): + if check_compiled == "error": + raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") + if check_compiled == "warn": + logger.warning( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + self._lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} + + +# Back-compat alias. Old name from the PEFT-only era; prefer ``LoRAModelMixin``. +PeftAdapterMixin = LoRAModelMixin diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 43fc8d897fe6..5dbd8b1b60b8 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -38,7 +38,6 @@ convert_controlnet_checkpoint, convert_cosmos_transformer_checkpoint_to_diffusers, convert_flux2_transformer_checkpoint_to_diffusers, - convert_flux_transformer_checkpoint_to_diffusers, convert_hidream_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers, convert_ldm_unet_checkpoint, @@ -110,10 +109,6 @@ "SparseControlNetModel": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, - "FluxTransformer2DModel": { - "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, - "default_subfolder": "transformer", - }, "ChromaTransformer2DModel": { "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 98b9e8266506..170dca27ae10 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2244,203 +2244,6 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): return converted_state_dict -def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 - num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 - mlp_ratio = 4.0 - inner_dim = 3072 - - # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; - # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation - def swap_scale_shift(weight): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight - - ## time_text_embed.timestep_embedder <- time_in - converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "time_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") - converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "time_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") - - ## time_text_embed.text_embedder <- vector_in - converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") - converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") - converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( - "vector_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") - - # guidance - has_guidance = any("guidance" in k for k in checkpoint) - if has_guidance: - converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( - "guidance_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( - "guidance_in.in_layer.bias" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( - "guidance_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( - "guidance_in.out_layer.bias" - ) - - # context_embedder - converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") - converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") - - # x_embedder - converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") - converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") - - # double transformer blocks - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - # norms. - ## norm1 - converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.bias" - ) - ## norm1_context - converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.bias" - ) - # Q, K, V - sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) - context_q, context_k, context_v = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 - ) - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 - ) - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) - # qk_norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.key_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.key_norm.scale" - ) - # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") - converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") - converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.bias" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.bias" - ) - # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.bias" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.bias" - ) - - # single transformer blocks - for i in range(num_single_layers): - block_prefix = f"single_transformer_blocks.{i}." - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.bias" - ) - # Q, K, V, mlp - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) - q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) - q_bias, k_bias, v_bias, mlp_bias = torch.split( - checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) - converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) - # qk norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.key_norm.scale" - ) - # output projections. - converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") - converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") - - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.weight") - ) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.bias") - ) - - return converted_state_dict - - def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key} diff --git a/src/diffusers/loaders/weight_mapping.py b/src/diffusers/loaders/weight_mapping.py new file mode 100644 index 000000000000..29e72b2c26ad --- /dev/null +++ b/src/diffusers/loaders/weight_mapping.py @@ -0,0 +1,210 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reusable infrastructure for converting model checkpoints between original and diffusers naming conventions. + +A model declares its mapping in a :class:`WeightMappingHandler` instance (typically in its ``weight_mapping.py`` +module) and assigns it to the class as ``_weight_mapping = FLUX_WEIGHT_MAPPING``. Internal call sites go through +``cls._weight_mapping.X`` (e.g. ``cls._weight_mapping.normalize_state_dict_keys(state_dict)``) instead of flattening +the methods onto the model class itself. +""" + +from dataclasses import dataclass, field +from typing import Callable, Optional + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +# Foreign key prefixes seen across multiple model families' single-file checkpoints. Stripping these is +# universally safe (no model uses them as native diffusers keys), so the handler defaults to removing them on +# every load. Models with additional, family-specific prefixes can extend or override +# ``prefixes_to_remove`` on their handler. +PREFIXES_TO_REMOVE: list[str] = [ + "model.diffusion_model.", +] + + +@dataclass +class WeightMappingHandler: + """Composition-style holder for a model class's weight-mapping configuration and helpers. + + Attached as the ``_weight_mapping`` class attribute on :class:`ModelMixin` (overridden per-model). Owns all the + data (available configs, prefixes, rename patterns, converter callables) and all the methods (rename, detect, + normalize) for single-file checkpoint loading. Internal callers reach it via ``cls._weight_mapping.X``. + + Attributes: + original_format_keys: Distinctive keys whose presence indicates the state_dict is in the original + (pre-diffusers) format. Used by :meth:`is_original_format` to decide whether key conversion is needed. + prefixes_to_remove: Foreign prefixes (e.g. ``["model.diffusion_model."]``) the handler will strip via + :meth:`normalize_state_dict_keys`. Defaults to the shared :data:`PREFIXES_TO_REMOVE` list — most models + only need that. Extend it for family-specific wrappers; prefix-only models can rely on the default and skip + registering a ``map_to_diffusers_fn`` callable. + available_configs: + Map of short config name to hub repo id (e.g. ``{"flux-dev": "black-forest-labs/FLUX.1-dev"}``). + default_config: Config name (key into ``available_configs``) used when ``detect_config_fn`` is + unregistered or returns ``None``. + default_subfolder: Default ``subfolder`` to use when fetching configs (e.g. ``"transformer"``). + map_to_diffusers_fn: Callable ``(state_dict, **kwargs) -> state_dict`` performing full key conversion. + ``None`` for prefix-only models. + map_from_diffusers_fn: Reverse callable (diffusers → original format). + detect_config_fn: ``(handler, state_dict) -> Optional[str]`` returning a config name from + ``available_configs``, or ``None`` to fall back to ``default_config``. + """ + + original_format_keys: set = field(default_factory=set) + prefixes_to_remove: list = field(default_factory=lambda: list(PREFIXES_TO_REMOVE)) + available_configs: dict = field(default_factory=dict) + default_config: Optional[str] = None + default_subfolder: str = "transformer" + map_to_diffusers_fn: Optional[Callable] = None + map_from_diffusers_fn: Optional[Callable] = None + detect_config_fn: Optional[Callable] = None + + # ---- single-file capability ---- + + @property + def supports_single_file(self) -> bool: + """Whether ``from_single_file(path)`` works for this model with no extra arguments. + + Requires ``default_config`` to be set so config resolution always succeeds (with or without a successful + ``detect_config_fn`` call). Models that declare only ``available_configs`` still load via + ``from_single_file(path, config=...)``, but they don't auto-resolve and so don't count as supporting. Key + normalization is all no-op-safe; the architecture-resolution step is the only hard requirement. + """ + return self.default_config is not None + + # ---- key utilities ---- + + @staticmethod + def rename_key(key: str, patterns: dict) -> str: + """Apply rename patterns to a key (first match wins per substring).""" + for old, new in patterns.items(): + key = key.replace(old, new) + return key + + def is_original_format(self, state_dict: dict) -> bool: + """Check if state_dict is in the original (pre-diffusers) format by presence of a known marker key. + + Returns ``True`` only when a registered ``original_format_keys`` entry is observed in the state_dict. Returning + ``False`` means "no positive evidence of original format" — empty / unrelated / unknown state_dicts all fall + here. Callers treat ``False`` as "proceed with diffusers-native keys." + """ + if not self.original_format_keys: + return False + return bool(self.original_format_keys & set(state_dict.keys())) + + def normalize_state_dict_keys(self, state_dict: dict) -> dict: + """Strip known foreign prefixes (e.g. ``model.diffusion_model.``) from state_dict keys.""" + if not self.prefixes_to_remove: + return state_dict + result = {} + for key, value in state_dict.items(): + new_key = key + for prefix in self.prefixes_to_remove: + if key.startswith(prefix): + new_key = key[len(prefix) :] + break + result[new_key] = value + return result + + # ---- config resolution ---- + + def detect_config(self, state_dict: dict) -> Optional[str]: + """Detect which config name from ``available_configs`` matches this state_dict. + + Dispatches to ``self.detect_config_fn(self, state_dict)``. If unregistered, returns ``None`` so the caller can + fall back to ``self.default_config``. + """ + if self.detect_config_fn is None: + return None + return self.detect_config_fn(self, state_dict) + + def get_model_config(self, state_dict: dict) -> str: + """Resolve the hub repo id whose config best matches this state_dict. + + Resolution order: + 1. Run ``detect_config(state_dict)`` (if a detector is registered). + 2. If detection returns ``None``, fall back to ``default_config`` and warn (since the user is now getting a + config that may not match the state_dict shape). + 3. Look up the chosen name in ``available_configs`` to get the hub repo id. + """ + detected = self.detect_config(state_dict) + if detected is None and self.default_config is not None and self.detect_config_fn is not None: + logger.warning( + f"Could not auto-detect a config for this state_dict; falling back to default_config=" + f"'{self.default_config}' ({self.available_configs.get(self.default_config)}). " + f"If this is the wrong architecture, pass `config=` to `from_single_file(...)` " + f"explicitly. Known configs: {sorted(self.available_configs)}." + ) + config_name = detected or self.default_config + if config_name is None: + available = sorted(self.available_configs) or "" + has_detector = self.detect_config_fn is not None + raise ValueError( + "Could not determine which config to load for this state_dict.\n" + "\n" + f" Detection: {'registered, but returned None for this state_dict' if has_detector else 'no detect_config_fn registered'}\n" + " Default config: not set\n" + f" Available configs: {available}\n" + "\n" + "To fix this, either:\n" + ' - pass `config=""` to `from_single_file(...)` to skip auto-detection, OR\n' + " - update the model's `WeightMappingHandler` to set `detect_config_fn` (returns a name from " + "`available_configs`), and/or set `default_config` to a name in `available_configs`." + ) + if config_name not in self.available_configs: + raise ValueError( + f"Resolved config name '{config_name}' is not a key of `available_configs` " + f"(available: {sorted(self.available_configs)})." + ) + return self.available_configs[config_name] + + # ---- conversion ---- + + def map_to_diffusers(self, state_dict: dict, **kwargs) -> dict: + """Convert state_dict from original format to diffusers format. + + No-op (returns ``state_dict`` unchanged) if no converter callable is registered; callers are expected to use + the prefix-only path (via :meth:`normalize_state_dict_keys`) in that case. + """ + if self.map_to_diffusers_fn is None: + return state_dict + return self.map_to_diffusers_fn(state_dict, **kwargs) + + def maybe_convert_state_dict(self, model, state_dict: dict) -> dict: + """Bring ``state_dict`` to diffusers naming if it isn't already. Two phases: + + 1. :meth:`normalize_state_dict_keys` — strip known prefixes (idempotent; no-op if none registered). + 2. :meth:`map_to_diffusers` — full key conversion, only invoked if step 1 alone didn't make the keys match the + model's. Skipped (no-op) if no converter callable was registered. + + Idempotent overall: calling twice produces the same result as calling once. + """ + state_dict = self.normalize_state_dict_keys(state_dict) + model_keys = set(model.state_dict().keys()) + state_dict_keys = set(state_dict.keys()) + # If the model's keys are a (strict) subset of the state_dict's, the rest is extras we'll surface later + # via the missing/unexpected keys report — but no key-renaming pass is needed. + if model_keys.issubset(state_dict_keys): + return state_dict + return self.map_to_diffusers(state_dict) + + def map_from_diffusers(self, state_dict: dict, **kwargs) -> dict: + """Convert state_dict from diffusers format to original format.""" + if self.map_from_diffusers_fn is None: + raise NotImplementedError("No `map_from_diffusers_fn` callable registered for this model.") + return self.map_from_diffusers_fn(state_dict, **kwargs) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 36d0893734c7..f4cd1ff6856b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -122,6 +122,8 @@ class AttentionModuleMixin: _default_processor_cls = None _available_processors = [] _supports_qkv_fusion = True + _parallel_config = None + fused_projections = False def set_processor(self, processor: AttentionProcessor) -> None: diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 161fcf426f21..5baadf8f760c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -31,6 +31,21 @@ class CacheMixin: """ _cache_config = None + _supports_cache = True + + @classmethod + def _metadata(cls): + """Contribute the ``_supports_cache`` row to :class:`ModelMetadata` for models inheriting :class:`CacheMixin`.""" + from .modeling_utils import DOCS_BASE + + return { + "_supports_cache": ( + True, + "True", + "Supports caching techniques (PAB / FasterCache / FirstBlockCache) via `enable_cache`.", + f"{DOCS_BASE}/optimization/cache", + ) + } @property def is_cache_enabled(self) -> bool: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..0c14f4ca1077 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -22,9 +22,10 @@ import os import re import shutil +import sys import tempfile from collections import OrderedDict -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack, contextmanager, nullcontext from functools import wraps from pathlib import Path from typing import Any, Callable, ContextManager, Type @@ -38,6 +39,8 @@ from typing_extensions import Self from .. import __version__ +from ..configuration_utils import ConfigMixin +from ..loaders.weight_mapping import WeightMappingHandler from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -45,6 +48,7 @@ FLASHPACK_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, HF_ENABLE_PARALLEL_LOADING, + HUB_KWARGS, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -229,34 +233,229 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) -class ModelMixin(torch.nn.Module, PushToHubMixin): +# Base URL for the diffusers docs. Used by each mixin's ``_metadata`` classmethod to build per-capability +# docs links. Adjust as the docs layout evolves — links are reader hints, nothing depends on them +# programmatically. +DOCS_BASE = "https://huggingface.co/docs/diffusers/main/en" + + +class ModelMetadata: + """Snapshot of a model class's feature attributes. + + Constructed by :meth:`ModelMixin.metadata` — walks ``cls.__mro__`` collecting rows from each mixin's ``_metadata`` + classmethod and exposes the raw values as attributes: + + >>> meta = FluxTransformer2DModel.metadata() >>> meta._supports_ip_adapter True >>> meta._lora ['bfl', 'kohya', + 'kontext', 'xlabs'] >>> '_supports_cache' in meta True + + ``repr(meta)`` (and ``print(meta)``) render a formatted table. Call :meth:`describe` to print the verbose variant + with descriptions and docs links. + """ + + # Internal storage is name-mangled (``self.__rows`` → ``self._ModelMetadata__rows``) so ``dir(meta)`` and + # tab-completion show only the feature attributes + ``describe``, not the snapshot's bookkeeping fields. + def __init__(self, rows: dict[str, tuple[Any, str, str, str]], cls_name: str): + self.__rows = rows + self.__cls_name = cls_name + for attr, (value, _display, _doc, _link) in rows.items(): + setattr(self, attr, value) + + def __iter__(self): + return iter(self.__rows) + + def __contains__(self, key): + return key in self.__rows + + def __len__(self): + return len(self.__rows) + + def __dir__(self): + return list(self.__rows) + ["describe", "keys", "values", "items"] + + def keys(self): + """Names of the feature attributes this snapshot exposes.""" + return self.__rows.keys() + + def values(self): + """Raw values for each feature attribute (same as ``meta.`` access).""" + return (info[0] for info in self.__rows.values()) + + def items(self): + """Pairs of ``(attribute_name, value)`` for each feature attribute.""" + return ((attr, info[0]) for attr, info in self.__rows.items()) + + def __repr__(self) -> str: + return self._render(verbose=False) + + def describe(self, verbose: bool = False) -> None: + """Print the formatted capability table. ``verbose=True`` adds descriptions and docs links per row.""" + print(self._render(verbose=verbose)) + + def _render(self, verbose: bool) -> str: + if not self.__rows: + return f"{self.__cls_name}: no feature attributes declared" + + is_tty = sys.stdout.isatty() + bold = "\033[1m" if is_tty else "" + dim = "\033[2m" if is_tty else "" + cyan = "\033[36m" if is_tty else "" + underline = "\033[4m" if is_tty else "" + reset = "\033[0m" if is_tty else "" + + attr_w = max(len(attr) for attr in self.__rows) + title = f"{self.__cls_name} feature attributes" + rule_width = max(len(title), attr_w + 2 + max(len(row[1]) for row in self.__rows.values())) + lines = [ + f"{bold}{title}{reset}", + f"{dim}{'─' * rule_width}{reset}", + ] + + rows = list(self.__rows.items()) + for i, (attr, (_value, display, doc, link)) in enumerate(rows): + lines.append(f" {bold}{cyan}{attr:<{attr_w}}{reset} {display}") + if verbose: + if doc: + lines.append(f" {dim}{doc}{reset}") + if link: + lines.append(f" {dim}See {underline}{link}{reset}") + if i < len(rows) - 1: + lines.append("") + return "\n".join(lines) + + +def register_metadata(metadata): + """Generic class decorator that attaches metadata to the decorated class. + + Dispatches via ``metadata._register(cls)`` — each metadata dataclass owns its own attachment logic. Currently used + by :class:`~diffusers.hooks._helpers.TransformerBlockMetadata` to register block-level metadata into + :class:`TransformerBlockRegistry`:: + + @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, ...)) class + FluxTransformerBlock(nn.Module): + ... + + Model-level capabilities are declared as plain class attributes on :class:`ModelMixin` (and on subsystem mixins + like :class:`LoRAModelMixin` or model-specific ones like ``FluxIPAdapterMixin``) — no decorator needed. + """ + + def wrap(cls): + metadata._register(cls) + return cls + + return wrap + + +# Deprecation message reused across the per-backend attention helpers on ``ModelMixin`` (npu / xla / xformers). +# These have been superseded by the unified ``set_attention_backend(...)`` / ``reset_attention_backend()`` API; +# each call site supplies its specific replacement call as ``{replacement}``. +_ATTENTION_API_DEPRECATION_MSG = "`ModelMixin.{name}` is deprecated. Use `{replacement}` instead." + + +class ModelMixin(torch.nn.Module, ConfigMixin, PushToHubMixin): r""" Base class for all models. [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - _supports_gradient_checkpointing = False - _keys_to_ignore_on_load_unexpected = None - _no_split_modules = None - _keep_in_fp32_modules = None - _skip_layerwise_casting_patterns = None - _supports_group_offloading = True - _repeated_blocks = [] - _parallel_config = None - _cp_plan = None _skip_keys = None + _supports_gradient_checkpointing: bool = False + _no_split_modules: list[str] | None = None + _keep_in_fp32_modules: list[str] | None = None + _skip_layerwise_casting_patterns: tuple[str, ...] | list[str] | None = None + _supports_group_offloading: bool = True + _repeated_blocks: list[str] = [] + _cp_plan: dict[str, Any] | None = None + _keys_to_ignore_on_load_unexpected: list[str] | None = None + _weight_mapping: WeightMappingHandler = WeightMappingHandler() + def __init__(self): super().__init__() self._gradient_checkpointing_func = None + @classmethod + def _metadata(cls) -> dict[str, tuple[Any, str, str, str]]: + """Return ``ModelMixin``-level rows for the metadata snapshot.""" + rows: dict[str, tuple[Any, str, str, str]] = {} + if cls._supports_gradient_checkpointing: + rows["_supports_gradient_checkpointing"] = ( + True, + "True", + "Trades compute for memory by recomputing activations during backward.", + f"{DOCS_BASE}/optimization/memory#gradient-checkpointing", + ) + if cls._supports_group_offloading: + rows["_supports_group_offloading"] = ( + True, + "True", + "Stage parameter groups on CPU/disk and stream them to the accelerator for inference.", + f"{DOCS_BASE}/optimization/memory#group-offloading", + ) + if cls._no_split_modules: + rows["_no_split_modules"] = ( + list(cls._no_split_modules), + ", ".join(cls._no_split_modules), + "Block class names that must stay on a single device under `device_map='auto'` sharding.", + f"{DOCS_BASE}/training/distributed_inference#device-map", + ) + if cls._keep_in_fp32_modules: + rows["_keep_in_fp32_modules"] = ( + list(cls._keep_in_fp32_modules), + ", ".join(cls._keep_in_fp32_modules), + "Submodule name patterns that remain in fp32 even when the model is cast to fp16/bf16.", + f"{DOCS_BASE}/optimization/fp16#mixed-precision", + ) + if cls._skip_layerwise_casting_patterns: + rows["_skip_layerwise_casting_patterns"] = ( + tuple(cls._skip_layerwise_casting_patterns), + ", ".join(cls._skip_layerwise_casting_patterns), + "Parameter name substrings excluded from layerwise dtype casting (embeddings, norms, ...).", + f"{DOCS_BASE}/optimization/memory#layerwise-casting", + ) + if cls._repeated_blocks: + rows["_repeated_blocks"] = ( + list(cls._repeated_blocks), + ", ".join(cls._repeated_blocks), + "Block class names safe to `torch.compile` once and reuse — enables `compile_repeated_blocks`.", + f"{DOCS_BASE}/optimization/torch2.0", + ) + if cls._cp_plan: + rows["_cp_plan"] = ( + True, + "True", + "Support context parallel inference.", + f"{DOCS_BASE}/training/distributed_inference#context-parallelism", + ) + if cls._weight_mapping.supports_single_file: + configs = sorted(cls._weight_mapping.available_configs) + rows["_weight_mapping"] = ( + configs, + ", ".join(configs), + "Auto-resolvable configs for `from_single_file(path)` (no `config=` argument required).", + f"{DOCS_BASE}/api/loaders/single_file", + ) + return rows + + @classmethod + def metadata(cls) -> "ModelMetadata": + """Return a :class:`ModelMetadata` snapshot of this class's feature attributes.""" + merged: dict[str, tuple[Any, str, str, str]] = {} + for mixin in cls.__mro__: + method = mixin.__dict__.get("_metadata") + if method is None: + continue + + for attr, info in method.__func__(cls).items(): + merged.setdefault(attr, info) + + return ModelMetadata(merged, cls.__name__) + def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite @@ -324,6 +523,14 @@ def set_use_npu_flash_attention(self, valid: bool) -> None: r""" Set the switch for the npu flash attention. """ + deprecate( + "ModelMixin.set_use_npu_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="set_use_npu_flash_attention", + replacement='set_attention_backend("_native_npu") / reset_attention_backend()', + ), + ) def fn_recursive_set_npu_flash_attention(module: torch.nn.Module): if hasattr(module, "set_use_npu_flash_attention"): @@ -341,6 +548,14 @@ def enable_npu_flash_attention(self) -> None: Enable npu flash attention from torch_npu """ + deprecate( + "ModelMixin.enable_npu_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="enable_npu_flash_attention", + replacement='set_attention_backend("_native_npu")', + ), + ) self.set_use_npu_flash_attention(True) def disable_npu_flash_attention(self) -> None: @@ -348,11 +563,28 @@ def disable_npu_flash_attention(self) -> None: disable npu flash attention from torch_npu """ + deprecate( + "ModelMixin.disable_npu_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="disable_npu_flash_attention", + replacement="reset_attention_backend()", + ), + ) self.set_use_npu_flash_attention(False) def set_use_xla_flash_attention( self, use_xla_flash_attention: bool, partition_spec: Callable | None = None, **kwargs ) -> None: + deprecate( + "ModelMixin.set_use_xla_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="set_use_xla_flash_attention", + replacement='set_attention_backend("_native_xla") / reset_attention_backend()', + ), + ) + # Recursively walk through all the children. # Any children which exposes the set_use_xla_flash_attention method # gets the message @@ -371,15 +603,40 @@ def enable_xla_flash_attention(self, partition_spec: Callable | None = None, **k r""" Enable the flash attention pallals kernel for torch_xla. """ + deprecate( + "ModelMixin.enable_xla_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="enable_xla_flash_attention", + replacement='set_attention_backend("_native_xla")', + ), + ) self.set_use_xla_flash_attention(True, partition_spec, **kwargs) def disable_xla_flash_attention(self): r""" Disable the flash attention pallals kernel for torch_xla. """ + deprecate( + "ModelMixin.disable_xla_flash_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="disable_xla_flash_attention", + replacement="reset_attention_backend()", + ), + ) self.set_use_xla_flash_attention(False) def set_use_memory_efficient_attention_xformers(self, valid: bool, attention_op: Callable | None = None) -> None: + deprecate( + "ModelMixin.set_use_memory_efficient_attention_xformers", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="set_use_memory_efficient_attention_xformers", + replacement='set_attention_backend("xformers") / reset_attention_backend()', + ), + ) + # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message @@ -424,12 +681,28 @@ def enable_xformers_memory_efficient_attention(self, attention_op: Callable | No >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) ``` """ + deprecate( + "ModelMixin.enable_xformers_memory_efficient_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="enable_xformers_memory_efficient_attention", + replacement='set_attention_backend("xformers")', + ), + ) self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self) -> None: r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ + deprecate( + "ModelMixin.disable_xformers_memory_efficient_attention", + "1.0.0", + _ATTENTION_API_DEPRECATION_MSG.format( + name="disable_xformers_memory_efficient_attention", + replacement="reset_attention_backend()", + ), + ) self.set_use_memory_efficient_attention_xformers(False) def enable_layerwise_casting( @@ -1450,6 +1723,233 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None return model + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path_or_dict: str | None = None, **kwargs) -> Self: + r""" + Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model + is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path_or_dict (`str`, *optional*): + Can be either: + - A link to the `.safetensors` or `.ckpt` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a local *file* containing the weights of the component model. + - A state dict containing the component model weights. + config (`str`, *optional*): + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted + on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component + configs in Diffusers format. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained weights and not initializing the weights. + disable_mmap (`bool`, *optional*, defaults to `False`): + Whether to disable mmap when loading a Safetensors model. + + Returns: + The instantiated model. + + Example: + ```python + >>> from diffusers import FluxTransformer2DModel + + >>> ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + >>> model = FluxTransformer2DModel.from_single_file(ckpt_path) + ``` + """ + from ..loaders.single_file_utils import ( + SingleFileComponentError, + load_single_file_checkpoint, + ) + + # The ``WeightMappingHandler`` is attached as ``cls._weight_mapping`` — either overridden by the model + # (e.g. ``_weight_mapping = FLUX_WEIGHT_MAPPING``) or inherited as the empty default from ``ModelMixin``. + # Its ``supports_single_file`` property checks that the model declared a ``default_config`` so config + # resolution always succeeds with no extra args from the user. + _weight_mapping = cls._weight_mapping + if not _weight_mapping.supports_single_file: + raise ValueError( + f"`{cls.__name__}.from_single_file` is not supported. " + "The model's `WeightMappingHandler` must declare `default_config` (a key into " + "`available_configs`) so we can resolve which architecture to instantiate when the user " + "doesn't pass `config=` explicitly. Use `from_pretrained` if the model is already in " + "diffusers format." + ) + default_subfolder = _weight_mapping.default_subfolder + + pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) + if pretrained_model_link_or_path is not None: + deprecation_message = ( + "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes" + ) + deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) + pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path + + # Hub-download kwargs (cache_dir / force_download / proxies / local_files_only / token / + # revision / subfolder) consolidated via the canonical ``HUB_KWARGS`` defaults. + hub_kwargs = {k: kwargs.pop(k, default) for k, default in HUB_KWARGS.items()} + + config = kwargs.pop("config", None) + config_revision = kwargs.pop("config_revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + quantization_config = kwargs.pop("quantization_config", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + disable_mmap = kwargs.pop("disable_mmap", False) + device_map = kwargs.pop("device_map", None) + + user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"} + if quantization_config is not None: + user_agent["quant"] = quantization_config.quant_method.value + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + + if isinstance(pretrained_model_link_or_path_or_dict, dict): + state_dict = pretrained_model_link_or_path_or_dict + else: + # ``load_single_file_checkpoint`` takes everything in ``HUB_KWARGS`` except ``subfolder``. + state_dict = load_single_file_checkpoint( + pretrained_model_link_or_path_or_dict, + disable_mmap=disable_mmap, + user_agent=user_agent, + **{k: v for k, v in hub_kwargs.items() if k != "subfolder"}, + ) + + # Normalize state_dict keys via the weight-mapping handler (strip known prefixes; no-op if none registered). + state_dict = _weight_mapping.normalize_state_dict_keys(state_dict) + + if quantization_config is not None: + hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) + hf_quantizer.validate_environment() + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + else: + hf_quantizer = None + + if config is not None: + if isinstance(config, str): + default_pretrained_model_config_name = config + else: + raise ValueError( + "Invalid `config` argument. Please provide a string representing a repo id " + "or path to a local Diffusers model repo." + ) + else: + default_pretrained_model_config_name = _weight_mapping.get_model_config(state_dict) + if default_subfolder is not None: + hub_kwargs["subfolder"] = default_subfolder + + # ``load_config`` consumes the hub-download kwargs; ``config_revision`` (if set) overrides + # the file ``revision`` for resolving the config repo specifically. + diffusers_model_config = cls.load_config( + pretrained_model_name_or_path=default_pretrained_model_config_name, + **{**hub_kwargs, "revision": config_revision}, + ) + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} + diffusers_model_config.update(model_kwargs) + + if is_accelerate_available(): + from accelerate import init_empty_weights + + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + else: + ctx = nullcontext + + with ctx(): + model = cls.from_config(diffusers_model_config) + + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + else: + keep_in_fp32_modules = [] + + # ``normalize_state_dict_keys`` already ran earlier (before model creation) for detection; this call is + # idempotent and runs the full converter only if keys still don't match the freshly-built model. + state_dict = _weight_mapping.maybe_convert_state_dict(model, state_dict) + + if not state_dict: + raise SingleFileComponentError( + f"Failed to load {cls.__name__}. Weights for this component appear to be missing in the checkpoint." + ) + + loaded_keys = list(state_dict.keys()) + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer) + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + None, + None, + loaded_keys, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + if device_map is not None: + from accelerate import dispatch_model + + device_map_kwargs = { + "device_map": device_map, + "offload_index": offload_index, + } + dispatch_model(model, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and hf_quantizer is None: + model.to(torch_dtype) + + model.eval() + + return model + # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): @@ -1651,7 +2151,6 @@ def enable_parallelism( ) config.setup(rank, world_size, device, mesh=mesh) - self._parallel_config = config for module in self.modules(): if not isinstance(module, attention_classes): diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 156b54e7f07d..6ad7f6ae551c 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -1,7 +1,57 @@ -from ...utils import is_torch_available +import sys +import types + +from ...utils import deprecate, is_torch_available + + +class _DeprecatedModuleAlias(types.ModuleType): + """Backwards-compat alias for a transformer module that's been moved to a subpackage. + + Lives only in ``sys.modules`` — no stub file. Emits a one-time ``deprecate`` warning on first attribute access, + then forwards every attribute lookup to the new target module. Used when a flat ``transformer_.py`` is split + into a ``/`` subpackage and we want the old import path to keep working for a release cycle. + """ + + def __init__(self, old_dotted_path: str, target: types.ModuleType): + super().__init__(target.__name__, target.__doc__) + # Bypass __getattr__ when writing internals. + self.__dict__["_target"] = target + self.__dict__["_old_path"] = old_dotted_path + self.__dict__["_warned"] = False + + def __getattr__(self, name): + if not self.__dict__["_warned"]: + self.__dict__["_warned"] = True + old = self.__dict__["_old_path"] + new = self.__dict__["_target"].__name__ + deprecate( + old, + "1.0.0", + f"Importing from `{old}` is deprecated. Import from `{new}` instead.", + standard_warn=True, + stacklevel=3, + ) + return getattr(self.__dict__["_target"], name) + + +def _register_legacy_module_alias(old_name: str, new_name: str) -> None: + """Register ``old_name`` as a deprecated alias for the already-loaded ``new_name`` submodule. + + Both names are relative to ``diffusers.models.transformers``. The new submodule must already be in ``sys.modules`` + (loaded by a prior ``from . import ...`` in this file). + """ + old_dotted = f"{__name__}.{old_name}" + target = sys.modules[f"{__name__}.{new_name}"] + sys.modules[old_dotted] = _DeprecatedModuleAlias(old_dotted, target) if is_torch_available(): + # Load flux first and install the legacy alias before any other transformer module imports, + # since some of them still pull from `transformer_flux` during their own load. + from .flux import FluxTransformer2DModel + + _register_legacy_module_alias("transformer_flux", "flux") + from .ace_step_transformer import AceStepTransformer1DModel from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .cogvideox_transformer_3d import CogVideoXTransformer3DModel @@ -27,7 +77,6 @@ from .transformer_cosmos import CosmosTransformer3DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_ernie_image import ErnieImageTransformer2DModel - from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel from .transformer_glm_image import GlmImageTransformer2DModel from .transformer_helios import HeliosTransformer3DModel diff --git a/src/diffusers/models/transformers/flux/__init__.py b/src/diffusers/models/transformers/flux/__init__.py new file mode 100644 index 000000000000..477b552bcbb0 --- /dev/null +++ b/src/diffusers/models/transformers/flux/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model import ( + FluxAttention, + FluxAttnProcessor, + FluxIPAdapterAttnProcessor, + FluxPosEmbed, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, +) diff --git a/src/diffusers/models/transformers/flux/_ip_adapter.py b/src/diffusers/models/transformers/flux/_ip_adapter.py new file mode 100644 index 000000000000..eb679d09ec4f --- /dev/null +++ b/src/diffusers/models/transformers/flux/_ip_adapter.py @@ -0,0 +1,175 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flux-specific IP-Adapter loading. + +IP-Adapter behavior — what's in the state dict, what the attn processors look like, which blocks they bind to — varies +enough across models that a generic mixin can't really capture the orchestration. Flux owns its own +``_load_ip_adapter_weights`` here, including the loop over blocks, the choice to skip single-stream blocks, and the +projection-dim computation. + +``FluxIPAdapterMixin`` is added to ``FluxTransformer2DModel``'s bases in ``flux/model.py``. Models that don't support +IP-Adapter simply don't inherit anything — there's no opt-in handler default to override. +""" + +from contextlib import nullcontext + +from ....models.embeddings import ImageProjection, MultiIPAdapterImageProjection +from ....models.model_loading_utils import load_model_dict_into_meta +from ....models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, DOCS_BASE +from ....utils import is_accelerate_available, is_torch_version, logging +from ....utils.torch_utils import empty_device_cache + + +logger = logging.get_logger(__name__) + + +def _resolve_init_context(low_cpu_mem_usage): + """Return ``(init_context, low_cpu_mem_usage)`` — disables low-cpu init if accelerate is missing.""" + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + if not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch " + "version or set `low_cpu_mem_usage=False`." + ) + return init_empty_weights, True + + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the " + "environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install " + "`accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip " + "install accelerate\n```\n." + ) + return nullcontext, False + + +def _convert_image_proj(model, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + """Build a Flux ``ImageProjection`` from an IP-Adapter ``image_proj`` state dict.""" + init_context, low_cpu_mem_usage = _resolve_init_context(low_cpu_mem_usage) + + # ``proj.weight`` rows == cross_attention_dim * num_image_text_embeds. The two + # supported configurations: 4 tokens (default) and 16 tokens (when rows == 65536). + num_image_text_embeds = 16 if state_dict["proj.weight"].shape[0] == 65536 else 4 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds + + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + updated_state_dict = {key.replace("proj", "image_embeds"): value for key, value in state_dict.items()} + + if low_cpu_mem_usage: + load_model_dict_into_meta( + image_projection, updated_state_dict, device_map={"": model.device}, dtype=model.dtype + ) + empty_device_cache() + else: + image_projection.load_state_dict(updated_state_dict, strict=True) + + return image_projection + + +def _convert_attn_processors(model, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + """Build the IP-Adapter attn-processor dict for a ``FluxTransformer2DModel``. + + Single-stream blocks keep their existing processor; double-stream blocks get a ``FluxIPAdapterAttnProcessor`` + loaded with the per-state-dict ``to_k_ip`` / ``to_v_ip`` weights. + """ + from .model import FluxIPAdapterAttnProcessor + + init_context, low_cpu_mem_usage = _resolve_init_context(low_cpu_mem_usage) + + attn_procs = {} + key_id = 0 + for name in model.attn_processors: + if name.startswith("single_transformer_blocks"): + attn_procs[name] = model.attn_processors[name].__class__() + continue + + num_image_text_embeds = [16 if sd["image_proj"]["proj.weight"].shape[0] == 65536 else 4 for sd in state_dicts] + + with init_context(): + attn_procs[name] = FluxIPAdapterAttnProcessor( + hidden_size=model.inner_dim, + cross_attention_dim=model.config.joint_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + dtype=model.dtype, + device=model.device, + ) + + value_dict = {} + for i, sd in enumerate(state_dicts): + value_dict[f"to_k_ip.{i}.weight"] = sd["ip_adapter"][f"{key_id}.to_k_ip.weight"] + value_dict[f"to_v_ip.{i}.weight"] = sd["ip_adapter"][f"{key_id}.to_v_ip.weight"] + value_dict[f"to_k_ip.{i}.bias"] = sd["ip_adapter"][f"{key_id}.to_k_ip.bias"] + value_dict[f"to_v_ip.{i}.bias"] = sd["ip_adapter"][f"{key_id}.to_v_ip.bias"] + + if low_cpu_mem_usage: + load_model_dict_into_meta(attn_procs[name], value_dict, device_map={"": model.device}, dtype=model.dtype) + else: + attn_procs[name].load_state_dict(value_dict) + + key_id += 1 + + empty_device_cache() + return attn_procs + + +class FluxIPAdapterMixin: + """Flux-specific IP-Adapter loader. Mixed into :class:`FluxTransformer2DModel`.""" + + _supports_ip_adapter = True + + @classmethod + def _metadata(cls): + """Contribute the ``_supports_ip_adapter`` row to the metadata describe() table.""" + return { + "_supports_ip_adapter": ( + True, + "True", + "Supports loading IP-Adapter weights (image-conditioning adapters).", + f"{DOCS_BASE}/using-diffusers/ip_adapter", + ) + } + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + """Install IP-Adapter weights on the Flux transformer. + + ``state_dicts`` is a single state dict (or a list, for multi-adapter loading); each dict must contain + ``"image_proj"`` and ``"ip_adapter"`` sub-dicts. + """ + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + + self.encoder_hid_proj = None + + attn_procs = _convert_attn_processors(self, state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = _convert_image_proj( + self, state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/transformers/flux/_lora.py b/src/diffusers/models/transformers/flux/_lora.py new file mode 100644 index 000000000000..9e821be9ed6f --- /dev/null +++ b/src/diffusers/models/transformers/flux/_lora.py @@ -0,0 +1,440 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flux LoRA conversion. + +Each supported foreign format has a top-level ``map__to_diffusers`` entry point: + + - :func:`map_bfl_to_diffusers` — original BFL repo layout + - :func:`map_kontext_to_diffusers` — fal Kontext checkpoints (BFL + ``base_model.model.`` prefix) + - :func:`map_xlabs_to_diffusers` — XLabs ``.processor.qkv_lora`` / ``.processor.proj_lora`` shape + - :func:`map_kohya_to_diffusers` — kohya sd-scripts (and "mixture" / ``lora_transformer_*`` variants) + +Each entry point produces a state dict with diffusers naming. Internally they all funnel through +:func:`_map_to_diffusers`, which converts a BFL-style state dict (original Flux module names + ``.lora_A``/``.lora_B`` +suffixes) to diffusers names by reusing the rename / QKV-split / special-key tables in ``weight_mapping.py`` and +applying LoRA-specific QKV semantics (``lora_A.weight`` replicates across heads; everything else chunks). + +A format-specific converter may also emit pre-converted diffusers keys directly when a key shape doesn't fit the +canonical intermediate (e.g., XLabs single-block QKV without a paired MLP LoRA). +""" + +import re + +import torch + +from ....loaders.lora import LoRAHandler +from ....utils import logging, state_dict_all_zero +from ._weight_mapping import ( + FLUX_QKV_SPLIT_PATTERNS, + FLUX_QKVMLP_SPLIT_PATTERN, + FLUX_QKVMLP_TARGETS, + FLUX_RENAME_PATTERNS, + FLUX_SPECIAL_KEYS, +) + + +logger = logging.get_logger(__name__) + + +# ============================================================================ +# Shared canonical -> diffusers converter +# ============================================================================ +# Canonical keys are BFL-style: original Flux module names + .lora_A/.lora_B +# suffixes. The shared converter handles three cases — pure renames, QKV splits, +# special transforms — by reusing the tables from weight_mapping. + +_LORA_SUFFIXES = (".lora_A.weight", ".lora_A.bias", ".lora_B.weight", ".lora_B.bias") + +# Module-path versions (boundary dots stripped) of the weight-mapping tables. +_QKV_PATTERNS = {p.strip("."): [t.strip(".") for t in ts] for p, ts in FLUX_QKV_SPLIT_PATTERNS.items()} +_QKVMLP_PATTERN = FLUX_QKVMLP_SPLIT_PATTERN.strip(".") +_QKVMLP_TARGETS = [t.strip(".") for t in FLUX_QKVMLP_TARGETS] +_SPECIAL_MODULES = {} +for _full_src, _spec in FLUX_SPECIAL_KEYS.items(): + for _tail in (".weight", ".bias"): + if _full_src.endswith(_tail) and _spec["target"].endswith(_tail): + _SPECIAL_MODULES.setdefault(_full_src[: -len(_tail)], (_spec["target"][: -len(_tail)], _spec["transform"])) + break + + +def _apply_renames(s, patterns): + for old, new in patterns.items(): + s = s.replace(old, new) + return s + + +def _map_to_diffusers(state_dict, inner_dim=3072, mlp_ratio=4.0): + """Convert a BFL-style canonical LoRA state dict to diffusers naming.""" + out = {} + qkvmlp_dims = (inner_dim, inner_dim, inner_dim, int(inner_dim * mlp_ratio)) + + for key, value in state_dict.items(): + # Split off the .lora_A/.lora_B suffix; non-LoRA keys pass through with renames. + suffix = next((s for s in _LORA_SUFFIXES if key.endswith(s)), "") + if not suffix: + out[f"transformer.{_apply_renames(key, FLUX_RENAME_PATTERNS)}"] = value + continue + module_path = key[: -len(suffix)] + + # FLUX_RENAME_PATTERNS keys often end with "."; pad-and-strip so bare module paths + # like "final_layer.linear" still match patterns like "final_layer.linear.". + def _rename(path): + renamed = _apply_renames(path + ".", FLUX_RENAME_PATTERNS) + return renamed[:-1] if renamed.endswith(".") else renamed + + qkv = next(((p, ts) for p, ts in _QKV_PATTERNS.items() if p in module_path), None) + if qkv is not None: + pattern, targets = qkv + chunks = ( + [value] * len(targets) if suffix == ".lora_A.weight" else list(torch.chunk(value, len(targets), dim=0)) + ) + for target, chunk in zip(targets, chunks): + out[f"transformer.{_rename(module_path.replace(pattern, target))}{suffix}"] = chunk + continue + + if _QKVMLP_PATTERN in module_path and "single_blocks." in module_path: + chunks = ( + [value] * len(_QKVMLP_TARGETS) + if suffix == ".lora_A.weight" + else list(torch.split(value, qkvmlp_dims, dim=0)) + ) + for target, chunk in zip(_QKVMLP_TARGETS, chunks): + out[f"transformer.{_rename(module_path.replace(_QKVMLP_PATTERN, target))}{suffix}"] = chunk + continue + + if module_path in _SPECIAL_MODULES: + target_module, transform = _SPECIAL_MODULES[module_path] + out[f"transformer.{target_module}{suffix}"] = transform(value) + continue + + out[f"transformer.{_rename(module_path)}{suffix}"] = value + + return out + + +# ============================================================================ +# BFL — identity (canonical form is BFL-style) +# ============================================================================ + + +def map_bfl_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from BFL format to diffusers naming.""" + return _map_to_diffusers(dict(state_dict)) + + +# ============================================================================ +# fal Kontext — BFL with ``base_model.model.`` prefix +# ============================================================================ + + +def map_kontext_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from fal Kontext format to diffusers naming.""" + prefix = "base_model.model." + canonical = {(k[len(prefix) :] if k.startswith(prefix) else k): v for k, v in state_dict.items()} + return _map_to_diffusers(canonical) + + +# ============================================================================ +# XLabs +# ============================================================================ +# XLabs key shape: [diffusion_model.]{double|single}_blocks.{i}.processor.{X}.{down|up}.weight +# Double-block X ∈ {qkv_lora1, qkv_lora2, proj_lora1, proj_lora2} — renameable to canonical +# BFL form. Single-block X ∈ {qkv_lora, proj_lora} — single blocks lack an MLP LoRA, so +# qkv keys can't be expressed as canonical "linear1" (QKV+MLP fused); we emit pre-converted +# diffusers keys for single-block extras and route only double-block keys through canonical. + +_XLABS_DOUBLE_RENAMES = { + ".processor.proj_lora1.": ".img_attn.proj.", + ".processor.proj_lora2.": ".txt_attn.proj.", + ".processor.qkv_lora1.": ".img_attn.qkv.", + ".processor.qkv_lora2.": ".txt_attn.qkv.", +} +_XLABS_SINGLE_QKV_TARGETS = ["attn.to_q", "attn.to_k", "attn.to_v"] + + +def map_xlabs_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from XLabs format to diffusers naming.""" + canonical = {} + extras = {} + for key, value in state_dict.items(): + k = key.removeprefix("diffusion_model.") + + if "single_blocks." in k: + block = re.search(r"single_blocks\.(\d+)", k).group(1) + base = f"transformer.single_transformer_blocks.{block}" + suffix = ".lora_A.weight" if k.endswith(".lora_A.weight") else ".lora_B.weight" + if "proj_lora" in k: + extras[f"{base}.proj_out{suffix}"] = value + elif "qkv_lora" in k: + chunks = ( + [value] * len(_XLABS_SINGLE_QKV_TARGETS) + if suffix == ".lora_A.weight" + else list(torch.chunk(value, 3, dim=0)) + ) + for t, chunk in zip(_XLABS_SINGLE_QKV_TARGETS, chunks): + extras[f"{base}.{t}{suffix}"] = chunk + continue + + # Double block: rename to canonical BFL-style; shared converter handles the QKV split. + for old, new in _XLABS_DOUBLE_RENAMES.items(): + k = k.replace(old, new) + canonical[k] = value + + converted = _map_to_diffusers(canonical) if canonical else {} + return {**converted, **extras} + + +# ============================================================================ +# Kohya (sd-scripts + "mixture" variant) +# ============================================================================ +# Kohya keys collapse dots into underscores in the module path, then append +# .lora_down/.lora_up/.alpha. We invert this with a single explicit suffix table +# (the original-name underscore <-> dot mapping isn't recoverable by rule), then +# apply alpha-driven scaling so canonical tensors are pre-scaled. + +# Kohya stub-suffix → BFL form. Block stubs (``double_blocks_{i}_`` and +# ``single_blocks_{i}_``) look up just the trailing here; everything +# else is a global stub that maps directly. No overlap between contexts. +_KOHYA_TO_BFL = { + # double_blocks_{i}_ + "img_attn_proj": "img_attn.proj", + "img_attn_qkv": "img_attn.qkv", + "img_mlp_0": "img_mlp.0", + "img_mlp_2": "img_mlp.2", + "img_mod_lin": "img_mod.lin", + "txt_attn_proj": "txt_attn.proj", + "txt_attn_qkv": "txt_attn.qkv", + "txt_mlp_0": "txt_mlp.0", + "txt_mlp_2": "txt_mlp.2", + "txt_mod_lin": "txt_mod.lin", + # single_blocks_{i}_ + "linear1": "linear1", + "linear2": "linear2", + "modulation_lin": "modulation.lin", + # Global stubs (used directly as canonical path) + "guidance_in_in_layer": "guidance_in.in_layer", + "guidance_in_out_layer": "guidance_in.out_layer", + "img_in": "img_in", + "txt_in": "txt_in", + "time_in_in_layer": "time_in.in_layer", + "time_in_out_layer": "time_in.out_layer", + "vector_in_in_layer": "vector_in.in_layer", + "vector_in_out_layer": "vector_in.out_layer", + "final_layer_linear": "final_layer.linear", + "final_layer_adaLN_modulation_1": "final_layer.adaLN_modulation.1", +} + + +def _kohya_scale(alpha, rank): + """Split alpha/rank into (down, up) factors so down*up == alpha/rank but stays bounded.""" + scale = alpha / rank + down, up = scale, 1.0 + while down * 2 < up: + down *= 2 + up /= 2 + return down, up + + +def _kohya_mixture_to_diffusers(state_dict): + """Convert Kohya mixture-format (``lora_transformer_*`` keys) directly to diffusers naming.""" + out = {} + unique = { + k.replace(".lora_A.weight", "").replace(".lora_B.weight", "").replace(".alpha", "") + for k in state_dict + if k.startswith("lora_transformer_") + } + + for k in unique: + if k.startswith("lora_transformer_single_transformer_blocks_"): + i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"single_transformer_blocks.{i}" + elif k.startswith("lora_transformer_transformer_blocks_"): + i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"transformer_blocks.{i}" + elif k.startswith("lora_transformer_context_embedder"): + diffusers_key = "context_embedder" + elif k.startswith("lora_transformer_norm_out_linear"): + diffusers_key = "norm_out.linear" + elif k.startswith("lora_transformer_proj_out"): + diffusers_key = "proj_out" + elif k.startswith("lora_transformer_x_embedder"): + diffusers_key = "x_embedder" + elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.text_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}" + else: + raise NotImplementedError(f"Handling for key ({k}) is not implemented.") + + if "attn_" in k: + tail = k.split("attn_")[-1] + if "_to_out_0" in k: + diffusers_key += ".attn.to_out.0" + elif "_to_add_out" in k: + diffusers_key += ".attn.to_add_out" + elif any(qkv in k for qkv in ("to_q", "to_k", "to_v", "add_q_proj", "add_k_proj", "add_v_proj")): + diffusers_key += f".attn.{tail}" + + down = state_dict.pop(f"{k}.lora_A.weight") + up = state_dict.pop(f"{k}.lora_B.weight") + alpha = state_dict.pop(f"{k}.alpha") + d_scale, u_scale = _kohya_scale(alpha, down.shape[0]) + out[f"transformer.{diffusers_key}.lora_A.weight"] = down * d_scale + out[f"transformer.{diffusers_key}.lora_B.weight"] = up * u_scale + + leftover = [k for k in state_dict if not k.startswith("lora_unet_")] + if leftover: + logger.warning(f"Unsupported mixture keys ignored: {leftover}") + return out + + +def map_kohya_to_diffusers(state_dict): + """Convert a Flux LoRA state dict from Kohya format (sd-scripts or mixture) to diffusers naming.""" + # ---- Pre-filter: rename prefix, drop unsupported keys, collapse leading dots. ---- + state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()} + + drop_specs = [ + (lambda k: "position_embedding" in k, "position_embedding", "position_embedding"), + (lambda k: ".diff_b" in k and k.startswith("lora_unet_"), ".diff_b", "diff_b"), + (lambda k: ".norm" in k and ".diff" in k, ".diff", "diff"), + ] + for predicate, marker, label in drop_specs: + if not any(predicate(k) for k in state_dict): + continue + msg = ( + f"The `{label}` LoRA params are all zeros which make them ineffective. So, we will purge them out of " + "the current state dict to make loading possible." + if state_dict_all_zero(state_dict, marker) + else f"`{label}` keys found in the state dict are currently unsupported and will be filtered out. " + "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." + ) + logger.info(msg) + state_dict = {k: v for k, v in state_dict.items() if not predicate(k)} + + # Some keys come with dots in the prefix; collapse them up to lora_A/lora_B/alpha. + limit = ["lora_A", "lora_B"] + (["alpha"] if any("alpha" in k for k in state_dict) else []) + boundary_re = re.compile("(" + "|".join(re.escape(s) for s in limit) + ")") + + def _collapse_prefix(key): + match = boundary_re.search(key) + if not match: + return key.replace(".", "_") + i = match.start() + boundary = i - 1 if i > 0 and key[i - 1] == "." else i + return key[:boundary].replace(".", "_") + key[boundary:] + + state_dict = {_collapse_prefix(k): v for k, v in state_dict.items() if k.startswith("lora_unet_")} + + # ---- Mixture variant has its own prefix; route to its direct converter. ---- + if any( + k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict + ): + return _kohya_mixture_to_diffusers(state_dict) + + # ---- sd-scripts variant: group by stub, apply alpha scaling, rewrite to canonical. ---- + groups = {} # stub -> {"lora_A": full_key, "lora_B": full_key, "alpha": full_key} + for key in list(state_dict): + if not key.startswith("lora_unet_"): + continue + for kind in ("lora_A.weight", "lora_B.weight", "alpha"): + tail = "." + kind + if key.endswith(tail): + stub = key[len("lora_unet_") : -len(tail)] + groups.setdefault(stub, {})[kind.split(".")[0]] = key + break + + canonical = {} + for stub, group in groups.items(): + down_key, up_key = group.get("lora_A"), group.get("lora_B") + if down_key is None or up_key is None: + continue + rank = state_dict[down_key].shape[0] + alpha = state_dict.pop(group["alpha"]).item() if "alpha" in group else float(rank) + d_scale, u_scale = _kohya_scale(alpha, rank) + down = state_dict.pop(down_key) * d_scale + up = state_dict.pop(up_key) * u_scale + + # Map kohya stub → BFL canonical path. Block stubs strip their "{kind}_blocks_{i}_" + # prefix and look up the trailing suffix; global stubs map directly. + bfl = None + for kind in ("double_blocks", "single_blocks"): + m = re.match(rf"{kind}_(\d+)_(.+)$", stub) + if m: + suffix = _KOHYA_TO_BFL.get(m.group(2)) + bfl = f"{kind}.{m.group(1)}.{suffix}" if suffix else None + break + else: + bfl = _KOHYA_TO_BFL.get(stub) + + if bfl is None: + logger.warning(f"Unsupported Kohya key: lora_unet_{stub}") + continue + canonical[f"{bfl}.lora_A.weight"] = down + canonical[f"{bfl}.lora_B.weight"] = up + + if state_dict: + logger.warning(f"Unsupported keys after Kohya normalization: {list(state_dict.keys())}") + + return _map_to_diffusers(canonical) + + +# ============================================================================ +# Top-level dispatch +# ============================================================================ +# Per-format identifying key substrings. Single source of truth — also exported +# via ``FLUX_LORA_METADATA`` so ``LoRAModelMixin._detect_lora_format`` finds it. + +_FLUX_LORA_FORMAT_KEYS: dict[str, set[str]] = { + "kohya": {"lora_unet_double_blocks_", "lora_unet_single_blocks_"}, + "xlabs": {".processor.qkv_lora", ".processor.proj_lora"}, + "bfl": {"time_in.in_layer.lora_A", "double_blocks.0.img_mod.lin.lora_A"}, + "kontext": {"base_model.model.double_blocks"}, +} + +_FORMAT_DISPATCH = { + "bfl": map_bfl_to_diffusers, + "kontext": map_kontext_to_diffusers, + "xlabs": map_xlabs_to_diffusers, + "kohya": map_kohya_to_diffusers, +} + + +def map_lora_to_diffusers(state_dict, **kwargs): + """Detect a Flux LoRA's source format and dispatch to its per-format converter. + + Already-converted (peft) state dicts pass through after filtering to ``transformer.*`` keys. Unknown formats (incl. + diffusers-native LoRAs with raw ``.alpha`` keys) pass through unchanged so the pipeline's diffusers-native fallback + can run. + """ + if any(k.startswith("transformer.") for k in state_dict): + return {k: v for k, v in state_dict.items() if k.startswith("transformer.")} + + keys = set(state_dict) + for fmt, fmt_keys in _FLUX_LORA_FORMAT_KEYS.items(): + if any(any(fk in k for k in keys) for fk in fmt_keys): + return _FORMAT_DISPATCH[fmt](state_dict) + return state_dict + + +# Assigned to ``FluxTransformer2DModel`` as the ``_lora`` class attribute in ``flux/model.py``. +FLUX_LORA = LoRAHandler( + format_keys=_FLUX_LORA_FORMAT_KEYS, + map_lora_to_diffusers_fn=map_lora_to_diffusers, +) diff --git a/src/diffusers/models/transformers/flux/_weight_mapping.py b/src/diffusers/models/transformers/flux/_weight_mapping.py new file mode 100644 index 000000000000..26b482ebd1aa --- /dev/null +++ b/src/diffusers/models/transformers/flux/_weight_mapping.py @@ -0,0 +1,319 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ....loaders.weight_mapping import WeightMappingHandler + + +def swap_scale_shift(weight: torch.Tensor) -> torch.Tensor: + """Swap scale and shift in AdaLayerNorm weights (original uses shift,scale; diffusers uses scale,shift).""" + shift, scale = weight.chunk(2, dim=0) + return torch.cat([scale, shift], dim=0) + + +# Pattern-based key renaming (substring replacements applied in order) +FLUX_RENAME_PATTERNS: dict[str, str] = { + # Global key renames + "time_in.in_layer": "time_text_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_text_embed.timestep_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "guidance_in.in_layer": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_text_embed.guidance_embedder.linear_2", + "txt_in.": "context_embedder.", + "img_in.": "x_embedder.", + "final_layer.linear.": "proj_out.", + # Double block patterns + "double_blocks.": "transformer_blocks.", + ".img_mod.lin.": ".norm1.linear.", + ".txt_mod.lin.": ".norm1_context.linear.", + ".img_attn.norm.query_norm.scale": ".attn.norm_q.weight", + ".img_attn.norm.key_norm.scale": ".attn.norm_k.weight", + ".txt_attn.norm.query_norm.scale": ".attn.norm_added_q.weight", + ".txt_attn.norm.key_norm.scale": ".attn.norm_added_k.weight", + ".img_mlp.0.": ".ff.net.0.proj.", + ".img_mlp.2.": ".ff.net.2.", + ".txt_mlp.0.": ".ff_context.net.0.proj.", + ".txt_mlp.2.": ".ff_context.net.2.", + ".img_attn.proj.": ".attn.to_out.0.", + ".txt_attn.proj.": ".attn.to_add_out.", + # Single block patterns + "single_blocks.": "single_transformer_blocks.", + ".modulation.lin.": ".norm.linear.", + ".norm.query_norm.scale": ".attn.norm_q.weight", + ".norm.key_norm.scale": ".attn.norm_k.weight", + ".linear2.": ".proj_out.", +} + + +# -------------------------------------------------------------------------- +# Per-key transforms (split + special), unified. +# -------------------------------------------------------------------------- +# Each entry is ``(source_substring, [target_substrings], forward_fn)``: +# - source/targets include surrounding dots so they only match at module boundaries +# (e.g. ".img_attn.qkv." matches both "X.img_attn.qkv.weight" and "X.img_attn.qkv.bias"). +# - len(targets) == 1 -> a unary transform (e.g. AdaLN scale/shift swap). +# - len(targets) > 1 -> a split transform (forward chunks the tensor). +# - forward_fn(tensor, **ctx) -> list[tensor] of length len(targets). +def _swap_to_list(v, **_): + return [swap_scale_shift(v)] + + +def _make_chunk(n): + return lambda v, **_: torch.chunk(v, n, dim=0) + + +def _qkvmlp_split(v, inner_dim=3072, **_): + return torch.split(v, [inner_dim, inner_dim, inner_dim, inner_dim * 4], dim=0) + + +FLUX_TRANSFORMS = [ + ("final_layer.adaLN_modulation.1.", ["norm_out.linear."], _swap_to_list), + (".img_attn.qkv.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v."], _make_chunk(3)), + (".txt_attn.qkv.", [".attn.add_q_proj.", ".attn.add_k_proj.", ".attn.add_v_proj."], _make_chunk(3)), + (".linear1.", [".attn.to_q.", ".attn.to_k.", ".attn.to_v.", ".proj_mlp."], _qkvmlp_split), +] + + +# Backward-compat tables derived from FLUX_TRANSFORMS so the existing +# map_from_diffusers code (and any external readers, including lora.py) +# keep working without changes. +def _wrap_unary(fwd_fn): + return lambda v: fwd_fn(v)[0] + + +FLUX_SPECIAL_KEYS: dict[str, dict] = {} +FLUX_QKV_SPLIT_PATTERNS: dict[str, list[str]] = {} +FLUX_QKVMLP_SPLIT_PATTERN: str = "" +FLUX_QKVMLP_TARGETS: list[str] = [] +for _src, _tgts, _fwd in FLUX_TRANSFORMS: + if len(_tgts) == 1: + for _suffix in ("weight", "bias"): + FLUX_SPECIAL_KEYS[_src + _suffix] = { + "target": _tgts[0] + _suffix, + "transform": _wrap_unary(_fwd), + } + elif _src == ".linear1.": + FLUX_QKVMLP_SPLIT_PATTERN = _src + FLUX_QKVMLP_TARGETS = list(_tgts) + else: + FLUX_QKV_SPLIT_PATTERNS[_src] = list(_tgts) + + +def _get_inner_dim(state_dict: dict[str, torch.Tensor]) -> int: + """Infer inner_dim from state_dict weights.""" + for key in state_dict: + if "single_blocks." in key and ".linear1." in key and key.endswith(".bias"): + # linear1 contains Q, K, V, MLP fused - Q/K/V each have inner_dim + # Total size = 3 * inner_dim + mlp_hidden_dim = 3 * inner_dim + 4 * inner_dim = 7 * inner_dim + total = state_dict[key].shape[0] + return total // 7 + + return 3072 # Default + + +def apply_transforms(state_dict, transforms, rename_patterns, **ctx): + """Drive a forward state-dict conversion from a list of ``(source, targets, forward_fn)`` entries. + + For each key in ``state_dict``: scan ``transforms``; the first entry whose ``source`` substring matches expands the + value via ``forward_fn(value, **ctx)`` into one tensor per target, each at a key derived by ``key.replace(source, + target)`` then ``rename_patterns``. Keys that match no transform are just renamed. + """ + out = {} + for key, value in state_dict.items(): + for source, targets, forward_fn in transforms: + if source in key: + tensors = forward_fn(value, **ctx) + for target, tensor in zip(targets, tensors): + new_key = WeightMappingHandler.rename_key(key.replace(source, target), rename_patterns) + out[new_key] = tensor + break + else: + out[WeightMappingHandler.rename_key(key, rename_patterns)] = value + + return out + + +def map_to_diffusers( + state_dict: dict[str, torch.Tensor], + **kwargs, +) -> dict[str, torch.Tensor]: + """Convert a Flux transformer state_dict from original format to diffusers format.""" + inner_dim = _get_inner_dim(state_dict) + return apply_transforms(state_dict, FLUX_TRANSFORMS, FLUX_RENAME_PATTERNS, inner_dim=inner_dim) + + +# Build reverse patterns for map_from_diffusers +FLUX_RENAME_PATTERNS_REVERSE: dict[str, str] = {v: k for k, v in FLUX_RENAME_PATTERNS.items()} +FLUX_SPECIAL_KEYS_REVERSE: dict[str, dict] = { + v["target"]: {"target": k, "transform": v["transform"]} for k, v in FLUX_SPECIAL_KEYS.items() +} +FLUX_QKV_SPLIT_PATTERNS_REVERSE: dict[str, str] = { + target: pattern for pattern, targets in FLUX_QKV_SPLIT_PATTERNS.items() for target in targets +} + + +def map_from_diffusers( + state_dict: dict[str, torch.Tensor], + **kwargs, +) -> dict[str, torch.Tensor]: + """ + Convert a Flux transformer state_dict from diffusers format to original format. + + Args: + state_dict: State dict in diffusers format + + Returns: + State dict in original Flux format + """ + converted_state_dict = {} + keys = list(state_dict.keys()) + + # Group keys for QKV concatenation + qkv_groups: dict[str, list[tuple[str, torch.Tensor]]] = {} + qkvmlp_groups: dict[str, list[tuple[str, torch.Tensor]]] = {} + + for key in keys: + value = state_dict[key] + + # Handle special keys with transforms + if key in FLUX_SPECIAL_KEYS_REVERSE: + spec = FLUX_SPECIAL_KEYS_REVERSE[key] + converted_state_dict[spec["target"]] = spec["transform"](value) + continue + + # Check if this is part of a QKV group (double blocks) + qkv_pattern = None + for target, pattern in FLUX_QKV_SPLIT_PATTERNS_REVERSE.items(): + if target in key: + qkv_pattern = pattern + break + + if qkv_pattern and "transformer_blocks." in key: + # Build the original key by replacing target with pattern + base_key = key + for target in FLUX_QKV_SPLIT_PATTERNS_REVERSE: + if target in base_key: + base_key = base_key.replace(target, qkv_pattern) + break + orig_key = WeightMappingHandler.rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) + + if orig_key not in qkv_groups: + qkv_groups[orig_key] = [] + qkv_groups[orig_key].append((key, value)) + continue + + # Check if this is part of a QKV+MLP group (single blocks) + is_qkvmlp = False + for target in FLUX_QKVMLP_TARGETS: + if target in key and "single_transformer_blocks." in key: + base_key = key.replace(target, FLUX_QKVMLP_SPLIT_PATTERN) + orig_key = WeightMappingHandler.rename_key(base_key, FLUX_RENAME_PATTERNS_REVERSE) + + if orig_key not in qkvmlp_groups: + qkvmlp_groups[orig_key] = [] + qkvmlp_groups[orig_key].append((key, value)) + is_qkvmlp = True + break + + if is_qkvmlp: + continue + + # Standard rename + new_key = WeightMappingHandler.rename_key(key, FLUX_RENAME_PATTERNS_REVERSE) + converted_state_dict[new_key] = value + + # Concatenate QKV groups + for orig_key, items in qkv_groups.items(): + if len(items) == 3: + # Sort by the target pattern order + items.sort( + key=lambda x: next( + i + for i, t in enumerate( + FLUX_QKV_SPLIT_PATTERNS[".img_attn.qkv."] + if ".img_attn." in orig_key + else FLUX_QKV_SPLIT_PATTERNS[".txt_attn.qkv."] + ) + if t in x[0] + ) + ) + converted_state_dict[orig_key] = torch.cat([v for _, v in items], dim=0) + + # Concatenate QKV+MLP groups + for orig_key, items in qkvmlp_groups.items(): + if len(items) == 4: + items.sort(key=lambda x: next(i for i, t in enumerate(FLUX_QKVMLP_TARGETS) if t in x[0])) + converted_state_dict[orig_key] = torch.cat([v for _, v in items], dim=0) + + return converted_state_dict + + +# Distinctive keys for original format detection (only keys that use simple renaming, not splits) +_FLUX_STATE_DICT_KEYS: set[str] = { + "time_in.in_layer.weight", + "double_blocks.0.img_mod.lin.weight", +} +_FLUX_AVAILABLE_CONFIGS: dict[str, str] = { + "flux-dev": "black-forest-labs/FLUX.1-dev", + "flux-schnell": "black-forest-labs/FLUX.1-schnell", + "flux-fill": "black-forest-labs/FLUX.1-Fill-dev", + "flux-depth": "black-forest-labs/FLUX.1-Depth-dev", +} + + +def detect_config(weight_mapping, state_dict: dict[str, Any]) -> str | None: + """Detect which Flux config name matches this state_dict. + + Receives the :class:`WeightMappingHandler` (not the model class) so it can call ``is_original_format`` and + ``rename_key`` directly on the subsystem that owns them. + """ + guidance_key = "guidance_in.in_layer.bias" + x_embedder_key = "img_in.weight" + + if not weight_mapping.is_original_format(state_dict): + guidance_key = weight_mapping.rename_key(guidance_key, FLUX_RENAME_PATTERNS) + x_embedder_key = weight_mapping.rename_key(x_embedder_key, FLUX_RENAME_PATTERNS) + + if x_embedder_key not in state_dict: + return None + + if guidance_key not in state_dict: + return "flux-schnell" + + in_channels = state_dict[x_embedder_key].shape[1] + if in_channels == 384: + return "flux-fill" + if in_channels == 128: + return "flux-depth" + + return "flux-dev" + + +# Assigned to ``FluxTransformer2DModel`` as the ``_weight_mapping`` class attribute in ``flux/model.py``. +FLUX_WEIGHT_MAPPING = WeightMappingHandler( + original_format_keys=_FLUX_STATE_DICT_KEYS, + available_configs=_FLUX_AVAILABLE_CONFIGS, + map_to_diffusers_fn=map_to_diffusers, + map_from_diffusers_fn=map_from_diffusers, + detect_config_fn=detect_config, + # Kicks in only when ``detect_config`` returns ``None`` (e.g. the ``img_in`` / ``x_embedder`` key is + # absent so we can't read in_channels). Most Flux checkpoints in the wild are dev-derived, so it's + # the safest fallback config to load. + default_config="flux-dev", + default_subfolder="transformer", +) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/flux/model.py similarity index 96% rename from src/diffusers/models/transformers/transformer_flux.py rename to src/diffusers/models/transformers/flux/model.py index 78a77ebcfea9..b2201a52413f 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/flux/model.py @@ -20,26 +20,30 @@ import torch.nn as nn import torch.nn.functional as F -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import apply_lora_scale, logging -from ...utils.torch_utils import maybe_allow_in_graph -from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward -from ..attention_dispatch import dispatch_attention_fn -from ..cache_utils import CacheMixin -from ..embeddings import ( +from ....configuration_utils import register_to_config +from ....hooks._helpers import TransformerBlockMetadata +from ....loaders.lora import LoRAModelMixin +from ....utils import apply_lora_scale, logging +from ....utils.torch_utils import maybe_allow_in_graph +from ..._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ...attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ...attention_dispatch import dispatch_attention_fn +from ...cache_utils import CacheMixin +from ...embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, apply_rotary_emb, get_1d_rotary_pos_embed, ) -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ...modeling_outputs import Transformer2DModelOutput +from ...modeling_utils import ModelMixin, register_metadata +from ...normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ._ip_adapter import FluxIPAdapterMixin +from ._lora import FLUX_LORA +from ._weight_mapping import FLUX_WEIGHT_MAPPING -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +logger = logging.get_logger(__name__) def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): @@ -353,6 +357,7 @@ def forward( @maybe_allow_in_graph +@register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, return_encoder_hidden_states_index=0)) class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): super().__init__() @@ -407,6 +412,7 @@ def forward( @maybe_allow_in_graph +@register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, return_encoder_hidden_states_index=0)) class FluxTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 @@ -524,12 +530,10 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: class FluxTransformer2DModel( ModelMixin, - ConfigMixin, - PeftAdapterMixin, - FromOriginalModelMixin, - FluxTransformer2DLoadersMixin, - CacheMixin, + LoRAModelMixin, AttentionMixin, + CacheMixin, + FluxIPAdapterMixin, ): """ The Transformer model introduced in Flux. @@ -564,7 +568,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _skip_layerwise_casting_patterns = ("pos_embed", "norm") _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _cp_plan = { "": { @@ -576,6 +580,9 @@ class FluxTransformer2DModel( "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), } + _lora = FLUX_LORA + _weight_mapping = FLUX_WEIGHT_MAPPING + @register_to_config def __init__( self, @@ -676,7 +683,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index d7cc96d018b3..0a55e3202f77 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -30,7 +30,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm -from .transformer_flux import FluxAttention, FluxAttnProcessor +from .flux import FluxAttention, FluxAttnProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index bdb87a385da7..2d8bc58683b2 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -877,7 +877,6 @@ def forward( `tuple` where the first element is the sample tensor. """ hidden_states = self.img_in(hidden_states) - timestep = timestep.to(hidden_states.dtype) if self.zero_cond_t: diff --git a/src/diffusers/models/transformers/utils.py b/src/diffusers/models/transformers/utils.py new file mode 100644 index 000000000000..a99e2fcdc392 --- /dev/null +++ b/src/diffusers/models/transformers/utils.py @@ -0,0 +1,77 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utilities for transformer model implementations.""" + +from dataclasses import dataclass, fields + +import torch + + +@dataclass +class TransformerModuleOutput: + """Base class providing tuple-compatible iteration for structured submodule outputs. + + Doesn't declare any fields itself — subclasses define their own schema. Provides only the plumbing that lets + callers unpack positionally (``h, e = output``), index (``output[0]``), and check length, with ``None`` fields + transparently skipped so a single-stream output unpacks as a 1-tuple. This matches the legacy bare-tuple return + shape so subclasses can be adopted without touching callers. + """ + + def _as_tuple(self): + """Tuple-compat view of the dataclass: declared field order, with ``None`` values skipped.""" + return tuple(getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None) + + def __iter__(self): + return iter(self._as_tuple()) + + def __getitem__(self, idx): + return self._as_tuple()[idx] + + def __len__(self): + return len(self._as_tuple()) + + +@dataclass +class TransformerBlockOutput(TransformerModuleOutput): + """Structured return type for transformer-block ``forward`` methods. + + Replaces the historical pattern of returning bare tuples whose element ordering varied per model (e.g. Flux + returned ``(encoder_hidden_states, hidden_states)`` while CogVideoX returned ``(hidden_states, + encoder_hidden_states)``). Tuple-compatibility inherited from :class:`TransformerModuleOutput`. + + Attributes: + hidden_states: The block's primary output tensor. Always populated. + encoder_hidden_states: The text / context stream output for dual-stream blocks. ``None`` for single-stream. + """ + + hidden_states: torch.Tensor = None + encoder_hidden_states: torch.Tensor | None = None + + +@dataclass +class AttnProcessorOutput(TransformerModuleOutput): + """Structured return type for attention-processor ``__call__`` methods. + + Replaces the historical pattern of returning a bare tensor for single-stream attention and a bare ``(hidden_states, + encoder_hidden_states)`` tuple for dual-stream attention. Tuple-compatibility inherited from + :class:`TransformerModuleOutput`. + + Attributes: + hidden_states: The processor's primary output tensor. Always populated. + encoder_hidden_states: The text / context stream output for dual-stream attention processors. ``None`` for + single-stream. + """ + + hidden_states: torch.Tensor = None + encoder_hidden_states: torch.Tensor | None = None diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 008426f5275e..f2ba49878710 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ GGUF_FILE_EXTENSION, HF_ENABLE_PARALLEL_LOADING, HF_MODULES_CACHE, + HUB_KWARGS, HUGGINGFACE_CO_RESOLVE_ENDPOINT, MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index cbfe2da0d32a..28b842ce2594 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -42,6 +42,19 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] + +# Canonical set of hub-download kwargs (with defaults) forwarded to ``_get_model_file`` and +# related loaders. Use ``{k: kwargs.pop(k, default) for k, default in HUB_KWARGS.items()}`` to +# extract them from a caller's ``**kwargs`` in one shot. +HUB_KWARGS = { + "cache_dir": None, + "force_download": False, + "proxies": None, + "local_files_only": None, + "token": None, + "revision": None, + "subfolder": None, +} DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_VARS_TRUE_VALUES