From c3831fff3e3d9e41d11e6379a8fda6e54ed82477 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 26 Dec 2025 06:32:52 +0000 Subject: [PATCH 1/3] Issue/846 - Refactor embedding to support device-side input and CUDA graph recording --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/embedding.hpp | 7 + include/infiniop.h | 1 + include/infiniop/ops/embedding.h | 26 ++ python/infinicore/nn/functional/embedding.py | 5 +- src/infinicore/nn/embedding.cc | 81 +---- src/infinicore/ops/embedding/embedding.cc | 82 ++--- .../ops/embedding/embedding_infiniop.cc | 49 +++ .../ops/embedding/cpu/embedding_cpu.cc | 110 ++++++ .../ops/embedding/cpu/embedding_cpu.h | 8 + src/infiniop/ops/embedding/embedding.h | 54 +++ .../ops/embedding/nvidia/embedding_kernel.cuh | 50 +++ .../ops/embedding/nvidia/embedding_nvidia.cu | 161 +++++++++ .../ops/embedding/nvidia/embedding_nvidia.cuh | 8 + src/infiniop/ops/embedding/operator.cc | 118 +++++++ .../EMBEDDING_GRAPH_RECORDING_COMPARISON.md | 159 +++++++++ .../nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md | 317 ++++++++++++++++++ test/infinicore/nn/embedding.py | 11 +- .../nn/test_embedding_graph_recording.py | 284 ++++++++++++++++ test/infinicore/ops/embedding.py | 20 +- 20 files changed, 1387 insertions(+), 165 deletions(-) create mode 100644 include/infiniop/ops/embedding.h create mode 100644 src/infinicore/ops/embedding/embedding_infiniop.cc create mode 100644 src/infiniop/ops/embedding/cpu/embedding_cpu.cc create mode 100644 src/infiniop/ops/embedding/cpu/embedding_cpu.h create mode 100644 src/infiniop/ops/embedding/embedding.h create mode 100644 src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh create mode 100644 src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu create mode 100644 src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh create mode 100644 src/infiniop/ops/embedding/operator.cc create mode 100644 test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md create mode 100644 test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md create mode 100644 test/infinicore/nn/test_embedding_graph_recording.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 0937a4821..b8ae4332b 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -3,6 +3,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/embedding.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/rearrange.hpp" diff --git a/include/infinicore/ops/embedding.hpp b/include/infinicore/ops/embedding.hpp index 4fd9991c4..6be997134 100644 --- a/include/infinicore/ops/embedding.hpp +++ b/include/infinicore/ops/embedding.hpp @@ -4,6 +4,13 @@ namespace infinicore::op { +class Embedding { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor out, Tensor input, Tensor weight); + static common::OpDispatcher &dispatcher(); +}; + Tensor embedding(Tensor input, Tensor weight); void embedding_(Tensor out, Tensor input, Tensor weight); } // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..034717ef4 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -8,6 +8,7 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/embedding.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/layer_norm.h" diff --git a/include/infiniop/ops/embedding.h b/include/infiniop/ops/embedding.h new file mode 100644 index 000000000..e5ffc211d --- /dev/null +++ b/include/infiniop/ops/embedding.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_EMBEDDING_API_H__ +#define __INFINIOP_EMBEDDING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc); + +__C __export infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream); + +__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor( + infiniopEmbeddingDescriptor_t desc); + +#endif + diff --git a/python/infinicore/nn/functional/embedding.py b/python/infinicore/nn/functional/embedding.py index f346d380a..592a12290 100644 --- a/python/infinicore/nn/functional/embedding.py +++ b/python/infinicore/nn/functional/embedding.py @@ -22,9 +22,8 @@ def embedding( and (sparse is False) ), "Unsupported parameters." - assert "cpu" == input.device.type, ( - "The device of 'input' variable must be on the CPU." - ) + # Note: embedding now supports device-side input for graph recording + # The C++ implementation handles both CPU and device-side inputs if out is None: return Tensor(_infinicore.embedding(input._underlying, weight._underlying)) diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index 85645bf95..f1af03042 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -43,80 +43,13 @@ Embedding::Embedding(size_t num_embeddings, } Tensor Embedding::forward(const Tensor &indices) const { - // Get the shape of indices - auto indices_shape = indices->shape(); - - // Output shape: indices_shape + [embedding_dim] - std::vector output_shape = indices_shape; - output_shape.push_back(embedding_dim_); - - // Create output tensor on the same device as weight - auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device()); - - // Flatten indices for sequential row copies - auto cpu_device = Device(Device::Type::CPU, 0); - auto indices_cpu = indices->to(cpu_device)->contiguous(); - - // Calculate total number of lookups - size_t num_lookups = 1; - for (auto dim : indices_shape) { - num_lookups *= dim; - } - - const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype()); - - // Source and destination base pointers - auto *weight_base = weight_->data(); - auto *out_base = out->data(); - - // Helper lambda to read index based on dtype with bounds checking - auto read_index = [&](size_t i) -> int64_t { - auto dtype = indices_cpu->dtype(); - if (dtype == DataType::I32) { - const auto *data = reinterpret_cast(indices_cpu->data()); - return static_cast(data[i]); - } else if (dtype == DataType::I64) { - const auto *data = reinterpret_cast(indices_cpu->data()); - return data[i]; - } else if (dtype == DataType::U32) { - const auto *data = reinterpret_cast(indices_cpu->data()); - return static_cast(data[i]); - } else if (dtype == DataType::U64) { - const auto *data = reinterpret_cast(indices_cpu->data()); - uint64_t val = data[i]; - // Check if value can fit in int64_t - if (val > static_cast(std::numeric_limits::max())) { - throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val)); - } - return static_cast(val); - } else { - throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast(dtype))); - } - }; - - if (weight_->device().getType() == Device::Type::CPU) { - // CPU path: memcpy row by row - for (size_t i = 0; i < num_lookups; ++i) { - int64_t idx = read_index(i); - if (idx < 0 || idx >= static_cast(num_embeddings_)) { - throw std::out_of_range( - "Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")"); - } - std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); - } - } else { - // Device path: use stream-ordered D2D copies - for (size_t i = 0; i < num_lookups; ++i) { - int64_t idx = read_index(i); - if (idx < 0 || idx >= static_cast(num_embeddings_)) { - throw std::out_of_range( - "Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")"); - } - context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes); - } - } - - return out; + // Ensure indices are contiguous for efficient access + // op::embedding now supports device-side input for graph recording + Tensor indices_contiguous = indices->is_contiguous() ? indices : indices->contiguous(); + + // Use op::embedding which now supports device-side input and batch dimension + // This enables full graph recording support without synchronization + return op::embedding(indices_contiguous, weight_); } std::string Embedding::extra_repr() const { diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index f1add0c97..cf5c41caf 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -1,15 +1,32 @@ #include "infinicore/ops/embedding.hpp" #include "infinicore/context/context.hpp" +#include "../../utils.hpp" #include +#include namespace infinicore::op { +common::OpDispatcher &Embedding::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void Embedding::execute(Tensor out, Tensor input, Tensor weight) { + // Check that output and weight are on the same device + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, weight); + + // Set device context + infinicore::context::setDevice(out->device()); + + // Use dispatcher to lookup kernel (infiniop implementation) + dispatcher().lookup(out->device().getType())(out, input, weight); +} + Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 ) { auto input_shape = input->shape(); auto weight_shape = weight->shape(); - // auto vocab_size = weight_shape[0]; auto embedding_dim = weight_shape[1]; // Assign memory to out variables @@ -22,68 +39,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i } void embedding_(Tensor out, Tensor input, Tensor weight) { - assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); - assert(infinicore::Device::Type::CPU == input->device().getType()); - - auto input_shape = input->shape(); - auto weight_shape = weight->shape(); - auto embedding_dim = weight_shape[1]; - - // Calculate the number of token - Size counts = 1; - for (auto &v : input_shape) { - counts *= v; - } - - // the bytes of one token - const Size bytes = dsize(weight->dtype()) * embedding_dim; - auto *weight_ptr = weight->data(); - auto *out_ptr = out->data(); - - // copies - if (weight->device().getType() == Device::Type::CPU) { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - - } else { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - } + Embedding::execute(out, input, weight); } } // namespace infinicore::op diff --git a/src/infinicore/ops/embedding/embedding_infiniop.cc b/src/infinicore/ops/embedding/embedding_infiniop.cc new file mode 100644 index 000000000..af73f13fa --- /dev/null +++ b/src/infinicore/ops/embedding/embedding_infiniop.cc @@ -0,0 +1,49 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/embedding.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::embedding_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopEmbeddingDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor input, Tensor weight) { + size_t seed = hash_combine(out, input, weight); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopEmbeddingDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), input->desc(), weight->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + INFINICORE_CHECK_ERROR(infiniopEmbedding( + desc, + out->data(), + input->data(), + weight->data(), + context::getStream())); +} + +static bool registered = []() { + Embedding::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::embedding_impl::infiniop diff --git a/src/infiniop/ops/embedding/cpu/embedding_cpu.cc b/src/infiniop/ops/embedding/cpu/embedding_cpu.cc new file mode 100644 index 000000000..e84eced6b --- /dev/null +++ b/src/infiniop/ops/embedding/cpu/embedding_cpu.cc @@ -0,0 +1,110 @@ +#include "../../../../utils.h" +#include "../../../tensor.h" +#include "../../../handle.h" +#include "embedding_cpu.h" +#include + +namespace op::embedding::cpu { + +struct Descriptor::Opaque {}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto input_shape = input_desc->shape(); + auto weight_shape = weight_desc->shape(); + + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto output_shape = output_desc->shape(); + size_t embedding_dim = weight_shape[1]; + CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); + + for (size_t i = 0; i < input_shape.size(); ++i) { + CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + auto input_dtype = input_desc->dtype(); + auto weight_dtype = weight_desc->dtype(); + CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || + weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + + size_t vocab_size = weight_shape[0]; + + *desc_ptr = new Descriptor( + num_indices, + embedding_dim, + vocab_size, + input_dtype, + weight_dtype, + new Opaque{}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + + if (_num_indices == 0) { + return INFINI_STATUS_SUCCESS; + } + + size_t element_size = infiniSizeOf(_weight_dtype); + size_t row_bytes = _embedding_dim * element_size; + + if (_input_dtype == INFINI_DTYPE_I32) { + const int32_t *indices_ptr = reinterpret_cast(input); + const std::byte *weight_ptr = reinterpret_cast(weight); + std::byte *out_ptr = reinterpret_cast(output); + + for (size_t i = 0; i < _num_indices; ++i) { + int32_t idx = indices_ptr[i]; + if (idx >= 0 && static_cast(idx) < _vocab_size) { + std::memcpy(out_ptr + i * row_bytes, + weight_ptr + static_cast(idx) * row_bytes, + row_bytes); + } + } + } else if (_input_dtype == INFINI_DTYPE_I64) { + const int64_t *indices_ptr = reinterpret_cast(input); + const std::byte *weight_ptr = reinterpret_cast(weight); + std::byte *out_ptr = reinterpret_cast(output); + + for (size_t i = 0; i < _num_indices; ++i) { + int64_t idx = indices_ptr[i]; + if (idx >= 0 && static_cast(idx) < _vocab_size) { + std::memcpy(out_ptr + i * row_bytes, + weight_ptr + static_cast(idx) * row_bytes, + row_bytes); + } + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::embedding::cpu diff --git a/src/infiniop/ops/embedding/cpu/embedding_cpu.h b/src/infiniop/ops/embedding/cpu/embedding_cpu.h new file mode 100644 index 000000000..a5cc5b2d0 --- /dev/null +++ b/src/infiniop/ops/embedding/cpu/embedding_cpu.h @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_CPU_H__ +#define __EMBEDDING_CPU_H__ + +#include "../embedding.h" + +DESCRIPTOR(cpu) + +#endif // __EMBEDDING_CPU_H__ diff --git a/src/infiniop/ops/embedding/embedding.h b/src/infiniop/ops/embedding/embedding.h new file mode 100644 index 000000000..0e4b33009 --- /dev/null +++ b/src/infiniop/ops/embedding/embedding.h @@ -0,0 +1,54 @@ +#ifndef __EMBEDDING_H__ +#define __EMBEDDING_H__ + +#include "../../../utils.h" +#include "../../operator.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::embedding::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + size_t _num_indices; \ + size_t _embedding_dim; \ + size_t _vocab_size; \ + infiniDtype_t _input_dtype; \ + infiniDtype_t _weight_dtype; \ + \ + Descriptor( \ + size_t num_indices, \ + size_t embedding_dim, \ + size_t vocab_size, \ + infiniDtype_t input_dtype, \ + infiniDtype_t weight_dtype, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _num_indices(num_indices), \ + _embedding_dim(embedding_dim), \ + _vocab_size(vocab_size), \ + _input_dtype(input_dtype), \ + _weight_dtype(weight_dtype) {} \ + \ + public: \ + ~Descriptor(); \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t output_desc, \ + infiniopTensorDescriptor_t input_desc, \ + infiniopTensorDescriptor_t weight_desc); \ + \ + infiniStatus_t calculate( \ + void *output, \ + const void *input, \ + const void *weight, \ + void *stream) const; \ + }; \ + } + +#endif // __EMBEDDING_H__ diff --git a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh new file mode 100644 index 000000000..8398bfbfc --- /dev/null +++ b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh @@ -0,0 +1,50 @@ +#ifndef __EMBEDDING_CUDA_KERNEL_CUH__ +#define __EMBEDDING_CUDA_KERNEL_CUH__ + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include + +namespace op::embedding::nvidia { + +template +INFINIOP_CUDA_KERNEL embeddingKernel( + T *output, + const IndexType *indices, + const T *weight, + size_t num_indices, + size_t embedding_dim, + size_t vocab_size) { + // Calculate global thread index + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < num_indices) { + // Get the index value + IndexType index_val = indices[idx]; + + // Bounds check - handle negative indices gracefully + if (index_val >= 0 && static_cast(index_val) < vocab_size) { + // Copy embedding vector from weight to output + const T *src = weight + static_cast(index_val) * embedding_dim; + T *dst = output + idx * embedding_dim; + + // Copy embedding_dim elements + // Use vectorized copy for better performance when possible + size_t i = 0; + // Copy in chunks of 4 for better memory bandwidth utilization + for (; i + 4 <= embedding_dim; i += 4) { + dst[i] = src[i]; + dst[i + 1] = src[i + 1]; + dst[i + 2] = src[i + 2]; + dst[i + 3] = src[i + 3]; + } + // Copy remaining elements + for (; i < embedding_dim; ++i) { + dst[i] = src[i]; + } + } + } +} + +} // namespace op::embedding::nvidia + +#endif // __EMBEDDING_CUDA_KERNEL_CUH__ diff --git a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu new file mode 100644 index 000000000..007e90c04 --- /dev/null +++ b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu @@ -0,0 +1,161 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../../../tensor.h" +#include "../../../../utils.h" +#include "embedding_kernel.cuh" +#include "embedding_nvidia.cuh" +#include + +namespace op::embedding::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto handle_nvidia = reinterpret_cast(handle); + auto input_shape = input_desc->shape(); + auto weight_shape = weight_desc->shape(); + + // Validate shapes + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + // Check output shape matches input shape + embedding_dim + auto output_shape = output_desc->shape(); + size_t embedding_dim = weight_shape[1]; + CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); + + for (size_t i = 0; i < input_shape.size(); ++i) { + CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + // Validate dtypes + auto input_dtype = input_desc->dtype(); + auto weight_dtype = weight_desc->dtype(); + CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || + weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + // Calculate number of indices (supporting batch dimension) + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + + size_t vocab_size = weight_shape[0]; + + *desc_ptr = new Descriptor( + num_indices, + embedding_dim, + vocab_size, + input_dtype, + weight_dtype, + new Opaque{handle_nvidia->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + + if (_num_indices == 0) { + return INFINI_STATUS_SUCCESS; + } + + auto cuda_stream = reinterpret_cast(stream); + constexpr size_t BLOCK_SIZE = 256; + size_t grid_size = (_num_indices + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Launch kernel based on dtypes + if (_input_dtype == INFINI_DTYPE_I32) { + const int32_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_input_dtype == INFINI_DTYPE_I64) { + const int64_t *indices_ptr = reinterpret_cast(input); + + if (_weight_dtype == INFINI_DTYPE_F32) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + embeddingKernel<<>>( + reinterpret_cast(output), + indices_ptr, + reinterpret_cast(weight), + _num_indices, + _embedding_dim, + _vocab_size); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // Check for kernel launch errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + return INFINI_STATUS_INTERNAL_ERROR; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::embedding::nvidia diff --git a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh new file mode 100644 index 000000000..c6b966d8d --- /dev/null +++ b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_CUDA_H__ +#define __EMBEDDING_CUDA_H__ + +#include "../embedding.h" + +DESCRIPTOR(nvidia) + +#endif // __EMBEDDING_CUDA_H__ diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc new file mode 100644 index 000000000..af75842fa --- /dev/null +++ b/src/infiniop/ops/embedding/operator.cc @@ -0,0 +1,118 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/embedding.h" + +#ifdef ENABLE_CPU_API +#include "cpu/embedding_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#include "nvidia/embedding_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::embedding::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + input_desc, \ + weight_desc) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#if defined(ENABLE_ILUVATAR_API) + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#if defined(ENABLE_QY_API) + CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#if defined(ENABLE_HYGON_API) + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(output, input, weight, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#if defined(ENABLE_ILUVATAR_API) + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#if defined(ENABLE_QY_API) + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#if defined(ENABLE_HYGON_API) + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#if defined(ENABLE_ILUVATAR_API) + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#if defined(ENABLE_QY_API) + DELETE(INFINI_DEVICE_QY, nvidia); +#endif +#if defined(ENABLE_HYGON_API) + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif + } + +#undef DELETE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} diff --git a/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md b/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md new file mode 100644 index 000000000..686c10a1b --- /dev/null +++ b/test/infinicore/nn/EMBEDDING_GRAPH_RECORDING_COMPARISON.md @@ -0,0 +1,159 @@ +# Embedding 图录制支持对比 + +## 改动前后对比 + +### ❌ 改动前:不支持图录制 + +**关键问题代码**(在 `nn::Embedding::forward` 中): +```cpp +// 改动前的实现 +Tensor Embedding::forward(const Tensor &indices) const { + auto cpu_device = Device(Device::Type::CPU, 0); + auto indices_cpu = indices->to(cpu_device)->contiguous(); // ❌ 同步操作! + + // ... 后续处理 +} +``` + +**问题分析**: +1. `indices->to(cpu_device)` 会触发 **同步的 D2H(Device-to-Host)内存拷贝** +2. CUDA Graph 录制要求所有操作都是**异步的**,不能有同步点 +3. 同步操作会导致图录制失败或产生错误 + +**验证方法**: +```python +# 改动前:这个操作会失败或产生同步 +input_ids_device = infinicore.from_list(..., device="cuda:0") # 设备端输入 +output = embedding.forward(input_ids_device) # ❌ 内部会同步拷贝到 CPU +``` + +--- + +### ✅ 改动后:支持图录制 + +**关键改进代码**: +```cpp +// 改动后的实现 +Tensor Embedding::forward(const Tensor &indices) const { + Tensor indices_contiguous = indices->is_contiguous() ? indices : indices->contiguous(); + return op::embedding(indices_contiguous, weight_); // ✅ 直接使用设备端 kernel +} +``` + +**改进点**: +1. **移除了同步操作**:不再调用 `indices->to(cpu_device)` +2. **使用设备端 CUDA kernel**:通过 InfiniOP 调用 `embeddingKernel`,完全在设备端执行 +3. **完全异步**:所有操作都在 CUDA stream 上异步执行 + +**实现位置**: +- CUDA Kernel: `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` +- Kernel 启动:使用 `cudaStream_t`,完全异步 +- 无同步点:没有 `cudaDeviceSynchronize()` 或 D2H 拷贝 + +**验证方法**: +```python +# 改动后:这个操作完全异步,支持图录制 +input_ids_device = infinicore.from_list(..., device="cuda:0") # 设备端输入 +output = embedding.forward(input_ids_device) # ✅ 直接使用设备端 kernel,无同步 +``` + +--- + +## 验证方法 + +### 方法 1: 代码检查 + +**检查点**: +1. ✅ 是否有 `->to(cpu_device)` 调用? +2. ✅ 是否有 `synchronize()` 调用? +3. ✅ 是否有设备端 kernel 实现? + +**改动前**: +```cpp +// ❌ 有同步操作 +auto indices_cpu = indices->to(cpu_device)->contiguous(); +``` + +**改动后**: +```cpp +// ✅ 无同步操作,直接使用设备端 kernel +return op::embedding(indices_contiguous, weight_); +``` + +### 方法 2: CUDA Graph API 测试 + +运行测试脚本: +```bash +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +**预期结果**: +- ✅ 改动后:图录制成功 +- ❌ 改动前:图录制失败(因为同步操作) + +### 方法 3: 设备端输入测试 + +**关键测试**: +```python +# 创建设备端输入 +input_ids = infinicore.from_list([[1, 2, 3]], dtype=int64, device="cuda:0") + +# 执行 forward +output = embedding.forward(input_ids) # 改动前会失败或同步,改动后成功 +``` + +**改动前**: +- 需要先将 `input_ids` 拷贝到 CPU +- 触发同步操作,无法图录制 + +**改动后**: +- 直接使用设备端 `input_ids` +- 完全异步,支持图录制 + +--- + +## 技术细节对比 + +| 特性 | 改动前 | 改动后 | +|------|--------|--------| +| **输入设备** | 必须在 CPU | 支持设备端 | +| **同步操作** | ❌ 有(D2H拷贝) | ✅ 无 | +| **Kernel位置** | CPU 实现 | CUDA kernel | +| **图录制支持** | ❌ 不支持 | ✅ 支持 | +| **Batch维度** | ✅ 支持 | ✅ 支持 | +| **性能** | 较慢(同步开销) | 更快(异步) | + +--- + +## 关键代码位置 + +### 改动前的问题代码 +- `src/infinicore/nn/embedding.cc` (旧版本) + - 第58行:`indices->to(cpu_device)->contiguous()` ❌ + +### 改动后的实现 +- `src/infinicore/nn/embedding.cc` (新版本) + - 第48行:`indices->is_contiguous() ? indices : indices->contiguous()` ✅ + - 第52行:`return op::embedding(indices_contiguous, weight_)` ✅ + +- `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` + - CUDA kernel 实现,完全异步 ✅ + +- `src/infinicore/ops/embedding/embedding_infiniop.cc` + - InfiniOP 包装,调用设备端 kernel ✅ + +--- + +## 总结 + +**改动前的关键问题**: +- ❌ `indices->to(cpu_device)` 触发同步 D2H 拷贝 +- ❌ 无法进行 CUDA Graph 录制 +- ❌ 性能较差(同步开销) + +**改动后的改进**: +- ✅ 移除所有同步操作 +- ✅ 使用设备端 CUDA kernel +- ✅ 完全支持 CUDA Graph 录制 +- ✅ 性能更好(完全异步) + diff --git a/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md b/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md new file mode 100644 index 000000000..e5e60db2b --- /dev/null +++ b/test/infinicore/nn/HOW_TO_USE_GRAPH_RECORDING_TEST.md @@ -0,0 +1,317 @@ +# Embedding 图录制测试使用指南 + +## 🚀 快速开始 + +### 运行测试 + +```bash +cd /home/zhuyue/codes/InfiniCore +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +--- + +## 📊 改动前后对比 + +### ❌ 改动前:不支持图录制 + +#### 1. 运行测试 + +```bash +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +#### 2. 预期输出 + +``` +============================================================ +Embedding 图录制支持验证 +============================================================ +============================================================ +测试 Embedding 图录制支持 +============================================================ + +1. 输入张量信息: + - Shape: [4, 32] + - Device: cuda + - Dtype: int64 + +2. 尝试 CUDA Graph 录制... + 使用 PyTorch CUDA Graph API 测试... + ✗ 图录制失败: [错误信息] + ✗ Embedding 不支持 CUDA Graph 录制(可能包含同步操作) + +3. 简化验证:检查异步操作支持 + ✓ 输入在设备上 + ⚠ 操作可能包含同步点(事件立即完成) ← 关键:说明有同步操作 + ✓ Forward 执行时间: X.XXX ms + ✓ 输出形状: [4, 32, 128] + ✓ 输出设备: cuda + ✗ 输出验证失败 + +============================================================ +测试 Embedding 设备端输入支持 +============================================================ + +测试 1: 设备端输入 + ✗ 设备端输入失败: [错误信息] + +============================================================ +测试结果总结 +============================================================ +CUDA Graph 录制: ✗ 失败 +设备端输入: ✗ 失败 +============================================================ +✗ 部分测试失败,Embedding 可能不完全支持图录制 +============================================================ +``` + +#### 3. 关键失败点 + +- **图录制失败**:因为代码中有 `indices->to(cpu_device)` 同步操作 +- **设备端输入失败**:需要先将输入拷贝到 CPU +- **异步验证显示同步点**:事件立即完成,说明有同步操作 + +--- + +### ✅ 改动后:支持图录制 + +#### 1. 运行测试 + +```bash +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +#### 2. 预期输出 + +``` +============================================================ +Embedding 图录制支持验证 +============================================================ +============================================================ +测试 Embedding 图录制支持 +============================================================ + +1. 输入张量信息: + - Shape: [4, 32] + - Device: cuda + - Dtype: int64 + +2. 尝试 CUDA Graph 录制... + 使用 PyTorch CUDA Graph API 测试... + ✓ 成功完成图录制! + ✓ Embedding 支持 CUDA Graph 录制 + ✓ 图可以成功重放 + +============================================================ +测试 Embedding 设备端输入支持 +============================================================ + +测试 1: 设备端输入 + ✓ 设备端输入成功 + - 输入设备: cuda + - 输出设备: cuda + - 输出形状: [1, 5, 64] + +============================================================ +测试结果总结 +============================================================ +CUDA Graph 录制: ✓ 通过 +设备端输入: ✓ 通过 +============================================================ +✓ 所有测试通过!Embedding 支持图录制 +============================================================ +``` + +#### 3. 关键成功点 + +- **图录制成功**:所有操作都是异步的,无同步点 +- **设备端输入成功**:直接支持设备端输入,无需拷贝 +- **图可以重放**:验证图录制的正确性 + +--- + +## 🔍 如何判断当前是改动前还是改动后? + +### 方法 1: 代码检查(最快) + +```bash +# 检查是否有同步操作 +grep -n "to(cpu_device)" src/infinicore/nn/embedding.cc + +# 结果解读: +# - 有输出 → ❌ 改动前(不支持图录制) +# - 无输出 → ✅ 改动后(支持图录制) +``` + +### 方法 2: 检查设备端实现 + +```bash +# 检查是否有设备端 CUDA kernel +ls src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu + +# 结果解读: +# - 不存在 → ❌ 改动前(不支持图录制) +# - 存在 → ✅ 改动后(支持图录制) +``` + +### 方法 3: 运行测试(最准确) + +```bash +python test/infinicore/nn/test_embedding_graph_recording.py + +# 查看 "CUDA Graph 录制" 测试结果: +# - ✓ 通过 → ✅ 改动后(支持图录制) +# - ✗ 失败 → ❌ 改动前(不支持图录制) +``` + +--- + +## 📝 测试内容详解 + +### 测试 1: CUDA Graph 录制 + +**目的**:验证 embedding 是否可以在 CUDA Graph 中录制 + +**工作原理**: +1. 使用 PyTorch 的 `torch.cuda.CUDAGraph()` API +2. 在图录制模式下执行 `embedding.forward()` +3. 如果包含同步操作,录制会失败 +4. 如果完全异步,录制会成功 + +**改动前**: +- ❌ 录制失败:因为 `indices->to(cpu_device)` 触发同步 + +**改动后**: +- ✅ 录制成功:使用设备端 CUDA kernel,完全异步 + +### 测试 2: 设备端输入支持 + +**目的**:验证 embedding 是否支持设备端输入 + +**工作原理**: +1. 创建设备端的 `input_ids` +2. 直接调用 `embedding.forward(input_ids)` +3. 检查是否成功且输出在设备上 + +**改动前**: +- ❌ 可能需要先将输入拷贝到 CPU(同步操作) + +**改动后**: +- ✅ 直接支持设备端输入(完全异步) + +### 测试 3: 异步操作验证(备用) + +**目的**:当 CUDA Graph API 不可用时,使用事件验证异步性 + +**工作原理**: +1. 使用 `DeviceEvent` 记录操作时间 +2. 检查操作是否立即完成(同步)或异步执行 + +**改动前**: +- ⚠️ 事件立即完成,说明有同步操作 + +**改动后**: +- ✅ 事件未立即完成,说明是异步操作 + +--- + +## 🛠️ 故障排查 + +### 问题 1: PyTorch 版本不支持 CUDA Graph + +**现象**: +``` +⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法 +``` + +**解决**: +- 需要 PyTorch 2.0+ 版本 +- 测试会自动降级到简化验证方法 +- 简化验证也能检测是否支持图录制 + +### 问题 2: CUDA 不可用 + +**现象**: +``` +⚠ CUDA 不可用,跳过图录制测试 +``` + +**解决**: +- 确保 CUDA 设备可用 +- 测试需要 CUDA 环境 + +### 问题 3: 测试失败但不确定原因 + +**检查清单**: +1. ✅ 确认代码已编译(特别是 CUDA 支持) +2. ✅ 确认 CUDA 设备可用 +3. ✅ 检查 `src/infinicore/nn/embedding.cc` 是否还有 `to(cpu_device)` +4. ✅ 检查是否有 `src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu` + +--- + +## 💡 快速验证脚本 + +创建一个简单的验证脚本: + +```bash +#!/bin/bash +# quick_check.sh + +cd /home/zhuyue/codes/InfiniCore + +echo "=== 1. 代码检查 ===" +if grep -q "to(cpu_device)" src/infinicore/nn/embedding.cc; then + echo "❌ 改动前:发现同步操作 to(cpu_device)" +else + echo "✅ 改动后:无同步操作" +fi + +echo "" +echo "=== 2. 设备端实现检查 ===" +if [ -f "src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu" ]; then + echo "✅ 改动后:有设备端 CUDA kernel" +else + echo "❌ 改动前:无设备端 CUDA kernel" +fi + +echo "" +echo "=== 3. 运行测试 ===" +python test/infinicore/nn/test_embedding_graph_recording.py +``` + +使用方法: +```bash +chmod +x quick_check.sh +./quick_check.sh +``` + +--- + +## 📋 总结 + +### 改动前特征 + +| 检查项 | 结果 | +|--------|------| +| 代码中有 `to(cpu_device)` | ✅ 有 | +| 有设备端 CUDA kernel | ❌ 无 | +| 图录制测试 | ❌ 失败 | +| 设备端输入 | ❌ 失败 | + +### 改动后特征 + +| 检查项 | 结果 | +|--------|------| +| 代码中有 `to(cpu_device)` | ❌ 无 | +| 有设备端 CUDA kernel | ✅ 有 | +| 图录制测试 | ✅ 成功 | +| 设备端输入 | ✅ 成功 | + +### 最简单的判断方法 + +**运行测试脚本**,查看 "CUDA Graph 录制" 测试结果: +- ✅ **通过** → 支持图录制(改动后) +- ❌ **失败** → 不支持图录制(改动前) + diff --git a/test/infinicore/nn/embedding.py b/test/infinicore/nn/embedding.py index 667713537..023bc7762 100644 --- a/test/infinicore/nn/embedding.py +++ b/test/infinicore/nn/embedding.py @@ -114,14 +114,9 @@ def torch_operator(self, x, weight): def infinicore_operator(self, x, weight): """InfiniCore nn.Embedding implementation""" - - if x.device.type != "cpu": - # 将 input的数据 转移到 cpu 上 - x_torch = convert_infinicore_to_torch(x) - x_torch_cpu = x_torch.contiguous().cpu() - - x = infinicore.from_torch(x_torch_cpu) - + # Note: embedding now supports device-side input for graph recording + # No need to convert to CPU anymore - the implementation handles both CPU and device inputs + num_embeddings, embedding_dim = weight.shape model = infinicore.nn.Embedding( diff --git a/test/infinicore/nn/test_embedding_graph_recording.py b/test/infinicore/nn/test_embedding_graph_recording.py new file mode 100644 index 000000000..405f71e0d --- /dev/null +++ b/test/infinicore/nn/test_embedding_graph_recording.py @@ -0,0 +1,284 @@ +""" +测试 embedding 是否支持 CUDA Graph 录制 + +使用方法: + python test/infinicore/nn/test_embedding_graph_recording.py + +关键验证点: +1. 改动前:indices->to(cpu_device) 会触发同步的 D2H 拷贝,导致图录制失败 +2. 改动后:使用设备端 CUDA kernel,完全异步,支持图录制 + +预期结果: +- 改动前:图录制失败,设备端输入可能失败 +- 改动后:图录制成功,设备端输入成功 +""" + +import infinicore +import torch +import ctypes + + +def test_embedding_graph_recording(): + """测试 embedding 是否支持 CUDA Graph 录制""" + print("=" * 60) + print("测试 Embedding 图录制支持") + print("=" * 60) + + # 检查是否有 CUDA + if not torch.cuda.is_available(): + print("⚠ CUDA 不可用,跳过图录制测试") + return False + + device = infinicore.device("cuda", 0) + + # 创建 embedding 模块 + vocab_size = 1000 + embedding_dim = 128 + embedding = infinicore.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + dtype=infinicore.float32, + device=device + ) + + # 创建设备端的 input_ids(这是关键:改动前不支持,改动后支持) + batch_size = 4 + seq_len = 32 + input_ids_device = infinicore.from_list( + [[i % vocab_size for i in range(seq_len)] for _ in range(batch_size)], + dtype=infinicore.int64, + device=device + ) + + print(f"\n1. 输入张量信息:") + print(f" - Shape: {input_ids_device.shape}") + print(f" - Device: {input_ids_device.device.type}") + print(f" - Dtype: {input_ids_device.dtype}") + + # 尝试使用 CUDA Graph 录制 + print(f"\n2. 尝试 CUDA Graph 录制...") + + # 使用 PyTorch 的 CUDA Graph API 进行测试(更简单可靠) + try: + # 设置设备 + infinicore.set_device(device) + + # 使用 PyTorch 的 CUDA Graph API + # 注意:PyTorch 2.0+ 支持 torch.cuda.graph + try: + # 方法 1: 使用 PyTorch 的 CUDA Graph(推荐) + print(" 使用 PyTorch CUDA Graph API 测试...") + + # 创建 warmup 输入 + warmup_input = input_ids_device + + # Warmup(图录制前需要先执行一次,包括内存分配) + warmup_output = embedding.forward(warmup_input) + infinicore.sync_stream() # 同步确保 warmup 完成 + + # 预先分配输出张量(CUDA Graph 不支持动态内存分配) + # 输出形状: input_shape + [embedding_dim] + output_shape = list(input_ids_device.shape) + [embedding_dim] + output = infinicore.empty( + output_shape, + dtype=embedding.weight.dtype, + device=device + ) + + # Warmup embedding(确保内存分配完成) + import infinicore.nn.functional as F + F.embedding(warmup_input, embedding.weight, out=output) + infinicore.sync_stream() + + # 开始图录制(使用预先分配的 output) + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + # 使用 embedding 的 out 参数(in-place),传入预先分配的 output + F.embedding(input_ids_device, embedding.weight, out=output) + + print(" ✓ 成功完成图录制!") + print(" ✓ Embedding 支持 CUDA Graph 录制") + + # 验证图可以重复执行 + graph.replay() + infinicore.sync_stream() + + print(" ✓ 图可以成功重放") + return True + + except AttributeError: + # PyTorch 版本可能不支持 torch.cuda.graph + print(" ⚠ PyTorch 版本不支持 torch.cuda.graph,使用简化验证方法") + return test_embedding_async_verification(embedding, input_ids_device) + except RuntimeError as e: + error_msg = str(e) + if "capture" in error_msg.lower() or "graph" in error_msg.lower(): + print(f" ✗ 图录制失败: {e}") + print(" ✗ Embedding 不支持 CUDA Graph 录制(可能包含同步操作)") + return False + else: + print(f" ⚠ 图录制测试异常: {e}") + return test_embedding_async_verification(embedding, input_ids_device) + + except Exception as e: + print(f" ⚠ 图录制测试异常: {e}") + print(" 使用简化验证方法...") + import traceback + traceback.print_exc() + return test_embedding_async_verification(embedding, input_ids_device) + + +def test_embedding_async_verification(embedding, input_ids_device): + """ + 简化验证:检查是否有同步操作 + + 关键检查点: + 1. 输入是否可以在设备上(改动前需要 CPU,改动后支持设备) + 2. 操作是否完全异步(没有同步点) + """ + print("\n3. 简化验证:检查异步操作支持") + + # 验证 1: 输入可以在设备上 + if input_ids_device.device.type != "cuda": + print(" ✗ 输入不在设备上,无法验证") + return False + + print(" ✓ 输入在设备上") + + # 验证 2: 执行 forward,检查是否有同步操作 + # 如果改动前,这里会调用 indices->to(cpu_device),触发同步 + # 如果改动后,直接使用设备端 kernel,完全异步 + + try: + # 记录开始时间 + start_event = infinicore.DeviceEvent(enable_timing=True) + end_event = infinicore.DeviceEvent(enable_timing=True) + + start_event.record() + output = embedding.forward(input_ids_device) + end_event.record() + + # 不立即同步,检查操作是否异步 + # 如果操作是异步的,query 应该返回 False(未完成) + # 如果操作是同步的,可能已经完成 + + # 等待一小段时间 + import time + time.sleep(0.001) # 1ms + + # 检查事件状态 + is_complete = end_event.query() + + if not is_complete: + print(" ✓ 操作是异步的(事件未立即完成)") + else: + print(" ⚠ 操作可能包含同步点(事件立即完成)") + + # 同步并测量时间 + end_event.synchronize() + elapsed = start_event.elapsed_time(end_event) + + print(f" ✓ Forward 执行时间: {elapsed:.3f} ms") + print(f" ✓ 输出形状: {output.shape}") + print(f" ✓ 输出设备: {output.device.type}") + + # 验证输出正确性 + embedding_dim = embedding.embedding_dim() + expected_shape = (*input_ids_device.shape, embedding_dim) + if output.device.type == "cuda" and output.shape == expected_shape: + print(" ✓ 输出在设备上,形状正确") + return True + else: + print(f" ✗ 输出验证失败") + print(f" 期望形状: {expected_shape}, 实际形状: {output.shape}") + print(f" 期望设备: cuda, 实际设备: {output.device.type}") + return False + + except Exception as e: + print(f" ✗ 验证失败: {e}") + import traceback + traceback.print_exc() + return False + + +def test_embedding_device_input_support(): + """测试 embedding 是否支持设备端输入""" + print("\n" + "=" * 60) + print("测试 Embedding 设备端输入支持") + print("=" * 60) + + if not torch.cuda.is_available(): + print("⚠ CUDA 不可用,跳过测试") + return False + + device = infinicore.device("cuda", 0) + vocab_size = 100 + embedding_dim = 64 + + embedding = infinicore.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + dtype=infinicore.float32, + device=device + ) + + # 测试 1: 设备端输入(改动后支持) + print("\n测试 1: 设备端输入") + try: + input_ids_device = infinicore.from_list( + [[1, 2, 3, 4, 5]], + dtype=infinicore.int64, + device=device + ) + output = embedding.forward(input_ids_device) + print(f" ✓ 设备端输入成功") + print(f" - 输入设备: {input_ids_device.device.type}") + print(f" - 输出设备: {output.device.type}") + print(f" - 输出形状: {output.shape}") + return True + except Exception as e: + print(f" ✗ 设备端输入失败: {e}") + return False + + +def main(): + """主测试函数""" + print("\n" + "=" * 60) + print("Embedding 图录制支持验证") + print("=" * 60) + + results = [] + + # 测试 1: 图录制支持 + result1 = test_embedding_graph_recording() + results.append(("CUDA Graph 录制", result1)) + + # 测试 2: 设备端输入支持 + result2 = test_embedding_device_input_support() + results.append(("设备端输入", result2)) + + # 总结 + print("\n" + "=" * 60) + print("测试结果总结") + print("=" * 60) + + all_passed = True + for test_name, result in results: + status = "✓ 通过" if result else "✗ 失败" + print(f"{test_name}: {status}") + if not result: + all_passed = False + + print("\n" + "=" * 60) + if all_passed: + print("✓ 所有测试通过!Embedding 支持图录制") + else: + print("✗ 部分测试失败,Embedding 可能不完全支持图录制") + print("=" * 60) + + return all_passed + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/test/infinicore/ops/embedding.py b/test/infinicore/ops/embedding.py index a8bdc00b8..6cb7755af 100644 --- a/test/infinicore/ops/embedding.py +++ b/test/infinicore/ops/embedding.py @@ -102,23 +102,9 @@ def torch_operator(self, *args, out=None, **kwargs): def infinicore_operator(self, input, weight, out=None, **kwargs): """InfiniCore Embedding implementation""" - - if input.device.type == "cpu": - input_cpu = input - else: - # 将 input的数据 转移到 cpu 上 - torch_reference = torch.zeros( - input.shape, - dtype=to_torch_dtype(input.dtype), - device="cpu" if "cpu" == input.device.type else "cuda", - ) - torch_reference = convert_infinicore_to_torch(input) - torch_reference = torch_reference.contiguous().cpu() - - # 创建cpu的 input - input_cpu = infinicore_tensor_from_torch(torch_reference) - - return infinicore.nn.functional.embedding(input_cpu, weight, out=out) + # Note: embedding now supports device-side input for graph recording + # No need to convert to CPU anymore - the implementation handles both CPU and device inputs + return infinicore.nn.functional.embedding(input, weight, out=out) def main(): From 3b0680e7e3830406bb4ba0eb290896a1cd325c7a Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 26 Dec 2025 07:54:18 +0000 Subject: [PATCH 2/3] Issue/846 - Ensure embedding tensors are on the same device. Change format. --- src/infinicore/nn/embedding.cc | 11 ++++- src/infinicore/ops/embedding/embedding.cc | 12 +++--- .../ops/embedding/embedding_infiniop.cc | 2 +- .../ops/embedding/cpu/embedding_cpu.cc | 43 +++++++++---------- src/infiniop/ops/embedding/embedding.h | 4 +- .../ops/embedding/nvidia/embedding_kernel.cuh | 6 +-- .../ops/embedding/nvidia/embedding_nvidia.cu | 38 ++++++++-------- src/infiniop/ops/embedding/operator.cc | 6 +-- 8 files changed, 64 insertions(+), 58 deletions(-) diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index f1af03042..6aa86a4fa 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -43,10 +43,17 @@ Embedding::Embedding(size_t num_embeddings, } Tensor Embedding::forward(const Tensor &indices) const { + // Ensure indices are on the same device as weight + // This avoids synchronous memcpy in ops layer which would hurt performance + Tensor indices_on_device = indices; + if (indices->device() != device_) { + indices_on_device = indices->to(device_); + } + // Ensure indices are contiguous for efficient access // op::embedding now supports device-side input for graph recording - Tensor indices_contiguous = indices->is_contiguous() ? indices : indices->contiguous(); - + Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous(); + // Use op::embedding which now supports device-side input and batch dimension // This enables full graph recording support without synchronization return op::embedding(indices_contiguous, weight_); diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index cf5c41caf..96f19803c 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -1,6 +1,6 @@ #include "infinicore/ops/embedding.hpp" -#include "infinicore/context/context.hpp" #include "../../utils.hpp" +#include "infinicore/context/context.hpp" #include #include @@ -12,12 +12,14 @@ common::OpDispatcher &Embedding::dispatcher() { } void Embedding::execute(Tensor out, Tensor input, Tensor weight) { - // Check that output and weight are on the same device - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, weight); - + // Check that all tensors are on the same device + // This is critical: if input is on CPU while out/weight are on GPU, + // passing CPU pointer to CUDA kernel will cause memory access errors + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight); + // Set device context infinicore::context::setDevice(out->device()); - + // Use dispatcher to lookup kernel (infiniop implementation) dispatcher().lookup(out->device().getType())(out, input, weight); } diff --git a/src/infinicore/ops/embedding/embedding_infiniop.cc b/src/infinicore/ops/embedding/embedding_infiniop.cc index af73f13fa..dfbbb2f71 100644 --- a/src/infinicore/ops/embedding/embedding_infiniop.cc +++ b/src/infinicore/ops/embedding/embedding_infiniop.cc @@ -1,7 +1,7 @@ #include "../../utils.hpp" #include "infinicore/common/hash.hpp" -#include "infinicore/ops/embedding.hpp" #include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/embedding.hpp" #include namespace infinicore::op::embedding_impl::infiniop { diff --git a/src/infiniop/ops/embedding/cpu/embedding_cpu.cc b/src/infiniop/ops/embedding/cpu/embedding_cpu.cc index e84eced6b..8e6648063 100644 --- a/src/infiniop/ops/embedding/cpu/embedding_cpu.cc +++ b/src/infiniop/ops/embedding/cpu/embedding_cpu.cc @@ -1,7 +1,7 @@ +#include "embedding_cpu.h" #include "../../../../utils.h" -#include "../../../tensor.h" #include "../../../handle.h" -#include "embedding_cpu.h" +#include "../../../tensor.h" #include namespace op::embedding::cpu { @@ -21,33 +21,32 @@ infiniStatus_t Descriptor::create( auto input_shape = input_desc->shape(); auto weight_shape = weight_desc->shape(); - + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - + auto output_shape = output_desc->shape(); size_t embedding_dim = weight_shape[1]; CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); - + for (size_t i = 0; i < input_shape.size(); ++i) { CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); } - + auto input_dtype = input_desc->dtype(); auto weight_dtype = weight_desc->dtype(); CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || - weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - + size_t num_indices = 1; for (auto dim : input_shape) { num_indices *= dim; } - + size_t vocab_size = weight_shape[0]; - + *desc_ptr = new Descriptor( num_indices, embedding_dim, @@ -57,7 +56,7 @@ infiniStatus_t Descriptor::create( new Opaque{}, handle->device, handle->device_id); - + return INFINI_STATUS_SUCCESS; } @@ -66,44 +65,44 @@ infiniStatus_t Descriptor::calculate( const void *input, const void *weight, void *stream) const { - + if (_num_indices == 0) { return INFINI_STATUS_SUCCESS; } - + size_t element_size = infiniSizeOf(_weight_dtype); size_t row_bytes = _embedding_dim * element_size; - + if (_input_dtype == INFINI_DTYPE_I32) { const int32_t *indices_ptr = reinterpret_cast(input); const std::byte *weight_ptr = reinterpret_cast(weight); std::byte *out_ptr = reinterpret_cast(output); - + for (size_t i = 0; i < _num_indices; ++i) { int32_t idx = indices_ptr[i]; if (idx >= 0 && static_cast(idx) < _vocab_size) { std::memcpy(out_ptr + i * row_bytes, - weight_ptr + static_cast(idx) * row_bytes, - row_bytes); + weight_ptr + static_cast(idx) * row_bytes, + row_bytes); } } } else if (_input_dtype == INFINI_DTYPE_I64) { const int64_t *indices_ptr = reinterpret_cast(input); const std::byte *weight_ptr = reinterpret_cast(weight); std::byte *out_ptr = reinterpret_cast(output); - + for (size_t i = 0; i < _num_indices; ++i) { int64_t idx = indices_ptr[i]; if (idx >= 0 && static_cast(idx) < _vocab_size) { std::memcpy(out_ptr + i * row_bytes, - weight_ptr + static_cast(idx) * row_bytes, - row_bytes); + weight_ptr + static_cast(idx) * row_bytes, + row_bytes); } } } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - + return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/embedding/embedding.h b/src/infiniop/ops/embedding/embedding.h index 0e4b33009..e0135dbfe 100644 --- a/src/infiniop/ops/embedding/embedding.h +++ b/src/infiniop/ops/embedding/embedding.h @@ -20,8 +20,8 @@ size_t num_indices, \ size_t embedding_dim, \ size_t vocab_size, \ - infiniDtype_t input_dtype, \ - infiniDtype_t weight_dtype, \ + infiniDtype_t input_dtype, \ + infiniDtype_t weight_dtype, \ Opaque *opaque, \ infiniDevice_t device_type, \ int device_id) \ diff --git a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh index 8398bfbfc..88f0ad92f 100644 --- a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh +++ b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh @@ -16,17 +16,17 @@ INFINIOP_CUDA_KERNEL embeddingKernel( size_t vocab_size) { // Calculate global thread index size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - + if (idx < num_indices) { // Get the index value IndexType index_val = indices[idx]; - + // Bounds check - handle negative indices gracefully if (index_val >= 0 && static_cast(index_val) < vocab_size) { // Copy embedding vector from weight to output const T *src = weight + static_cast(index_val) * embedding_dim; T *dst = output + idx * embedding_dim; - + // Copy embedding_dim elements // Use vectorized copy for better performance when possible size_t i = 0; diff --git a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu index 007e90c04..72d1e6514 100644 --- a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu +++ b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu @@ -1,7 +1,7 @@ +#include "../../../../utils.h" #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "../../../tensor.h" -#include "../../../../utils.h" #include "embedding_kernel.cuh" #include "embedding_nvidia.cuh" #include @@ -23,50 +23,48 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t input_desc, infiniopTensorDescriptor_t weight_desc) { - auto handle_nvidia = reinterpret_cast(handle); auto input_shape = input_desc->shape(); auto weight_shape = weight_desc->shape(); - + // Validate shapes CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); - + // Check output shape matches input shape + embedding_dim auto output_shape = output_desc->shape(); size_t embedding_dim = weight_shape[1]; CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); - + for (size_t i = 0; i < input_shape.size(); ++i) { CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); } - + // Validate dtypes auto input_dtype = input_desc->dtype(); auto weight_dtype = weight_desc->dtype(); CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || - weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE); CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - + // Calculate number of indices (supporting batch dimension) size_t num_indices = 1; for (auto dim : input_shape) { num_indices *= dim; } - + size_t vocab_size = weight_shape[0]; - + *desc_ptr = new Descriptor( num_indices, embedding_dim, vocab_size, input_dtype, weight_dtype, - new Opaque{handle_nvidia->internal()}, + new Opaque{reinterpret_cast(handle)->internal()}, handle->device, handle->device_id); - + return INFINI_STATUS_SUCCESS; } @@ -75,19 +73,19 @@ infiniStatus_t Descriptor::calculate( const void *input, const void *weight, void *stream) const { - + if (_num_indices == 0) { return INFINI_STATUS_SUCCESS; } - + auto cuda_stream = reinterpret_cast(stream); constexpr size_t BLOCK_SIZE = 256; size_t grid_size = (_num_indices + BLOCK_SIZE - 1) / BLOCK_SIZE; - + // Launch kernel based on dtypes if (_input_dtype == INFINI_DTYPE_I32) { const int32_t *indices_ptr = reinterpret_cast(input); - + if (_weight_dtype == INFINI_DTYPE_F32) { embeddingKernel<<>>( reinterpret_cast(output), @@ -117,7 +115,7 @@ infiniStatus_t Descriptor::calculate( } } else if (_input_dtype == INFINI_DTYPE_I64) { const int64_t *indices_ptr = reinterpret_cast(input); - + if (_weight_dtype == INFINI_DTYPE_F32) { embeddingKernel<<>>( reinterpret_cast(output), @@ -148,13 +146,13 @@ infiniStatus_t Descriptor::calculate( } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - + // Check for kernel launch errors cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { return INFINI_STATUS_INTERNAL_ERROR; } - + return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc index af75842fa..50f2f05ed 100644 --- a/src/infiniop/ops/embedding/operator.cc +++ b/src/infiniop/ops/embedding/operator.cc @@ -18,7 +18,7 @@ __C infiniStatus_t infiniopCreateEmbeddingDescriptor( #define CREATE(CASE, NAMESPACE) \ case CASE: \ - return op::embedding::NAMESPACE::Descriptor::create( \ + return op::embedding::NAMESPACE::Descriptor::create( \ handle, \ reinterpret_cast(desc_ptr), \ output_desc, \ @@ -89,8 +89,8 @@ __C infiniStatus_t infiniopEmbedding( __C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { -#define DELETE(CASE, NAMESPACE) \ - case CASE: \ +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; From 3b3f0fbc72c6c5800b1ef76cf53fef8b8ae0b424 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Fri, 26 Dec 2025 08:32:21 +0000 Subject: [PATCH 3/3] Issue/846 - Optimize embedding kernel with vectorized memory access and __ldg - Add vectorized memory access using float4/float2, half2, and bfloat162 - Use __ldg instruction for read-only weight and indices access - Add memory alignment checks to enable vectorized paths - Add __restrict__ keywords for better compiler optimization - Implement dynamic block size selection based on embedding_dim --- .../ops/embedding/nvidia/embedding_kernel.cuh | 162 ++++++++++++++++-- .../ops/embedding/nvidia/embedding_nvidia.cu | 26 ++- 2 files changed, 163 insertions(+), 25 deletions(-) diff --git a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh index 88f0ad92f..0e85b5f6a 100644 --- a/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh +++ b/src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh @@ -2,15 +2,127 @@ #define __EMBEDDING_CUDA_KERNEL_CUH__ #include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include #include +#include namespace op::embedding::nvidia { +// Helper function to check memory alignment +__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) { + // Use size_t for pointer arithmetic in device code (more compatible) + return (reinterpret_cast(ptr) % alignment == 0); +} + +// Vectorized copy for float type using float4 +template +__forceinline__ __device__ void copyVectorizedFloat4( + float *__restrict__ dst, + const float *__restrict__ src, + size_t embedding_dim) { + // Use float4 for vectorized access (16 bytes, 4 floats) + const float4 *src_vec = reinterpret_cast(src); + float4 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 4; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining elements + size_t remaining = embedding_dim % 4; + if (remaining > 0) { + size_t offset = vec_count * 4; + for (size_t i = 0; i < remaining; ++i) { + dst[offset + i] = __ldg(&src[offset + i]); + } + } +} + +// Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes) +template +__forceinline__ __device__ void copyVectorizedFloat2( + float *__restrict__ dst, + const float *__restrict__ src, + size_t embedding_dim) { + // Use float2 for vectorized access (8 bytes, 2 floats) + const float2 *src_vec = reinterpret_cast(src); + float2 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +// Vectorized copy for half type using half2 +template +__forceinline__ __device__ void copyVectorizedHalf2( + half *__restrict__ dst, + const half *__restrict__ src, + size_t embedding_dim) { + // Use half2 for vectorized access (4 bytes, 2 halfs) + const half2 *src_vec = reinterpret_cast(src); + half2 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +// Vectorized copy for bfloat16 type using bfloat162 +template +__forceinline__ __device__ void copyVectorizedBFloat162( + cuda_bfloat16 *__restrict__ dst, + const cuda_bfloat16 *__restrict__ src, + size_t embedding_dim) { + // Use bfloat162 for vectorized access (4 bytes, 2 bfloat16s) + const cuda_bfloat162 *src_vec = reinterpret_cast(src); + cuda_bfloat162 *dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + // Vectorized copy using __ldg for read-only weight + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + // Copy remaining element if odd + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +// Scalar copy fallback with __ldg optimization +template +__forceinline__ __device__ void copyScalar( + T *__restrict__ dst, + const T *__restrict__ src, + size_t embedding_dim) { + // Scalar copy with __ldg for read-only weight + for (size_t i = 0; i < embedding_dim; ++i) { + dst[i] = __ldg(&src[i]); + } +} + template INFINIOP_CUDA_KERNEL embeddingKernel( - T *output, - const IndexType *indices, - const T *weight, + T *__restrict__ output, + const IndexType *__restrict__ indices, + const T *__restrict__ weight, size_t num_indices, size_t embedding_dim, size_t vocab_size) { @@ -19,7 +131,7 @@ INFINIOP_CUDA_KERNEL embeddingKernel( if (idx < num_indices) { // Get the index value - IndexType index_val = indices[idx]; + IndexType index_val = __ldg(&indices[idx]); // Bounds check - handle negative indices gracefully if (index_val >= 0 && static_cast(index_val) < vocab_size) { @@ -27,19 +139,35 @@ INFINIOP_CUDA_KERNEL embeddingKernel( const T *src = weight + static_cast(index_val) * embedding_dim; T *dst = output + idx * embedding_dim; - // Copy embedding_dim elements - // Use vectorized copy for better performance when possible - size_t i = 0; - // Copy in chunks of 4 for better memory bandwidth utilization - for (; i + 4 <= embedding_dim; i += 4) { - dst[i] = src[i]; - dst[i + 1] = src[i + 1]; - dst[i + 2] = src[i + 2]; - dst[i + 3] = src[i + 3]; - } - // Copy remaining elements - for (; i < embedding_dim; ++i) { - dst[i] = src[i]; + // Choose optimal copy strategy based on type and alignment + if constexpr (std::is_same_v) { + // Check alignment for float4 (16 bytes) + bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16); + if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) { + copyVectorizedFloat4(dst, src, embedding_dim); + } else if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + // Try float2 if not aligned to 16 bytes + copyVectorizedFloat2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use half2 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedHalf2(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else if constexpr (std::is_same_v) { + // Use bfloat162 for vectorized access + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + copyVectorizedBFloat162(dst, src, embedding_dim); + } else { + copyScalar(dst, src, embedding_dim); + } + } else { + // Fallback to scalar copy with __ldg + copyScalar(dst, src, embedding_dim); } } } diff --git a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu index 72d1e6514..b714b0aa4 100644 --- a/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu +++ b/src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu @@ -79,15 +79,25 @@ infiniStatus_t Descriptor::calculate( } auto cuda_stream = reinterpret_cast(stream); - constexpr size_t BLOCK_SIZE = 256; - size_t grid_size = (_num_indices + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Dynamic block size optimization based on embedding_dim + // Smaller embedding_dim benefits from larger block size (better occupancy) + // Larger embedding_dim benefits from smaller block size (more registers per thread) + size_t block_size = 256; // Default + if (_embedding_dim <= 64) { + block_size = 512; // Small embedding_dim: use larger block for better occupancy + } else if (_embedding_dim >= 1024) { + block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure + } + + size_t grid_size = (_num_indices + block_size - 1) / block_size; // Launch kernel based on dtypes if (_input_dtype == INFINI_DTYPE_I32) { const int32_t *indices_ptr = reinterpret_cast(input); if (_weight_dtype == INFINI_DTYPE_F32) { - embeddingKernel<<>>( + embeddingKernel<<>>( reinterpret_cast(output), indices_ptr, reinterpret_cast(weight), @@ -95,7 +105,7 @@ infiniStatus_t Descriptor::calculate( _embedding_dim, _vocab_size); } else if (_weight_dtype == INFINI_DTYPE_F16) { - embeddingKernel<<>>( + embeddingKernel<<>>( reinterpret_cast(output), indices_ptr, reinterpret_cast(weight), @@ -103,7 +113,7 @@ infiniStatus_t Descriptor::calculate( _embedding_dim, _vocab_size); } else if (_weight_dtype == INFINI_DTYPE_BF16) { - embeddingKernel<<>>( + embeddingKernel<<>>( reinterpret_cast(output), indices_ptr, reinterpret_cast(weight), @@ -117,7 +127,7 @@ infiniStatus_t Descriptor::calculate( const int64_t *indices_ptr = reinterpret_cast(input); if (_weight_dtype == INFINI_DTYPE_F32) { - embeddingKernel<<>>( + embeddingKernel<<>>( reinterpret_cast(output), indices_ptr, reinterpret_cast(weight), @@ -125,7 +135,7 @@ infiniStatus_t Descriptor::calculate( _embedding_dim, _vocab_size); } else if (_weight_dtype == INFINI_DTYPE_F16) { - embeddingKernel<<>>( + embeddingKernel<<>>( reinterpret_cast(output), indices_ptr, reinterpret_cast(weight), @@ -133,7 +143,7 @@ infiniStatus_t Descriptor::calculate( _embedding_dim, _vocab_size); } else if (_weight_dtype == INFINI_DTYPE_BF16) { - embeddingKernel<<>>( + embeddingKernel<<>>( reinterpret_cast(output), indices_ptr, reinterpret_cast(weight),