Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from packaging import version
from peft.tuners.tuners_utils import BaseTuner
from torch import nn
from transformers.utils import is_torch_npu_available
from typing import Callable, List, Optional, Tuple

from mcore_bridge.utils import get_logger, is_flash_attn_3_available
Expand Down Expand Up @@ -593,8 +594,25 @@ def forward(self, *_args, **kwargs):
def _patch_TELinear():

def __repr__(self):
return (f'{type(self).__name__}(in_features={self.in_features}, '
f'out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})')
if is_torch_npu_available():
# MindSpeed 0.15.x changes some TE debug fields to
# input_size/output_size. Keep this compatibility on the NPU path
# only so GPU and older versions retain their original field
# semantics.
in_features = getattr(self, 'in_features', getattr(self, 'input_size', None))
out_features = getattr(self, 'out_features', getattr(self, 'output_size', None))
use_bias = getattr(self, 'use_bias', getattr(self, 'bias', None) is not None)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the NPU path, use_bias = getattr(self, 'use_bias', getattr(self, 'bias', None) is not None) can misreport bias when self.bias is a boolean (e.g., False still makes the expression true because it is not None). To make __repr__ robust across TE/MindSpeed variants, compute use_bias by first reading bias_attr = getattr(self, 'bias', None) and handling the boolean case explicitly (otherwise fall back to bias_attr is not None).

Suggested change
use_bias = getattr(self, 'use_bias', getattr(self, 'bias', None) is not None)
bias_attr = getattr(self, 'bias', None)
if hasattr(self, 'use_bias'):
use_bias = self.use_bias
elif isinstance(bias_attr, bool):
use_bias = bias_attr
else:
use_bias = bias_attr is not None

Copilot uses AI. Check for mistakes.
tp_size = getattr(self, 'tp_size', None)
if tp_size is None:
parallel_mode = getattr(self, 'parallel_mode', None)
tp_size = 1 if parallel_mode == 'duplicated' else 'unknown'
Comment on lines +607 to +608
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When tp_size is missing, it is inferred from parallel_mode. However, parallel_mode can be None for local layers (as seen in the GPU path or default initialization), which should also imply a tp_size of 1. Currently, it defaults to 'unknown' in this case.

Suggested change
parallel_mode = getattr(self, 'parallel_mode', None)
tp_size = 1 if parallel_mode == 'duplicated' else 'unknown'
parallel_mode = getattr(self, 'parallel_mode', None)
tp_size = 1 if parallel_mode in ('duplicated', None) else 'unknown'

else:
in_features = self.in_features
out_features = self.out_features
use_bias = self.use_bias
tp_size = self.tp_size
return (f'{type(self).__name__}(in_features={in_features}, '
f'out_features={out_features}, bias={use_bias}, TP={tp_size})')

TELinear.__repr__ = __repr__

Expand Down
108 changes: 69 additions & 39 deletions src/mcore_bridge/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
import warnings
from contextlib import contextmanager, nullcontext
from importlib import metadata
from megatron.core import parallel_state
Comment on lines 8 to 10
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from importlib import metadata introduces a module-level name that collides with the existing metadata parameter used later in sharded_state_dict(...). This shadowing is legal but makes the file harder to read and can lead to accidental misuse of the module vs. the parameter. Consider importing with an alias (e.g., importlib_metadata) and updating _get_mindspeed_version() accordingly.

Copilot uses AI. Check for mistakes.
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.extensions.transformer_engine import (TEColumnParallelGroupedLinear, TEColumnParallelLinear,
Expand All @@ -30,6 +31,67 @@

mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0')
MINDSPEED_015 = version.parse('0.15.0')


def _get_mindspeed_version():
try:
return version.parse(metadata.version('mindspeed'))
except metadata.PackageNotFoundError:
return None
except Exception:
return None
Comment on lines +37 to +43
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _get_mindspeed_version function is called multiple times during the LoRA update process. Since metadata.version involves filesystem access and parsing, it is inefficient to call it repeatedly. Caching the result would improve performance.

Suggested change
def _get_mindspeed_version():
try:
return version.parse(metadata.version('mindspeed'))
except metadata.PackageNotFoundError:
return None
except Exception:
return None
_MINDSPEED_VERSION = None
def _get_mindspeed_version():
global _MINDSPEED_VERSION
if _MINDSPEED_VERSION is not None:
return _MINDSPEED_VERSION
try:
_MINDSPEED_VERSION = version.parse(metadata.version('mindspeed'))
except (metadata.PackageNotFoundError, Exception):
_MINDSPEED_VERSION = False
return _MINDSPEED_VERSION if _MINDSPEED_VERSION is not False else None



def _use_legacy_npu_local_linear() -> bool:
if not is_torch_npu_available():
return False
mindspeed_version = _get_mindspeed_version()
if mindspeed_version is None:
# Fall back to the conservative path when the version is unknown so we
# do not force an older NPU stack onto the 0.15 TE semantics.
return True
return mindspeed_version < MINDSPEED_015


def _build_local_te_linear(input_size: int, output_size: int, bias: bool, **kwargs):
if _use_legacy_npu_local_linear():
return nn.Linear(
in_features=input_size,
out_features=output_size,
bias=bias,
)
local_kwargs = dict(kwargs)
parallel_mode = None
if is_torch_npu_available():
# Local TE linear layers in MindSpeed 0.15.x use duplicated semantics,
# and this path does not accept tp_group.
local_kwargs.pop('tp_group', None)
parallel_mode = 'duplicated'
return TELinear(
input_size=input_size,
output_size=output_size,
bias=bias,
parallel_mode=parallel_mode,
skip_weight_param_allocation=False,
**local_kwargs,
)


def _get_tensor_parallel_group_for_lora(base_layer):
"""Resolve the tensor-parallel group across TE and MindSpeed TE variants.

Megatron's TE layers expose ``tp_group`` directly, but MindSpeed 0.15.x
replaces some TE classes (for example
``MindSpeedTELayerNormColumnParallelLinear``) with implementations that keep
the same tensor-parallel semantics under ``parallel_group`` instead. LoRA
still needs to forward the right group into the newly created parallel
adapter layers, otherwise adapter injection fails before training starts.
"""
tp_group = getattr(base_layer, 'tp_group', None)
if tp_group is not None:
return tp_group
return getattr(base_layer, 'parallel_group', None)


class LoraParallelLinear(MegatronModule, LoraLayer):
Expand Down Expand Up @@ -102,25 +164,13 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
'is_expert': self.is_expert,
}
if mcore_013 and not (mcore_016 and self.is_grouped):
kwargs['tp_group'] = self.base_layer.tp_group
tp_group = _get_tensor_parallel_group_for_lora(self.base_layer)
if tp_group is not None:
kwargs['tp_group'] = tp_group
if isinstance(self.base_layer, TopKRouter):
router_shape = self.base_layer.weight.shape
lora_a = TELinear(
input_size=router_shape[1],
output_size=r,
bias=lora_bias,
parallel_mode=None,
skip_weight_param_allocation=False,
**kwargs,
)
lora_b = TELinear(
input_size=r,
output_size=router_shape[0],
bias=lora_bias,
parallel_mode=None,
skip_weight_param_allocation=False,
**kwargs,
)
lora_a = _build_local_te_linear(router_shape[1], r, lora_bias, **kwargs)
lora_b = _build_local_te_linear(r, router_shape[0], lora_bias, **kwargs)
elif self.is_parallel_a:
in_features = self.in_features * self.tp_size
if self.is_grouped:
Expand All @@ -147,14 +197,7 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
input_is_parallel=True,
**kwargs,
)
lora_b = TELinear(
input_size=r,
output_size=self.out_features,
bias=lora_bias,
parallel_mode=None,
skip_weight_param_allocation=False,
**kwargs,
)
lora_b = _build_local_te_linear(r, self.out_features, lora_bias, **kwargs)
lora_a.parallel_mode = self.base_layer.parallel_mode # fix moe_shared_expert_overlap
else:
if is_torch_npu_available():
Expand All @@ -177,20 +220,7 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w
**kwargs,
)
else:
if is_torch_npu_available():
lora_a = nn.Linear(
in_features=self.in_features,
out_features=r,
bias=lora_bias,
)
else:
lora_a = TELinear(
input_size=self.in_features,
output_size=r,
bias=lora_bias,
parallel_mode=None,
skip_weight_param_allocation=False,
**kwargs)
lora_a = _build_local_te_linear(self.in_features, r, lora_bias, **kwargs)
lora_b = TEColumnParallelLinear(
input_size=r,
output_size=out_features,
Expand Down
Loading