Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ option(WITH_ASCEND "Enable Ascend backend" OFF)

option(WITH_TORCH "Enable PyTorch C++ backend" OFF)

option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF)

# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for
# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed
# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the
Expand All @@ -29,6 +31,14 @@ option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)

# NineToothed code generation configuration.
set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation")
set(NINETOOTHED_SOURCE_DIR "" CACHE PATH "Optional local ninetoothed source checkout; installed package is used when empty")
set(INFINIOPS_NINETOOTHED_OPS "rms_norm" CACHE STRING "Semicolon- or comma-separated NineToothed ops to generate")
set(INFINIOPS_NINETOOTHED_DTYPES "float32;float16;bfloat16" CACHE STRING "Semicolon- or comma-separated NineToothed dtypes to generate")
set(INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS "2;3" CACHE STRING "Semicolon- or comma-separated RmsNorm input ranks to generate with NineToothed")
set(INFINIOPS_NINETOOTHED_BLOCK_SIZE "256" CACHE STRING "Block size baked into simple NineToothed elementwise kernels")

if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detecting available devices...")

Expand Down Expand Up @@ -231,6 +241,10 @@ if(_gpu_backend_count GREATER 1)
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.")
endif()

if(WITH_NINETOOTHED AND NOT WITH_NVIDIA)
message(FATAL_ERROR "`WITH_NINETOOTHED` currently requires `WITH_NVIDIA=ON` because ninetoothed AOT uses caller=`cuda`.")
endif()

if(WITH_NVIDIA)
add_compile_definitions(WITH_NVIDIA=1)
enable_language(CUDA)
Expand Down
15 changes: 15 additions & 0 deletions scripts/generate_ninetoothed_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pathlib
import sys


def main():
project_dir = pathlib.Path(__file__).resolve().parents[1]
sys.path.insert(0, str(project_dir / "src"))

from native.ninetoothed.codegen import main as codegen_main

codegen_main()


if __name__ == "__main__":
main()
51 changes: 51 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,57 @@ if(WITH_NVIDIA)
target_compile_definitions(infiniops PUBLIC WITH_NVIDIA=1)
target_sources(infiniops PRIVATE ${NVIDIA_SOURCES})

if(WITH_NINETOOTHED)
find_package(Python COMPONENTS Interpreter REQUIRED)

if(NINETOOTHED_PYTHON_EXECUTABLE)
set(_ninetoothed_python "${NINETOOTHED_PYTHON_EXECUTABLE}")
elseif(_TORCH_PYTHON)
set(_ninetoothed_python "${_TORCH_PYTHON}")
else()
set(_ninetoothed_python "${Python_EXECUTABLE}")
endif()
message(STATUS "NineToothed codegen Python: ${_ninetoothed_python}")

string(REPLACE "," ";" _ninetoothed_ops "${INFINIOPS_NINETOOTHED_OPS}")
string(REPLACE "," ";" _ninetoothed_dtypes "${INFINIOPS_NINETOOTHED_DTYPES}")
string(REPLACE "," ";" _ninetoothed_rms_norm_ndims "${INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS}")

set(_ninetoothed_output_dir "${CMAKE_CURRENT_BINARY_DIR}/ninetoothed")
set(_ninetoothed_generator_args
"${PROJECT_SOURCE_DIR}/scripts/generate_ninetoothed_ops.py"
--output-dir "${_ninetoothed_output_dir}"
--ops ${_ninetoothed_ops}
--dtypes ${_ninetoothed_dtypes}
--rms-norm-ndims ${_ninetoothed_rms_norm_ndims}
--block-size "${INFINIOPS_NINETOOTHED_BLOCK_SIZE}")

if(NINETOOTHED_SOURCE_DIR)
list(APPEND _ninetoothed_generator_args
--ninetoothed-source-dir "${NINETOOTHED_SOURCE_DIR}")
endif()

execute_process(
COMMAND "${_ninetoothed_python}" ${_ninetoothed_generator_args}
WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
RESULT_VARIABLE _ninetoothed_generation_result
)

if(NOT _ninetoothed_generation_result EQUAL 0)
message(FATAL_ERROR "Generating NineToothed operator sources failed with `${_ninetoothed_python}`. Set `NINETOOTHED_PYTHON_EXECUTABLE` to a Python with `ninetoothed`, `triton`, `sympy`, and CUDA dependencies installed.")
endif()

include("${_ninetoothed_output_dir}/manifest.cmake")
set(_ninetoothed_compile_definitions
WITH_NINETOOTHED=1
INFINIOPS_NINETOOTHED_BLOCK_SIZE=${INFINIOPS_NINETOOTHED_BLOCK_SIZE})
target_compile_definitions(infiniops PUBLIC
${_ninetoothed_compile_definitions})
target_include_directories(infiniops PUBLIC
${INFINIOPS_NINETOOTHED_INCLUDE_DIRS})
target_sources(infiniops PRIVATE ${INFINIOPS_NINETOOTHED_SOURCES})
endif()

find_package(CUDAToolkit REQUIRED)
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cuda_driver)

Expand Down
101 changes: 101 additions & 0 deletions src/native/cuda/nvidia/ops/rms_norm/ninetoothed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#ifndef INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_
#define INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_

#ifdef WITH_NINETOOTHED

#include <cassert>
#include <cstdint>
#include <vector>

#include "base/rms_norm.h"
#include "data_type.h"
#include "native/ninetoothed/tensor.h"
#include "rms_norm/infiniops_ninetoothed_rms_norm.h"

#ifndef INFINIOPS_NINETOOTHED_BLOCK_SIZE
#define INFINIOPS_NINETOOTHED_BLOCK_SIZE 256
#endif

namespace infini::ops {

namespace detail {

inline NineToothedTensor ExpandedRmsNormWeight(const Tensor& weight,
const Tensor::Shape& shape,
std::vector<std::uint64_t>& sizes,
std::vector<std::int64_t>& strides) {
sizes.assign(shape.begin(), shape.end());
strides.assign(shape.size(), 0);
strides.back() = weight.strides().empty() ? 1 : weight.strides().back();

return NineToothedTensor{const_cast<void*>(weight.data()), sizes.data(),
strides.data()};
}

inline NineToothedResult LaunchNineToothedRmsNorm(
const Tensor::Shape& shape, DataType dtype, NineToothedStream stream,
NineToothedTensor input, NineToothedTensor weight, NineToothedTensor eps,
NineToothedTensor out, NineToothedTensor num_normalized_elements) {
const int dtype_index = ninetoothed::DTypeIndex(dtype);

if (dtype_index < 0) {
return 1;
}

return launch_infiniops_ninetoothed_rms_norm(
stream, input, weight, eps, out, num_normalized_elements,
ninetoothed::SizeArg(shape.size()), 1, dtype_index, dtype_index,
dtype_index, INFINIOPS_NINETOOTHED_BLOCK_SIZE);
}

} // namespace detail

template <>
class Operator<RmsNorm, Device::Type::kNvidia, 9> : public RmsNorm {
public:
using RmsNorm::RmsNorm;
using RmsNorm::operator();

void operator()(const Tensor input, const Tensor weight, float eps,
Tensor out) const override {
assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() &&
"operator `RmsNorm` requires all input and output tensors to have "
"the same dtype");
assert(input.shape() == out.shape() &&
"ninetoothed `RmsNorm` requires input and output tensors with the "
"same shape");
assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) &&
"ninetoothed `RmsNorm` requires a 1D weight matching the last "
"dimension");
assert((out.ndim() == 2 || out.ndim() == 3) &&
"ninetoothed `RmsNorm` currently supports rank-2 and rank-3 "
"tensors");

std::vector<std::uint64_t> weight_sizes;
std::vector<std::int64_t> weight_strides;
double eps_value = static_cast<double>(eps);
std::int64_t num_normalized_elements =
static_cast<std::int64_t>(out.size(-1));
std::uint64_t empty_shape[1] = {};
std::int64_t empty_strides[1] = {};

auto result = detail::LaunchNineToothedRmsNorm(
out.shape(), out.dtype(), static_cast<NineToothedStream>(stream_),
ninetoothed::FromTensor<NineToothedTensor>(input),
detail::ExpandedRmsNormWeight(weight, out.shape(), weight_sizes,
weight_strides),
ninetoothed::FromScalar<NineToothedTensor>(eps_value, empty_shape,
empty_strides),
ninetoothed::FromTensor<NineToothedTensor>(out),
ninetoothed::FromScalar<NineToothedTensor>(
num_normalized_elements, empty_shape, empty_strides));

assert(result == 0 && "ninetoothed `RmsNorm` launch failed");
}
};

} // namespace infini::ops

#endif // WITH_NINETOOTHED

#endif
165 changes: 165 additions & 0 deletions src/native/ninetoothed/codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import argparse
import pathlib
import shutil
import sys


_DEFAULT_DTYPES = ("float32", "float16", "bfloat16")
_DEFAULT_RMS_NORM_NDIMS = (2, 3)
_SUPPORTED_OPS = ("rms_norm",)


def _import_ninetoothed(source_dir):
if source_dir is not None:
sys.path.insert(0, str(pathlib.Path(source_dir) / "src"))

import ninetoothed

return ninetoothed


def _rms_norm_premake(
ndim,
num_normalized_dims,
input_dtype,
weight_dtype,
output_dtype,
block_size,
):
import ntops

return ntops.kernels.rms_norm.premake(
ndim,
num_normalized_dims,
input_dtype=input_dtype,
weight_dtype=weight_dtype,
output_dtype=output_dtype,
block_size=block_size,
)


def _normalize_ndims(values):
ndims = []

for value in values:
ndim = int(value)

if ndim not in _DEFAULT_RMS_NORM_NDIMS:
raise ValueError(f"`RmsNorm` currently supports rank 2 and 3: {value!r}")

if ndim not in ndims:
ndims.append(ndim)

return tuple(ndims)


def _rms_norm_configs(ninetoothed, dtypes, ndims, block_size):
configs = []

for ndim in _normalize_ndims(ndims):
for dtype_name in dtypes:
dtype = getattr(ninetoothed, dtype_name)
configs.append(
(
(),
{
"ndim": ndim,
"num_normalized_dims": 1,
"input_dtype": dtype,
"weight_dtype": dtype,
"output_dtype": dtype,
"block_size": block_size,
},
{},
)
)

return tuple(configs)


def _generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size):
variant_dir = output_dir / "rms_norm"
variant_dir.mkdir(parents=True, exist_ok=True)
ninetoothed.build(
_rms_norm_premake,
_rms_norm_configs(ninetoothed, dtypes, rms_norm_ndims, block_size),
meta_parameters=None,
caller="cuda",
kernel_name="infiniops_ninetoothed_rms_norm",
output_dir=variant_dir,
lazy=False,
)


def _build_manifest(output_dir):
return sorted(
str(path)
for path in pathlib.Path(output_dir).rglob("*.cpp")
if not path.name.endswith(".tmp.cpp")
)


def _write_cmake_manifest(output_dir, sources):
manifest_path = pathlib.Path(output_dir) / "manifest.cmake"
lines = ["set(INFINIOPS_NINETOOTHED_SOURCES"]
lines.extend(f' "{source}"' for source in sources)
lines.append(")")
lines.append("")
lines.append(f'set(INFINIOPS_NINETOOTHED_INCLUDE_DIRS "{output_dir}")')
lines.append("")
manifest_path.write_text("\n".join(lines) + "\n")


def generate(
ops,
*,
output_dir,
dtypes=_DEFAULT_DTYPES,
rms_norm_ndims=_DEFAULT_RMS_NORM_NDIMS,
block_size=256,
ninetoothed_source_dir=None,
):
unknown_ops = tuple(op for op in ops if op not in _SUPPORTED_OPS)

if unknown_ops:
raise ValueError(f"unsupported ninetoothed ops: {', '.join(unknown_ops)}")

output_dir = pathlib.Path(output_dir)
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True, exist_ok=True)

ninetoothed = _import_ninetoothed(ninetoothed_source_dir)

if "rms_norm" in ops:
_generate_rms_norm(ninetoothed, output_dir, dtypes, rms_norm_ndims, block_size)

sources = _build_manifest(output_dir)
_write_cmake_manifest(output_dir, sources)

return sources


def _parse_args():
parser = argparse.ArgumentParser(
description="Generate ninetoothed operator sources for InfiniOps."
)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--ops", nargs="+", default=_SUPPORTED_OPS)
parser.add_argument("--dtypes", nargs="+", default=_DEFAULT_DTYPES)
parser.add_argument("--rms-norm-ndims", nargs="+", default=_DEFAULT_RMS_NORM_NDIMS)
parser.add_argument("--block-size", type=int, default=256)
parser.add_argument("--ninetoothed-source-dir")

return parser.parse_args()


def main():
args = _parse_args()
generate(
args.ops,
output_dir=args.output_dir,
dtypes=tuple(args.dtypes),
rms_norm_ndims=tuple(args.rms_norm_ndims),
block_size=args.block_size,
ninetoothed_source_dir=args.ninetoothed_source_dir,
)
Loading
Loading