Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/twinkle/model/transformers/strategy/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,15 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict:
unwrapped = self.unwrap_model(model)
state_dict = {}
adapter_suffix = f'.{adapter_name}.'
fallback_state_dict = {}
for name, param in unwrapped.named_parameters():
if not _is_lora_state_key(name) or adapter_suffix not in name:
if not _is_lora_state_key(name):
continue
local = torch_util.to_local_tensor(param)
state_dict[name] = local.cpu()
target_dict = state_dict if adapter_suffix in name else fallback_state_dict
target_dict[name] = local.cpu()
del local
return state_dict
return state_dict or fallback_state_dict
Comment on lines +213 to +221
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Populating fallback_state_dict by copying all non-matching LoRA parameters to the CPU is highly inefficient. In standard scenarios where the requested adapter is found, this results in unnecessary GPU-to-CPU transfers and memory allocations for all other adapters' parameters, which are then immediately discarded.

To optimize this, we can first check if any parameter name contains the adapter_suffix. If it does, we only copy the matching parameters. Otherwise, we fall back to copying all LoRA parameters.

Suggested change
fallback_state_dict = {}
for name, param in unwrapped.named_parameters():
if not _is_lora_state_key(name) or adapter_suffix not in name:
if not _is_lora_state_key(name):
continue
local = torch_util.to_local_tensor(param)
state_dict[name] = local.cpu()
target_dict = state_dict if adapter_suffix in name else fallback_state_dict
target_dict[name] = local.cpu()
del local
return state_dict
return state_dict or fallback_state_dict
has_adapter = any(
_is_lora_state_key(name) and adapter_suffix in name
for name, _ in unwrapped.named_parameters()
)
for name, param in unwrapped.named_parameters():
if not _is_lora_state_key(name):
continue
if has_adapter and adapter_suffix not in name:
continue
local = torch_util.to_local_tensor(param)
state_dict[name] = local.cpu()
del local
return state_dict



def _is_lora_state_key(name: str) -> bool:
Expand Down
11 changes: 7 additions & 4 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict:
"""Collect only LoRA adapter parameters, with EP-aware all-gather."""
unwrapped = self.unwrap_model(model)
state_dict = {}
fallback_state_dict = {}

ep_fsdp_mesh = self.ep_fsdp_device_mesh
ep_group = None
Expand All @@ -366,7 +367,7 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict:
adapter_suffix = f'.{adapter_name}.'

for name, param in unwrapped.named_parameters():
if not _is_lora_state_key(name) or adapter_suffix not in name:
if not _is_lora_state_key(name):
continue

local_full = torch_util.to_local_tensor(param)
Expand All @@ -375,13 +376,15 @@ def get_adapter_state_dict(self, model, adapter_name: str) -> dict:
gathered = [torch.empty_like(local_full) for _ in range(ep_world_size)]
dist.all_gather(gathered, local_full, group=ep_group)
local_full = torch.cat(gathered, dim=_ep_expert_state_dict_gather_dim(name))
state_dict[name] = local_full.cpu()
target_dict = state_dict if adapter_suffix in name else fallback_state_dict
target_dict[name] = local_full.cpu()
del gathered, local_full
else:
state_dict[name] = local_full.cpu()
target_dict = state_dict if adapter_suffix in name else fallback_state_dict
target_dict[name] = local_full.cpu()
del local_full
return state_dict or fallback_state_dict

return state_dict


def _detect_ep_expert_names(model: nn.Module) -> Set[str]:
Expand Down
Loading