Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,16 @@ def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool):
new_state_dict = {}
for k, v in hf_state_dict.items():
if self._peft_format:
if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k:
k = k.replace(f'{self._adapter_name}.', '')
# Without adding a leading '.' here (e.g., '.lora_A.'),
# we avoid the case where mg_module itself is a linear layer (such as proj1).
if ('lora_A.' in k or 'lora_B.' in k
or 'modules_to_save.' in k) and f'.{self._adapter_name}.' in k:
k = k.replace(f'.{self._adapter_name}.', '.')
new_state_dict[k] = v
else:
if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k:
if 'lora_A.' in k or 'lora_B.' in k or 'original_module.' in k:
continue
if 'modules_to_save.' in k and f'modules_to_save.{self._adapter_name}.' not in k:
continue
k = k.replace('base_layer.', '')
k = k.replace(f'modules_to_save.{self._adapter_name}.', '')
Expand Down Expand Up @@ -1324,24 +1329,26 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i
hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[qkv_block + z_block:-a_block].clone()
hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[-a_block:].clone()
del in_proj_weight
if to_mcore:
conv1d = hf_state_dict['conv1d.weight'].load()
q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0)
conv1d = torch.cat([
*(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]),
], dim=1).reshape((-1, *conv1d.shape[-2:]))
self._set_weight(mg_attn.conv1d.weight, conv1d, 'conv1d.weight')
else:
conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight')
if conv1d is not None:
conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:])
q_c, k_c, v_c = torch.split(
conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1)
q_c = q_c.reshape(-1, *q_c.shape[-2:])
k_c = k_c.reshape(-1, *k_c.shape[-2:])
v_c = v_c.reshape(-1, *v_c.shape[-2:])
conv1d = torch.concat([q_c, k_c, v_c], dim=0)
hf_state_dict['conv1d.weight'] = conv1d
if not self._peft_format:
if to_mcore:
conv1d = hf_state_dict['conv1d.weight'].load()
q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0)
conv1d = torch.cat([
*(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]),
],
dim=1).reshape((-1, *conv1d.shape[-2:]))
self._set_weight(mg_attn.conv1d.weight, conv1d, 'conv1d.weight')
else:
conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight')
if conv1d is not None:
conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:])
q_c, k_c, v_c = torch.split(
conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1)
q_c = q_c.reshape(-1, *q_c.shape[-2:])
k_c = k_c.reshape(-1, *k_c.shape[-2:])
v_c = v_c.reshape(-1, *v_c.shape[-2:])
conv1d = torch.concat([q_c, k_c, v_c], dim=0)
hf_state_dict['conv1d.weight'] = conv1d
self._set_state_dict(mg_attn, 'dt_bias', hf_state_dict, 'dt_bias', to_mcore)
self._set_state_dict(mg_attn, 'A_log', hf_state_dict, 'A_log', to_mcore)
self._set_state_dict(mg_attn, 'out_norm.weight', hf_state_dict, 'norm.weight', to_mcore)
Expand Down Expand Up @@ -1703,7 +1710,7 @@ def export_weights(
self.config = mg_models[0].config
with torch.no_grad():
for k, v in self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc):
if converter:
if converter and v is not None:
kv = converter(k, v)
if kv is None:
continue
Expand Down
Loading