diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index d0434991..8bcb99b7 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -210,13 +210,15 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: unwrapped = self.unwrap_model(model) state_dict = {} adapter_suffix = f'.{adapter_name}.' + fallback_state_dict = {} for name, param in unwrapped.named_parameters(): - if not _is_lora_state_key(name) or adapter_suffix not in name: + if not _is_lora_state_key(name): continue local = torch_util.to_local_tensor(param) - state_dict[name] = local.cpu() + target_dict = state_dict if adapter_suffix in name else fallback_state_dict + target_dict[name] = local.cpu() del local - return state_dict + return state_dict or fallback_state_dict def _is_lora_state_key(name: str) -> bool: diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index e1e21ff4..638770d9 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -354,6 +354,7 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: """Collect only LoRA adapter parameters, with EP-aware all-gather.""" unwrapped = self.unwrap_model(model) state_dict = {} + fallback_state_dict = {} ep_fsdp_mesh = self.ep_fsdp_device_mesh ep_group = None @@ -366,7 +367,7 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: adapter_suffix = f'.{adapter_name}.' for name, param in unwrapped.named_parameters(): - if not _is_lora_state_key(name) or adapter_suffix not in name: + if not _is_lora_state_key(name): continue local_full = torch_util.to_local_tensor(param) @@ -375,13 +376,15 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict: gathered = [torch.empty_like(local_full) for _ in range(ep_world_size)] dist.all_gather(gathered, local_full, group=ep_group) local_full = torch.cat(gathered, dim=_ep_expert_state_dict_gather_dim(name)) - state_dict[name] = local_full.cpu() + target_dict = state_dict if adapter_suffix in name else fallback_state_dict + target_dict[name] = local_full.cpu() del gathered, local_full else: - state_dict[name] = local_full.cpu() + target_dict = state_dict if adapter_suffix in name else fallback_state_dict + target_dict[name] = local_full.cpu() del local_full + return state_dict or fallback_state_dict - return state_dict def _detect_ep_expert_names(model: nn.Module) -> Set[str]: