Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 101 additions & 9 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>

#include "../test_common.h"
Expand All @@ -32,6 +33,7 @@ namespace {
enum class InputCase {
kFP8Current,
kBF16,
kMXFP8,
};

enum class ShapeCase {
Expand All @@ -44,16 +46,29 @@ enum class ShapeCase {
size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
size_t size = 6 * ptr_bytes + 6 * int_bytes;
// Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays
size_t size = 8 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}

Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
fillUniform(&input_fp32);

const size_t numel = shape[0] * shape[1];
std::vector<float> data(numel);
std::mt19937 gen(std::hash<std::string>{}(name));
// Random mean and stddev -> different amax per tensor -> different scales
std::uniform_real_distribution<float> param_dis(0.1f, 10.0f);
float mean = param_dis(gen);
float stddev = param_dis(gen);
std::normal_distribution<float> dis(mean, stddev);
for (size_t i = 0; i < numel; ++i) {
data[i] = dis(gen);
}
NVTE_CHECK_CUDA(cudaMemcpy(input_fp32.rowwise_dptr(), data.data(),
numel * sizeof(float), cudaMemcpyHostToDevice));

Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);

Expand All @@ -73,6 +88,63 @@ Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& sha
return t;
}

// Creates an MXFP8 operand with the correct data layout for GEMM.
// MXFP8 GEMM requirements (scales are along K dimension):
// A transposed -> needs rowwise data/scales
// A non-transposed -> needs columnwise data/scales
// B transposed -> needs columnwise data/scales
// B non-transposed -> needs rowwise data/scales
Tensor make_mxfp8_operand(const std::string& name, const std::vector<size_t>& shape,
bool is_A, bool transposed) {
// Determine which data layout we need
bool use_rowwise, use_colwise;
if (is_A) {
// A: transposed -> rowwise, non-transposed -> columnwise
use_rowwise = transposed;
use_colwise = !transposed;
} else {
// B: transposed -> columnwise, non-transposed -> rowwise (opposite of A!)
use_rowwise = !transposed;
use_colwise = transposed;
}

// Create BF16 input with random data
Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16);
fillUniform(&input_bf16);

// Create MXFP8 tensor with only the required data layout
Tensor mxfp8(name, shape, TypeInfo<fp8e4m3>::dtype, use_rowwise, use_colwise,
NVTE_MXFP8_1D_SCALING);

// Quantize BF16 -> MXFP8
nvte_quantize(input_bf16.data(), mxfp8.data(), 0);

// Create output tensor for swizzled scales (same data shape, same layout)
Tensor mxfp8_swizzled(name + "_swizzled", shape, TypeInfo<fp8e4m3>::dtype,
use_rowwise, use_colwise, NVTE_MXFP8_1D_SCALING);
mxfp8_swizzled.set_with_gemm_swizzled_scales(true); // Must be set BEFORE swizzle call

// Copy quantized data from mxfp8 to mxfp8_swizzled
if (use_rowwise) {
size_t data_bytes = test::bytes(mxfp8.rowwise_shape(), mxfp8.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.rowwise_dptr(), mxfp8.rowwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
}
if (use_colwise) {
size_t data_bytes = test::bytes(mxfp8.columnwise_shape(), mxfp8.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(mxfp8_swizzled.columnwise_dptr(), mxfp8.columnwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
}

// Swizzle scales for GEMM
nvte_swizzle_scaling_factors(mxfp8.data(), mxfp8_swizzled.data(), 0);

// Sync to ensure operations are complete
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

return mxfp8_swizzled;
}

struct TestParams {
InputCase input_case;
bool transa;
Expand All @@ -88,16 +160,16 @@ struct TestParams {
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}};
}
}

Expand Down Expand Up @@ -138,6 +210,13 @@ void run_grouped_gemm_case(const TestParams& params) {
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kMXFP8: {
A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape,
/*is_A=*/true, params.transa));
B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape,
/*is_A=*/false, params.transb));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
Expand Down Expand Up @@ -246,7 +325,9 @@ void run_grouped_gemm_case(const TestParams& params) {
cublas_ws.data(),
nullptr, // config (use defaults)
0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Compare results
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
Expand Down Expand Up @@ -277,7 +358,7 @@ TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
}

std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
Expand All @@ -288,16 +369,27 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest

// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
// FP8 tests (each tensor has random mean/stddev -> different scales)
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
// BF16 tests
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
// MXFP8 tests
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, true, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, false, ShapeCase::kAllSame, false},
{InputCase::kMXFP8, false, false, ShapeCase::kAllDifferent, false},
{InputCase::kMXFP8, false, false, ShapeCase::kSameFirst, false},
// MXFP8 with NULL C
{InputCase::kMXFP8, true, false, ShapeCase::kAllSame, true},
};

INSTANTIATE_TEST_SUITE_P(OperatorTest,
Expand Down
114 changes: 96 additions & 18 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,14 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode) {
NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build.");
const NVTEShape shape = tensors[0]->rowwise_shape();

// Check which data layouts are available (all tensors must have the same)
const bool has_rowwise = tensors[0]->rowwise();
const bool has_columnwise = tensors[0]->columnwise();
NVTE_CHECK(has_rowwise || has_columnwise, "Tensors must have at least one data layout.");

const NVTEShape shape = has_rowwise ? tensors[0]->rowwise_shape()
: tensors[0]->columnwise_shape();
const DType dtype = tensors[0]->dtype();
const size_t num_tensors = tensors.size();
const size_t elem_size = typeToNumBits(dtype) / 8;
Expand All @@ -1076,7 +1083,8 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
std::vector<int64_t> first_dims(num_tensors);
std::vector<int64_t> last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
const auto s = tensors[i]->rowwise_shape();
const auto s = has_rowwise ? tensors[i]->rowwise_shape()
: tensors[i]->columnwise_shape();
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
first_dims[i] = static_cast<int64_t>(s.data[0]);
last_dims[i] = static_cast<int64_t>(s.data[1]);
Expand Down Expand Up @@ -1105,10 +1113,11 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
};

const bool need_offsets = !same_first || !same_last;
const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING;
if (need_offsets) {
offsets[0] = 0;
for (size_t i = 1; i < num_tensors; ++i) {
offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding();
offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0);
}
} else {
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -1146,21 +1155,24 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
: (logical_first * logical_last);
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;

grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}

NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
NVTEGroupedTensor h = grouped.handle.get();
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));

const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
if (include_columnwise) {
// Copy rowwise data if available
if (has_rowwise) {
grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));
}

// Copy columnwise data if available
if (has_columnwise) {
grouped.columnwise_data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
Expand Down Expand Up @@ -1202,11 +1214,17 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor));
}

if (isFp8Type(dtype)) {
if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
// FP8 tensor scaling: one float scale_inv per tensor
// For delayed scaling, rowwise and columnwise share the same scale
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
if (has_rowwise) {
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
} else {
scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr<float>()[0];
}
}
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
Expand All @@ -1217,6 +1235,66 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
sizeof(scale_tensor));
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor,
sizeof(scale_tensor));
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
// MXFP8: E8M0 scale_inv per block of 32 elements
// Helper to gather scale_inv from individual tensors into a contiguous buffer
auto gather_scales = [&](
auto get_shape_fn,
auto get_cpu_ptr_fn) -> std::pair<CudaPtr<>, size_t> {
// Compute total size and offsets
size_t total_bytes = 0;
std::vector<size_t> scale_offsets(num_tensors);
std::vector<size_t> numels(num_tensors);

for (size_t i = 0; i < num_tensors; ++i) {
scale_offsets[i] = total_bytes;
const NVTEShape shape = get_shape_fn(tensors[i]);
size_t numel = 1;
for (size_t d = 0; d < shape.ndim; ++d) {
numel *= shape.data[d];
}
numels[i] = numel;
total_bytes += numel; // E8M0 is 1 byte per element
}

// Allocate and copy
CudaPtr<> buffer = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
Comment on lines +1262 to +1268
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant CPU sync for swizzled MXFP8 scales.

The loop calls tensors[i]->to_cpu() on line 1263, then immediately passes the tensor to get_cpu_ptr_fn(tensors[i]) on line 1267. However, both rowwise_cpu_scale_inv_ptr<uint8_t>() and columnwise_cpu_scale_inv_ptr<uint8_t>() internally call to_cpu() themselves (test_common.h lines 249 and 264), making the explicit call on line 1263 redundant.

Additionally, the GPU pointers are available directly via get_rowwise_scale_inv().data_ptr and get_columnwise_scale_inv().data_ptr, allowing a device-to-device copy that avoids the round-trip entirely:

Suggested change
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
NVTE_CHECK_CUDA(cudaMemcpy(dst,
has_rowwise ? tensors[i]->tensor_.get_rowwise_scale_inv().data_ptr
: tensors[i]->tensor_.get_columnwise_scale_inv().data_ptr,
numels[i],
cudaMemcpyDeviceToDevice));

This improves both clarity and efficiency in test code.

return {std::move(buffer), total_bytes};
};

// Gather rowwise scale_inv if available
if (has_rowwise) {
auto [row_buffer, row_total] = gather_scales(
[](Tensor* t) { return t->rowwise_scale_inv_shape(); },
[](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr<uint8_t>(); });
grouped.scale_inv = std::move(row_buffer);

NVTEShape row_shape = nvte_make_shape(&row_total, 1);
NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor));
}

// Gather columnwise scale_inv if available
if (has_columnwise) {
auto [col_buffer, col_total] = gather_scales(
[](Tensor* t) { return t->columnwise_scale_inv_shape(); },
[](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr<uint8_t>(); });
grouped.columnwise_scale_inv = std::move(col_buffer);

NVTEShape col_shape = nvte_make_shape(&col_total, 1);
NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor));
}

// Mark as having swizzled scales (required for GEMM)
nvte_set_grouped_tensor_swizzled_scales(h, 1);
}

return grouped;
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ struct GroupedBuffers {
GroupedTensorHandle handle;
CudaPtr<> data;
CudaPtr<> scale_inv;
CudaPtr<> columnwise_scale_inv;
CudaPtr<int64_t> first_dims_dev;
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ struct GroupedTensor {
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 1)),
nvte_tensor(0) {}
nvte_tensor(0),
with_gemm_swizzled_scales(false) {}

explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }

Expand Down
Loading
Loading