diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 7039a91..e48fb08 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -267,11 +267,16 @@ def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool): new_state_dict = {} for k, v in hf_state_dict.items(): if self._peft_format: - if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k: - k = k.replace(f'{self._adapter_name}.', '') + # Without adding a leading '.' here (e.g., '.lora_A.'), + # we avoid the case where mg_module itself is a linear layer (such as proj1). + if ('lora_A.' in k or 'lora_B.' in k + or 'modules_to_save.' in k) and f'.{self._adapter_name}.' in k: + k = k.replace(f'.{self._adapter_name}.', '.') new_state_dict[k] = v else: - if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k: + if 'lora_A.' in k or 'lora_B.' in k or 'original_module.' in k: + continue + if 'modules_to_save.' in k and f'modules_to_save.{self._adapter_name}.' not in k: continue k = k.replace('base_layer.', '') k = k.replace(f'modules_to_save.{self._adapter_name}.', '') @@ -1324,24 +1329,26 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[qkv_block + z_block:-a_block].clone() hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[-a_block:].clone() del in_proj_weight - if to_mcore: - conv1d = hf_state_dict['conv1d.weight'].load() - q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0) - conv1d = torch.cat([ - *(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]), - ], dim=1).reshape((-1, *conv1d.shape[-2:])) - self._set_weight(mg_attn.conv1d.weight, conv1d, 'conv1d.weight') - else: - conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight') - if conv1d is not None: - conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:]) - q_c, k_c, v_c = torch.split( - conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1) - q_c = q_c.reshape(-1, *q_c.shape[-2:]) - k_c = k_c.reshape(-1, *k_c.shape[-2:]) - v_c = v_c.reshape(-1, *v_c.shape[-2:]) - conv1d = torch.concat([q_c, k_c, v_c], dim=0) - hf_state_dict['conv1d.weight'] = conv1d + if not self._peft_format: + if to_mcore: + conv1d = hf_state_dict['conv1d.weight'].load() + q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0) + conv1d = torch.cat([ + *(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]), + ], + dim=1).reshape((-1, *conv1d.shape[-2:])) + self._set_weight(mg_attn.conv1d.weight, conv1d, 'conv1d.weight') + else: + conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight') + if conv1d is not None: + conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:]) + q_c, k_c, v_c = torch.split( + conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1) + q_c = q_c.reshape(-1, *q_c.shape[-2:]) + k_c = k_c.reshape(-1, *k_c.shape[-2:]) + v_c = v_c.reshape(-1, *v_c.shape[-2:]) + conv1d = torch.concat([q_c, k_c, v_c], dim=0) + hf_state_dict['conv1d.weight'] = conv1d self._set_state_dict(mg_attn, 'dt_bias', hf_state_dict, 'dt_bias', to_mcore) self._set_state_dict(mg_attn, 'A_log', hf_state_dict, 'A_log', to_mcore) self._set_state_dict(mg_attn, 'out_norm.weight', hf_state_dict, 'norm.weight', to_mcore) @@ -1703,7 +1710,7 @@ def export_weights( self.config = mg_models[0].config with torch.no_grad(): for k, v in self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc): - if converter: + if converter and v is not None: kv = converter(k, v) if kv is None: continue