diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index b9d184c..4b17274 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -21,6 +21,7 @@ from packaging import version from peft.tuners.tuners_utils import BaseTuner from torch import nn +from transformers.utils import is_torch_npu_available from typing import Callable, List, Optional, Tuple from mcore_bridge.utils import get_logger, is_flash_attn_3_available @@ -593,8 +594,25 @@ def forward(self, *_args, **kwargs): def _patch_TELinear(): def __repr__(self): - return (f'{type(self).__name__}(in_features={self.in_features}, ' - f'out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})') + if is_torch_npu_available(): + # MindSpeed 0.15.x changes some TE debug fields to + # input_size/output_size. Keep this compatibility on the NPU path + # only so GPU and older versions retain their original field + # semantics. + in_features = getattr(self, 'in_features', getattr(self, 'input_size', None)) + out_features = getattr(self, 'out_features', getattr(self, 'output_size', None)) + use_bias = getattr(self, 'use_bias', getattr(self, 'bias', None) is not None) + tp_size = getattr(self, 'tp_size', None) + if tp_size is None: + parallel_mode = getattr(self, 'parallel_mode', None) + tp_size = 1 if parallel_mode == 'duplicated' else 'unknown' + else: + in_features = self.in_features + out_features = self.out_features + use_bias = self.use_bias + tp_size = self.tp_size + return (f'{type(self).__name__}(in_features={in_features}, ' + f'out_features={out_features}, bias={use_bias}, TP={tp_size})') TELinear.__repr__ = __repr__ diff --git a/src/mcore_bridge/tuners/lora.py b/src/mcore_bridge/tuners/lora.py index 6557178..2cc0f16 100644 --- a/src/mcore_bridge/tuners/lora.py +++ b/src/mcore_bridge/tuners/lora.py @@ -6,6 +6,7 @@ import torch.nn.functional as F import warnings from contextlib import contextmanager, nullcontext +from importlib import metadata from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import (TEColumnParallelGroupedLinear, TEColumnParallelLinear, @@ -30,6 +31,67 @@ mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0') +MINDSPEED_015 = version.parse('0.15.0') + + +def _get_mindspeed_version(): + try: + return version.parse(metadata.version('mindspeed')) + except metadata.PackageNotFoundError: + return None + except Exception: + return None + + +def _use_legacy_npu_local_linear() -> bool: + if not is_torch_npu_available(): + return False + mindspeed_version = _get_mindspeed_version() + if mindspeed_version is None: + # Fall back to the conservative path when the version is unknown so we + # do not force an older NPU stack onto the 0.15 TE semantics. + return True + return mindspeed_version < MINDSPEED_015 + + +def _build_local_te_linear(input_size: int, output_size: int, bias: bool, **kwargs): + if _use_legacy_npu_local_linear(): + return nn.Linear( + in_features=input_size, + out_features=output_size, + bias=bias, + ) + local_kwargs = dict(kwargs) + parallel_mode = None + if is_torch_npu_available(): + # Local TE linear layers in MindSpeed 0.15.x use duplicated semantics, + # and this path does not accept tp_group. + local_kwargs.pop('tp_group', None) + parallel_mode = 'duplicated' + return TELinear( + input_size=input_size, + output_size=output_size, + bias=bias, + parallel_mode=parallel_mode, + skip_weight_param_allocation=False, + **local_kwargs, + ) + + +def _get_tensor_parallel_group_for_lora(base_layer): + """Resolve the tensor-parallel group across TE and MindSpeed TE variants. + + Megatron's TE layers expose ``tp_group`` directly, but MindSpeed 0.15.x + replaces some TE classes (for example + ``MindSpeedTELayerNormColumnParallelLinear``) with implementations that keep + the same tensor-parallel semantics under ``parallel_group`` instead. LoRA + still needs to forward the right group into the newly created parallel + adapter layers, otherwise adapter injection fails before training starts. + """ + tp_group = getattr(base_layer, 'tp_group', None) + if tp_group is not None: + return tp_group + return getattr(base_layer, 'parallel_group', None) class LoraParallelLinear(MegatronModule, LoraLayer): @@ -102,25 +164,13 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w 'is_expert': self.is_expert, } if mcore_013 and not (mcore_016 and self.is_grouped): - kwargs['tp_group'] = self.base_layer.tp_group + tp_group = _get_tensor_parallel_group_for_lora(self.base_layer) + if tp_group is not None: + kwargs['tp_group'] = tp_group if isinstance(self.base_layer, TopKRouter): router_shape = self.base_layer.weight.shape - lora_a = TELinear( - input_size=router_shape[1], - output_size=r, - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs, - ) - lora_b = TELinear( - input_size=r, - output_size=router_shape[0], - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs, - ) + lora_a = _build_local_te_linear(router_shape[1], r, lora_bias, **kwargs) + lora_b = _build_local_te_linear(r, router_shape[0], lora_bias, **kwargs) elif self.is_parallel_a: in_features = self.in_features * self.tp_size if self.is_grouped: @@ -147,14 +197,7 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w input_is_parallel=True, **kwargs, ) - lora_b = TELinear( - input_size=r, - output_size=self.out_features, - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs, - ) + lora_b = _build_local_te_linear(r, self.out_features, lora_bias, **kwargs) lora_a.parallel_mode = self.base_layer.parallel_mode # fix moe_shared_expert_overlap else: if is_torch_npu_available(): @@ -177,20 +220,7 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w **kwargs, ) else: - if is_torch_npu_available(): - lora_a = nn.Linear( - in_features=self.in_features, - out_features=r, - bias=lora_bias, - ) - else: - lora_a = TELinear( - input_size=self.in_features, - output_size=r, - bias=lora_bias, - parallel_mode=None, - skip_weight_param_allocation=False, - **kwargs) + lora_a = _build_local_te_linear(self.in_features, r, lora_bias, **kwargs) lora_b = TEColumnParallelLinear( input_size=r, output_size=out_features,