From c5848d663223b88200786f7eb03b50c3c68e335d Mon Sep 17 00:00:00 2001 From: xichengpro Date: Tue, 2 Jun 2026 20:12:04 +0800 Subject: [PATCH 1/2] fix: fallback to all lora params to avoid empty adapter weights Added `fallback_state_dict` in `get_adapter_state_dict` to prevent saving an empty dictionary when the specific adapter suffix is not found in the parameter names. --- src/twinkle/model/transformers/strategy/accelerate.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0434991..8bcb99b7 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -210,13 +210,15 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: unwrapped = self.unwrap_model(model) state_dict = {} adapter_suffix = f'.{adapter_name}.' + fallback_state_dict = {} for name, param in unwrapped.named_parameters(): - if not _is_lora_state_key(name) or adapter_suffix not in name: + if not _is_lora_state_key(name): continue local = torch_util.to_local_tensor(param) - state_dict[name] = local.cpu() + target_dict = state_dict if adapter_suffix in name else fallback_state_dict + target_dict[name] = local.cpu() del local - return state_dict + return state_dict or fallback_state_dict def _is_lora_state_key(name: str) -> bool: From 96125973aea05c84574121741c9f35cf58da8a94 Mon Sep 17 00:00:00 2001 From: xichengpro Date: Tue, 2 Jun 2026 20:30:59 +0800 Subject: [PATCH 2/2] fix: fallback to all lora params to avoid empty adapter weights --- .../model/transformers/strategy/native_fsdp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index e1e21ff4..638770d9 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -354,6 +354,7 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: """Collect only LoRA adapter parameters, with EP-aware all-gather.""" unwrapped = self.unwrap_model(model) state_dict = {} + fallback_state_dict = {} ep_fsdp_mesh = self.ep_fsdp_device_mesh ep_group = None @@ -366,7 +367,7 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: adapter_suffix = f'.{adapter_name}.' for name, param in unwrapped.named_parameters(): - if not _is_lora_state_key(name) or adapter_suffix not in name: + if not _is_lora_state_key(name): continue local_full = torch_util.to_local_tensor(param) @@ -375,13 +376,15 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: gathered = [torch.empty_like(local_full) for _ in range(ep_world_size)] dist.all_gather(gathered, local_full, group=ep_group) local_full = torch.cat(gathered, dim=_ep_expert_state_dict_gather_dim(name)) - state_dict[name] = local_full.cpu() + target_dict = state_dict if adapter_suffix in name else fallback_state_dict + target_dict[name] = local_full.cpu() del gathered, local_full else: - state_dict[name] = local_full.cpu() + target_dict = state_dict if adapter_suffix in name else fallback_state_dict + target_dict[name] = local_full.cpu() del local_full + return state_dict or fallback_state_dict - return state_dict def _detect_ep_expert_names(model: nn.Module) -> Set[str]: