diff --git a/CMakeLists.txt b/CMakeLists.txt index 9973438cf..976913c63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF) option(WITH_TORCH "Enable PyTorch C++ backend" OFF) +option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF) + # Default OFF until CANN's `extract_host_stub.py` path handling is fixed for # `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed # object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the @@ -29,6 +31,14 @@ option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) +# NineToothed code generation configuration. +set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation") +set(NINETOOTHED_SOURCE_DIR "" CACHE PATH "Optional local ninetoothed source checkout; installed package is used when empty") +set(INFINIOPS_NINETOOTHED_OPS "rms_norm" CACHE STRING "Semicolon- or comma-separated NineToothed ops to generate") +set(INFINIOPS_NINETOOTHED_DTYPES "float32;float16;bfloat16" CACHE STRING "Semicolon- or comma-separated NineToothed dtypes to generate") +set(INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS "2;3" CACHE STRING "Semicolon- or comma-separated RmsNorm input ranks to generate with NineToothed") +set(INFINIOPS_NINETOOTHED_BLOCK_SIZE "256" CACHE STRING "Block size baked into simple NineToothed elementwise kernels") + if(AUTO_DETECT_DEVICES) message(STATUS "Auto-detecting available devices...") @@ -231,6 +241,10 @@ if(_gpu_backend_count GREATER 1) message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") endif() +if(WITH_NINETOOTHED AND NOT WITH_NVIDIA) + message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because ninetoothed AOT uses caller=`cuda`.") +endif() + if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) enable_language(CUDA) diff --git a/scripts/generate_ninetoothed_ops.py b/scripts/generate_ninetoothed_ops.py new file mode 100644 index 000000000..612015ecd --- /dev/null +++ b/scripts/generate_ninetoothed_ops.py @@ -0,0 +1,15 @@ +import pathlib +import sys + + +def main(): + project_dir = pathlib.Path(__file__).resolve().parents[1] + sys.path.insert(0, str(project_dir / "src")) + + from native.ninetoothed.codegen import main as codegen_main + + codegen_main() + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 762b9d48f..246bc0d22 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -39,6 +39,57 @@ if(WITH_NVIDIA) target_compile_definitions(infiniops PUBLIC WITH_NVIDIA=1) target_sources(infiniops PRIVATE ${NVIDIA_SOURCES}) + if(WITH_NINETOOTHED) + find_package(Python COMPONENTS Interpreter REQUIRED) + + if(NINETOOTHED_PYTHON_EXECUTABLE) + set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}") + elseif(_TORCH_PYTHON) + set(_ninetoothed_python "${_TORCH_PYTHON}") + else() + set(_ninetoothed_python "${Python_EXECUTABLE}") + endif() + message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}") + + string(REPLACE "," ";" _ninetoothed_ops "${INFINIOPS_NINETOOTHED_OPS}") + string(REPLACE "," ";" _ninetoothed_dtypes "${INFINIOPS_NINETOOTHED_DTYPES}") + string(REPLACE "," ";" _ninetoothed_rms_norm_ndims "${INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS}") + + set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed") + set(_ninetoothed_generator_args + "${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py" + --output-dir "${_ninetoothed_output_dir}" + --ops ${_ninetoothed_ops} + --dtypes ${_ninetoothed_dtypes} + --rms-norm-ndims ${_ninetoothed_rms_norm_ndims} + --block-size "${INFINIOPS_NINETOOTHED_BLOCK_SIZE}") + + if(NINETOOTHED_SOURCE_DIR) + list(APPEND _ninetoothed_generator_args + --ninetoothed-source-dir "${NINETOOTHED_SOURCE_DIR}") + endif() + + execute_process( + COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args} + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + RESULT_VARIABLE _ninetoothed_generation_result + ) + + if(NOT _ninetoothed_generation_result EQUAL 0) + message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `triton`, `sympy`, and CUDA dependencies installed.") + endif() + + include("${_ninetoothed_output_dir}/manifest.cmake") + set(_ninetoothed_compile_definitions + WITH_NINETOOTHED=1 + INFINIOPS_NINETOOTHED_BLOCK_SIZE=${INFINIOPS_NINETOOTHED_BLOCK_SIZE}) + target_compile_definitions(infiniops PUBLIC + ${_ninetoothed_compile_definitions}) + target_include_directories(infiniops PUBLIC + ${INFINIOPS_NINETOOTHED_INCLUDE_DIRS}) + target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES}) + endif() + find_package(CUDAToolkit REQUIRED) target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cuda_driver) diff --git a/src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h b/src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h new file mode 100644 index 000000000..ce69381e7 --- /dev/null +++ b/src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h @@ -0,0 +1,101 @@ +#ifndef INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ +#define INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ + +#ifdef WITH_NINETOOTHED + +#include +#include +#include + +#include "base/rms_norm.h" +#include "data_type.h" +#include "native/ninetoothed/tensor.h" +#include "rms_norm/infiniops_ninetoothed_rms_norm.h" + +#ifndef INFINIOPS_NINETOOTHED_BLOCK_SIZE +#define INFINIOPS_NINETOOTHED_BLOCK_SIZE 256 +#endif + +namespace infini::ops { + +namespace detail { + +inline NineToothedTensor ExpandedRmsNormWeight(const Tensor& weight, + const Tensor::Shape& shape, + std::vector& sizes, + std::vector& strides) { + sizes.assign(shape.begin(), shape.end()); + strides.assign(shape.size(), 0); + strides.back() = weight.strides().empty() ? 1 : weight.strides().back(); + + return NineToothedTensor{const_cast(weight.data()), sizes.data(), + strides.data()}; +} + +inline NineToothedResult LaunchNineToothedRmsNorm( + const Tensor::Shape& shape, DataType dtype, NineToothedStream stream, + NineToothedTensor input, NineToothedTensor weight, NineToothedTensor eps, + NineToothedTensor out, NineToothedTensor num_normalized_elements) { + const int dtype_index = ninetoothed::DTypeIndex(dtype); + + if (dtype_index < 0) { + return 1; + } + + return launch_infiniops_ninetoothed_rms_norm( + stream, input, weight, eps, out, num_normalized_elements, + ninetoothed::SizeArg(shape.size()), 1, dtype_index, dtype_index, + dtype_index, INFINIOPS_NINETOOTHED_BLOCK_SIZE); +} + +} // namespace detail + +template <> +class Operator : public RmsNorm { + public: + using RmsNorm::RmsNorm; + using RmsNorm::operator(); + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() && + "operator `RmsNorm` requires all input and output tensors to have " + "the same dtype"); + assert(input.shape() == out.shape() && + "ninetoothed `RmsNorm` requires input and output tensors with the " + "same shape"); + assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) && + "ninetoothed `RmsNorm` requires a 1D weight matching the last " + "dimension"); + assert((out.ndim() == 2 || out.ndim() == 3) && + "ninetoothed `RmsNorm` currently supports rank-2 and rank-3 " + "tensors"); + + std::vector weight_sizes; + std::vector weight_strides; + double eps_value = static_cast(eps); + std::int64_t num_normalized_elements = + static_cast(out.size(-1)); + std::uint64_t empty_shape[1] = {}; + std::int64_t empty_strides[1] = {}; + + auto result = detail::LaunchNineToothedRmsNorm( + out.shape(), out.dtype(), static_cast(stream_), + ninetoothed::FromTensor(input), + detail::ExpandedRmsNormWeight(weight, out.shape(), weight_sizes, + weight_strides), + ninetoothed::FromScalar(eps_value, empty_shape, + empty_strides), + ninetoothed::FromTensor(out), + ninetoothed::FromScalar( + num_normalized_elements, empty_shape, empty_strides)); + + assert(result == 0 && "ninetoothed `RmsNorm` launch failed"); + } +}; + +} // namespace infini::ops + +#endif // WITH_NINETOOTHED + +#endif diff --git a/src/native/ninetoothed/codegen.py b/src/native/ninetoothed/codegen.py new file mode 100644 index 000000000..0e1c1dad2 --- /dev/null +++ b/src/native/ninetoothed/codegen.py @@ -0,0 +1,165 @@ +import argparse +import pathlib +import shutil +import sys + + +_DEFAULT_DTYPES = ("float32", "float16", "bfloat16") +_DEFAULT_RMS_NORM_NDIMS = (2, 3) +_SUPPORTED_OPS = ("rms_norm",) + + +def _import_ninetoothed(source_dir): + if source_dir is not None: + sys.path.insert(0, str(pathlib.Path(source_dir) / "src")) + + import ninetoothed + + return ninetoothed + + +def _rms_norm_premake( + ndim, + num_normalized_dims, + input_dtype, + weight_dtype, + output_dtype, + block_size, +): + import ntops + + return ntops.kernels.rms_norm.premake( + ndim, + num_normalized_dims, + input_dtype=input_dtype, + weight_dtype=weight_dtype, + output_dtype=output_dtype, + block_size=block_size, + ) + + +def _normalize_ndims(values): + ndims = [] + + for value in values: + ndim = int(value) + + if ndim not in _DEFAULT_RMS_NORM_NDIMS: + raise ValueError(f"`RmsNorm` currently supports rank 2 and 3: {value!r}") + + if ndim not in ndims: + ndims.append(ndim) + + return tuple(ndims) + + +def _rms_norm_configs(ninetoothed, dtypes, ndims, block_size): + configs = [] + + for ndim in _normalize_ndims(ndims): + for dtype_name in dtypes: + dtype = getattr(ninetoothed, dtype_name) + configs.append( + ( + (), + { + "ndim": ndim, + "num_normalized_dims": 1, + "input_dtype": dtype, + "weight_dtype": dtype, + "output_dtype": dtype, + "block_size": block_size, + }, + {}, + ) + ) + + return tuple(configs) + + +def _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size): + variant_dir = output_dir / "rms_norm" + variant_dir.mkdir(parents=True, exist_ok=True) + ninetoothed.build( + _rms_norm_premake, + _rms_norm_configs(ninetoothed, dtypes, rms_norm_ndims, block_size), + meta_parameters=None, + caller="cuda", + kernel_name="infiniops_ninetoothed_rms_norm", + output_dir=variant_dir, + lazy=False, + ) + + +def _build_manifest(output_dir): + return sorted( + str(path) + for path in pathlib.Path(output_dir).rglob("*.cpp") + if not path.name.endswith(".tmp.cpp") + ) + + +def _write_cmake_manifest(output_dir, sources): + manifest_path = pathlib.Path(output_dir) / "manifest.cmake" + lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"] + lines.extend(f' "{source}"' for source in sources) + lines.append(")") + lines.append("") + lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")') + lines.append("") + manifest_path.write_text("\n".join(lines) + "\n") + + +def generate( + ops, + *, + output_dir, + dtypes=_DEFAULT_DTYPES, + rms_norm_ndims=_DEFAULT_RMS_NORM_NDIMS, + block_size=256, + ninetoothed_source_dir=None, +): + unknown_ops = tuple(op for op in ops if op not in _SUPPORTED_OPS) + + if unknown_ops: + raise ValueError(f"unsupported ninetoothed ops: {', '.join(unknown_ops)}") + + output_dir = pathlib.Path(output_dir) + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(parents=True, exist_ok=True) + + ninetoothed = _import_ninetoothed(ninetoothed_source_dir) + + if "rms_norm" in ops: + _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size) + + sources = _build_manifest(output_dir) + _write_cmake_manifest(output_dir, sources) + + return sources + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate ninetoothed operator sources for InfiniOps." + ) + parser.add_argument("--output-dir", required=True) + parser.add_argument("--ops", nargs="+", default=_SUPPORTED_OPS) + parser.add_argument("--dtypes", nargs="+", default=_DEFAULT_DTYPES) + parser.add_argument("--rms-norm-ndims", nargs="+", default=_DEFAULT_RMS_NORM_NDIMS) + parser.add_argument("--block-size", type=int, default=256) + parser.add_argument("--ninetoothed-source-dir") + + return parser.parse_args() + + +def main(): + args = _parse_args() + generate( + args.ops, + output_dir=args.output_dir, + dtypes=tuple(args.dtypes), + rms_norm_ndims=tuple(args.rms_norm_ndims), + block_size=args.block_size, + ninetoothed_source_dir=args.ninetoothed_source_dir, + ) diff --git a/src/native/ninetoothed/tensor.h b/src/native/ninetoothed/tensor.h new file mode 100644 index 000000000..f37625350 --- /dev/null +++ b/src/native/ninetoothed/tensor.h @@ -0,0 +1,58 @@ +#ifndef INFINI_OPS_NATIVE_NINETOOTHED_TENSOR_H_ +#define INFINI_OPS_NATIVE_NINETOOTHED_TENSOR_H_ + +#include +#include +#include +#include + +#include "data_type.h" +#include "tensor.h" + +namespace infini::ops::ninetoothed { + +inline int DTypeIndex(DataType dtype) { + switch (dtype) { + case DataType::kFloat16: + return 8; + case DataType::kBFloat16: + return 9; + case DataType::kFloat32: + return 10; + default: + return -1; + } +} + +inline int SizeArg(Tensor::Size size) { + assert(size <= static_cast(std::numeric_limits::max()) && + "ninetoothed launch config dimensions must fit in int"); + return static_cast(size); +} + +template +NineToothedTensor FromTensor(const Tensor& tensor) { + static_assert(sizeof(Tensor::Size) == sizeof(std::uint64_t)); + static_assert(sizeof(Tensor::Stride) == sizeof(std::int64_t)); + static_assert(std::is_unsigned_v); + static_assert(std::is_signed_v); + + return NineToothedTensor{ + const_cast(tensor.data()), + reinterpret_cast( + const_cast(tensor.shape().data())), + reinterpret_cast( + const_cast(tensor.strides().data())), + }; +} + +template +NineToothedTensor FromScalar(T& value, std::uint64_t* empty_shape, + std::int64_t* empty_strides) { + return NineToothedTensor{static_cast(&value), empty_shape, + empty_strides}; +} + +} // namespace infini::ops::ninetoothed + +#endif diff --git a/tests/test_generate_ninetoothed_ops.py b/tests/test_generate_ninetoothed_ops.py new file mode 100644 index 000000000..f51081206 --- /dev/null +++ b/tests/test_generate_ninetoothed_ops.py @@ -0,0 +1,139 @@ +import importlib.util +import pathlib +import sys +import tempfile +import types +import unittest +from unittest import mock + + +def _load_generator_module(): + path = ( + pathlib.Path(__file__).resolve().parents[1] + / "src" + / "native" + / "ninetoothed" + / "codegen.py" + ) + spec = importlib.util.spec_from_file_location( + "ninetoothed_codegen_under_test", path + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module + + +class GenerateNineToothedOpsTest(unittest.TestCase): + def test_generate_rms_norm_uses_ntops_premake_with_rank_configs(self): + module = _load_generator_module() + calls = [] + + fake_ninetoothed = types.SimpleNamespace( + float32="nt.float32", + ) + fake_ninetoothed.build = lambda *args, **kwargs: calls.append((args, kwargs)) + + fake_arrangement = object() + fake_application = object() + fake_tensors = object() + premake_calls = [] + + def fake_ntops_premake(*args, **kwargs): + premake_calls.append((args, kwargs)) + return fake_arrangement, fake_application, fake_tensors + + fake_ntops = types.SimpleNamespace( + kernels=types.SimpleNamespace( + rms_norm=types.SimpleNamespace(premake=fake_ntops_premake) + ) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = pathlib.Path(tmpdir) + + with ( + mock.patch.object( + module, "_import_ninetoothed", return_value=fake_ninetoothed + ), + mock.patch.object( + module, "_build_manifest", return_value=["kernel.cpp"] + ), + ): + manifest = module.generate( + ["rms_norm"], + output_dir=tmp_path, + dtypes=("float32",), + rms_norm_ndims=(2,), + block_size=256, + ) + + self.assertEqual(manifest, ["kernel.cpp"]) + self.assertEqual(len(calls), 1) + + args, kwargs = calls[0] + premake, configs = args + self.assertEqual( + configs, + ( + ( + (), + { + "ndim": 2, + "num_normalized_dims": 1, + "input_dtype": "nt.float32", + "weight_dtype": "nt.float32", + "output_dtype": "nt.float32", + "block_size": 256, + }, + {}, + ), + ), + ) + self.assertEqual(kwargs["caller"], "cuda") + self.assertEqual( + kwargs["kernel_name"], + "infiniops_ninetoothed_rms_norm", + ) + self.assertEqual(kwargs["output_dir"], tmp_path / "rms_norm") + self.assertIs(kwargs["lazy"], False) + self.assertIsNone(kwargs["meta_parameters"]) + + with mock.patch.dict( + sys.modules, + { + "ntops": fake_ntops, + }, + ): + arrangement, application, tensors = premake( + ndim=2, + num_normalized_dims=1, + input_dtype="nt.float32", + weight_dtype="nt.float32", + output_dtype="nt.float32", + block_size=256, + ) + + self.assertIs(arrangement, fake_arrangement) + self.assertIs(application, fake_application) + self.assertIs(tensors, fake_tensors) + self.assertEqual( + premake_calls, + [ + ( + (2, 1), + { + "input_dtype": "nt.float32", + "weight_dtype": "nt.float32", + "output_dtype": "nt.float32", + "block_size": 256, + }, + ) + ], + ) + + +if __name__ == "__main__": + unittest.main()