Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion transformer_engine/debug/features/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,12 @@ def call_feature(self, call, feat_config, layer_name, **kwargs):
"""
if call.__name__ == "inspect_tensor":
kwargs_copy = kwargs.copy()
for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]:
for k in [
"quantizer",
"columnwise_quantized_tensor",
"rowwise_quantized_tensor",
"tp_size",
]:
if k not in call.__code__.co_varnames:
kwargs_copy.pop(k)
else:
Expand All @@ -490,6 +495,10 @@ def call_feature(self, call, feat_config, layer_name, **kwargs):
"inspect_tensor_postquantize is deprecated, use inspect_tensor instead.",
DeprecationWarning,
)
kwargs_copy = kwargs.copy()
for k in ["tp_size"]:
if k not in call.__code__.co_varnames:
kwargs_copy.pop(k, None)

return call(feat_config, layer_name, **kwargs_copy)

Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/debug/features/log_fp8_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def inspect_tensor(
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
tp_size: int = 1,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
Expand Down Expand Up @@ -357,7 +358,7 @@ def inspect_tensor(
)

skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
tensor_name, tp_group, tp_size
)

STATS_BUFFERS.try_add_buffer(
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/debug/features/log_nvfp4_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def inspect_tensor(
rowwise_quantized_tensor: Optional[QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
tp_size: int = 1,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
Expand Down Expand Up @@ -199,7 +200,7 @@ def inspect_tensor(
)

skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
tensor_name, tp_group, tp_size
)

# Add nvfp4_ prefix to all stats for internal use
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/debug/features/log_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def inspect_tensor(
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
tp_size: int = 1,
): # pylint: disable=unused-argument
"""API call used to collect the data about the tensor before process_tensor()/quantization."""

Expand Down Expand Up @@ -214,7 +215,7 @@ def inspect_tensor(
)

skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
tensor_name, tp_group, tp_size
)

for stat in config["stats"]:
Expand Down
12 changes: 9 additions & 3 deletions transformer_engine/debug/features/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
from transformer_engine.debug.pytorch.debug_state import TEDebugState


def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup):
def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup, tp_size: int):
"""
Returns the statistics reduction parameters for the tensor.
"""
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
if TEDebugState.weight_tensor_tp_group_reduce and tp_size > 1:
# Do not overwrite with `None`: in torch.distributed collectives
# group=None means the default/world process group.
if tp_group is not None:
reduction_group = tp_group
else:
# "Reduce in TP group" requested, but TP group is missing.
skip_reduction = True
else:
skip_reduction = True
return skip_reduction, reduction_group, reduce_within_microbatch
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/debug/pytorch/debug_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def __init__(
tensor_name: str,
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
tp_size: int,
):

super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.tp_size = tp_size
self.iteration = TEDebugState.get_iteration()

# Configure parent quantizer
Expand Down Expand Up @@ -263,6 +265,7 @@ def _call_inspect_tensor_api(
"tensor_name": self.tensor_name,
"iteration": TEDebugState.get_iteration(),
"tp_group": self.tp_group,
"tp_size": self.tp_size,
"columnwise_quantized_tensor": columnwise_gemm_tensor,
"rowwise_quantized_tensor": rowwise_gemm_tensor,
"quantizer": self.parent_quantizer,
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ def _get_debug_quantizers(self):
names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
[
DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group)
DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group, self.tp_size)
for q_id, q in enumerate(qs)
]
for name, qs in zip(names, original_quantizers)
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):

names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size)
for name, q in zip(names, original_quantizers)
)

Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,7 @@ def make_debug(prefix, offset):
label,
None if label in ("dgrad", "wgrad") else base_quantizers[i + offset],
self.tp_group,
self.tp_size,
)
for i, label in enumerate(labels)
]
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):

names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"]
return tuple(
DebugQuantizer(self.name, name, q, self.tp_group)
DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size)
for name, q in zip(names, original_quantizers)
)

Expand Down
Loading