diff --git a/.gitignore b/.gitignore index d9479360b..4944a5db4 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,4 @@ cache/ # Compressed *.gz *.zip -*.tar +*.tar \ No newline at end of file diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 0937a4821..79a0655be 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -1,8 +1,13 @@ #pragma once +#include "ops/adaptive_max_pool1d.hpp" #include "ops/add.hpp" +#include "ops/asinh.hpp" #include "ops/attention.hpp" +#include "ops/baddbmm.hpp" +#include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" +#include "ops/fmod.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/rearrange.hpp" diff --git a/include/infinicore/ops/adaptive_max_pool1d.hpp b/include/infinicore/ops/adaptive_max_pool1d.hpp new file mode 100644 index 000000000..05e49b490 --- /dev/null +++ b/include/infinicore/ops/adaptive_max_pool1d.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class AdaptiveMaxPool1d { +public: + using schema = void (*)(Tensor, Tensor, size_t); + static void execute(Tensor y, Tensor x, size_t output_size); + static common::OpDispatcher &dispatcher(); +}; + +Tensor adaptive_max_pool1d(Tensor x, size_t output_size); +void adaptive_max_pool1d_(Tensor y, Tensor x, size_t output_size); +} // namespace infinicore::op \ No newline at end of file diff --git a/include/infinicore/ops/asinh.hpp b/include/infinicore/ops/asinh.hpp new file mode 100644 index 000000000..505eb97d9 --- /dev/null +++ b/include/infinicore/ops/asinh.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Asinh { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor y, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +Tensor asinh(Tensor x); +void asinh_(Tensor y, Tensor x); +} // namespace infinicore::op diff --git a/include/infinicore/ops/baddbmm.hpp b/include/infinicore/ops/baddbmm.hpp new file mode 100644 index 000000000..3c08b98d9 --- /dev/null +++ b/include/infinicore/ops/baddbmm.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, + float beta = 1.0f, + float alpha = 1.0f); +void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, + float beta = 1.0f, + float alpha = 1.0f); +} // namespace infinicore::op \ No newline at end of file diff --git a/include/infinicore/ops/bilinear.hpp b/include/infinicore/ops/bilinear.hpp new file mode 100644 index 000000000..3f5f44aac --- /dev/null +++ b/include/infinicore/ops/bilinear.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias); +void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias); + +} // namespace infinicore::op \ No newline at end of file diff --git a/include/infinicore/ops/fmod.hpp b/include/infinicore/ops/fmod.hpp new file mode 100644 index 000000000..87b90d515 --- /dev/null +++ b/include/infinicore/ops/fmod.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Fmod { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor c, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor fmod(Tensor a, Tensor b); +void fmod_(Tensor c, Tensor a, Tensor b); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..c4c114559 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -2,12 +2,15 @@ #define __INFINIOP_API_H__ #include "infiniop/handle.h" +#include "infiniop/ops/adaptive_max_pool1d.h" #include "infiniop/ops/add.h" +#include "infiniop/ops/asinh.h" #include "infiniop/ops/attention.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/fmod.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/layer_norm.h" diff --git a/include/infiniop/ops/adaptive_max_pool1d.h b/include/infiniop/ops/adaptive_max_pool1d.h new file mode 100644 index 000000000..484413e21 --- /dev/null +++ b/include/infiniop/ops/adaptive_max_pool1d.h @@ -0,0 +1,22 @@ +#ifndef __INFINIOP_ADAPTIVE_MAX_POOL1D_H__ +#define __INFINIOP_ADAPTIVE_MAX_POOL1D_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopAdaptiveMaxPool1dDescriptor_t; + +__C __export infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( + infiniopHandle_t handle, + infiniopAdaptiveMaxPool1dDescriptor_t *desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size); + +__C __export infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize(infiniopAdaptiveMaxPool1dDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopAdaptiveMaxPool1d(infiniopAdaptiveMaxPool1dDescriptor_t desc, void *workspace, size_t workspace_size, + void *y, const void *x, void *stream); + +__C __export infiniStatus_t infiniopDestroyAdaptiveMaxPool1dDescriptor(infiniopAdaptiveMaxPool1dDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/infiniop/ops/asinh.h b/include/infiniop/ops/asinh.h new file mode 100644 index 000000000..4849bc422 --- /dev/null +++ b/include/infiniop/ops/asinh.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_ASINH_API_H_ +#define __INFINIOP_ASINH_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopAsinhDescriptor_t; + +__C __export infiniStatus_t infiniopCreateAsinhDescriptor(infiniopHandle_t handle, + infiniopAsinhDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniStatus_t infiniopGetAsinhWorkspaceSize(infiniopAsinhDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopAsinh(infiniopAsinhDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroyAsinhDescriptor(infiniopAsinhDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/infiniop/ops/fmod.h b/include/infiniop/ops/fmod.h new file mode 100644 index 000000000..b74b6daca --- /dev/null +++ b/include/infiniop/ops/fmod.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_FMOD_API_H_ +#define __INFINIOP_FMOD_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopFmodDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFmodDescriptor(infiniopHandle_t handle, + infiniopFmodDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t a, + infiniopTensorDescriptor_t b); + +__C __export infiniStatus_t infiniopGetFmodWorkspaceSize(infiniopFmodDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopFmod(infiniopFmodDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFmodDescriptor(infiniopFmodDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 5c541ec3c..a2f32b62c 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -41,10 +41,14 @@ ) from infinicore.ops.add import add from infinicore.ops.attention import attention +from infinicore.ops.asinh import asinh from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.rearrange import rearrange +from infinicore.ops.baddbmm import baddbmm +from infinicore.ops.bilinear import bilinear +from infinicore.ops.fmod import fmod from infinicore.tensor import ( Tensor, empty, @@ -101,6 +105,10 @@ # Operations. "add", "attention", + "asinh", + "baddbmm", + "bilinear", + "fmod", "matmul", "mul", "narrow", diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..b9b313b67 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,3 +1,4 @@ +from .adaptive_max_pool1d import adaptive_max_pool1d from .causal_softmax import causal_softmax from .embedding import embedding from .linear import linear @@ -8,6 +9,7 @@ from .swiglu import swiglu __all__ = [ + "adaptive_max_pool1d", "causal_softmax", "random_sample", "rms_norm", diff --git a/python/infinicore/nn/functional/adaptive_max_pool1d.py b/python/infinicore/nn/functional/adaptive_max_pool1d.py new file mode 100644 index 000000000..74a8c56e9 --- /dev/null +++ b/python/infinicore/nn/functional/adaptive_max_pool1d.py @@ -0,0 +1,39 @@ +from typing import List + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def adaptive_max_pool1d( + input: Tensor, + output_size: int, + *, + out=None, +) -> Tensor: + r"""Applies a 1D adaptive max pooling over an input signal composed of + several input planes. + + The output size is H_out. The algorithm used is fairly simple: + + .. math:: + \text{start} = \left\lfloor \frac{i \cdot L_{in}}{L_{out}} \right\rfloor + + \text{end} = \left\lceil \frac{(i + 1) \cdot L_{in}}{L_{out}} \right\rceil + + where :math:`L_{in}` is the size of the input dimension, and :math:`L_{out}` is the size of the output dimension. + + Args: + input (Tensor): Input tensor of shape (N, C, L_in) + output_size (int): The target output size (L_out) + out (Tensor, optional): Output tensor. + + Returns: + Tensor: The result of the adaptive max pooling operation. + """ + + if out is None: + return Tensor(_infinicore.adaptive_max_pool1d(input._underlying, output_size)) + + _infinicore.adaptive_max_pool1d_(out._underlying, input._underlying, output_size) + + return out diff --git a/python/infinicore/ops/asinh.py b/python/infinicore/ops/asinh.py new file mode 100644 index 000000000..05ec58779 --- /dev/null +++ b/python/infinicore/ops/asinh.py @@ -0,0 +1,11 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def asinh(input, *, out=None): + if out is None: + return Tensor(_infinicore.asinh(input._underlying)) + + _infinicore.asinh_(out._underlying, input._underlying) + + return out diff --git a/python/infinicore/ops/baddbmm.py b/python/infinicore/ops/baddbmm.py new file mode 100644 index 000000000..4a34cbb64 --- /dev/null +++ b/python/infinicore/ops/baddbmm.py @@ -0,0 +1,25 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def baddbmm(input, batch1, batch2, *, beta=1.0, alpha=1.0, out=None): + if out is None: + return Tensor( + _infinicore.baddbmm( + input._underlying, + batch1._underlying, + batch2._underlying, + float(beta), + float(alpha), + ) + ) + _infinicore.baddbmm_( + out._underlying, + input._underlying, + batch1._underlying, + batch2._underlying, + float(beta), + float(alpha), + ) + + return out diff --git a/python/infinicore/ops/bilinear.py b/python/infinicore/ops/bilinear.py new file mode 100644 index 000000000..4773dd825 --- /dev/null +++ b/python/infinicore/ops/bilinear.py @@ -0,0 +1,23 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def bilinear(input1, input2, weight, bias=None, *, out=None): + if out is None: + return Tensor( + _infinicore.bilinear( + input1._underlying, + input2._underlying, + weight._underlying, + bias._underlying if bias is not None else None, + ) + ) + _infinicore.bilinear_( + out._underlying, + input1._underlying, + input2._underlying, + weight._underlying, + bias._underlying if bias is not None else None, + ) + + return out diff --git a/python/infinicore/ops/fmod.py b/python/infinicore/ops/fmod.py new file mode 100644 index 000000000..e52be82cb --- /dev/null +++ b/python/infinicore/ops/fmod.py @@ -0,0 +1,11 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def fmod(input, other, *, out=None): + if out is None: + return Tensor(_infinicore.fmod(input._underlying, other._underlying)) + + _infinicore.fmod_(out._underlying, input._underlying, other._underlying) + + return out diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc new file mode 100644 index 000000000..bd80b0771 --- /dev/null +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc @@ -0,0 +1,30 @@ +#include "infinicore/ops/adaptive_max_pool1d.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &AdaptiveMaxPool1d::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void AdaptiveMaxPool1d::execute(Tensor y, Tensor x, size_t output_size) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(y, x, output_size); +} + +Tensor adaptive_max_pool1d(Tensor x, size_t output_size) { + infinicore::Shape y_shape = x->shape(); + y_shape.back() = output_size; + auto y = Tensor::empty(y_shape, x->dtype(), x->device()); + adaptive_max_pool1d_(y, x, output_size); + return y; +} + +void adaptive_max_pool1d_(Tensor y, Tensor x, size_t output_size) { + AdaptiveMaxPool1d::execute(y, x, output_size); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc new file mode 100644 index 000000000..451489e15 --- /dev/null +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/adaptive_max_pool1d.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::adaptive_max_pool1d_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopAdaptiveMaxPool1dDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyAdaptiveMaxPool1dDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor y, Tensor x, size_t out) { + size_t seed = hash_combine(y, x, out); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopAdaptiveMaxPool1dDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateAdaptiveMaxPool1dDescriptor( + context::getInfiniopHandle(y->device()), &desc, + y->desc(), x->desc(), out)); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetAdaptiveMaxPool1dWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopAdaptiveMaxPool1d( + desc, workspace->data(), workspace_size, + y->data(), x->data(), context::getStream())); +} + +static bool registered = []() { + AdaptiveMaxPool1d::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::adaptive_max_pool1d_impl::infiniop diff --git a/src/infinicore/ops/asinh/asinh.cc b/src/infinicore/ops/asinh/asinh.cc new file mode 100644 index 000000000..fbf131d99 --- /dev/null +++ b/src/infinicore/ops/asinh/asinh.cc @@ -0,0 +1,27 @@ +#include "infinicore/ops/asinh.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Asinh::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Asinh::execute(Tensor y, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(y, x); +} + +Tensor asinh(Tensor x) { + auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); + asinh_(y, x); + return y; +} + +void asinh_(Tensor y, Tensor x) { + Asinh::execute(y, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/asinh/asinh_infiniop.cc b/src/infinicore/ops/asinh/asinh_infiniop.cc new file mode 100644 index 000000000..ceed8d5a2 --- /dev/null +++ b/src/infinicore/ops/asinh/asinh_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/asinh.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::asinh_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopAsinhDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyAsinhDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor y, Tensor x) { + size_t seed = hash_combine(y, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopAsinhDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateAsinhDescriptor( + context::getInfiniopHandle(y->device()), &desc, + y->desc(), x->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetAsinhWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopAsinh( + desc, workspace->data(), workspace_size, + y->data(), x->data(), context::getStream())); +} + +static bool registered = []() { + Asinh::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::asinh_impl::infiniop diff --git a/src/infinicore/ops/baddbmm/baddbmm.cc b/src/infinicore/ops/baddbmm/baddbmm.cc new file mode 100644 index 000000000..3a8ee1518 --- /dev/null +++ b/src/infinicore/ops/baddbmm/baddbmm.cc @@ -0,0 +1,100 @@ +#include "infinicore/ops/baddbmm.hpp" +#include "infinicore/ops/gemm.hpp" +#include "infinicore/ops/rearrange.hpp" + +namespace infinicore::op { + +// 内联的 BLAS 兼容性检查,减少函数调用开销 +inline bool is_blas_compatible(const Tensor &t) { + const auto ndim = t->ndim(); + if (ndim == 2) { + const auto rs = t->stride(0); + const auto cs = t->stride(1); + if (rs != 1 && cs != 1) { + return false; + } + if (rs == 1 && cs == 1) { + return t->shape()[0] == 1 || t->shape()[1] == 1; + } + return true; + } else if (ndim == 3) { + const auto rs = t->stride(1); + const auto cs = t->stride(2); + if (t->shape()[0] > 1 && t->stride(0) == 0) { + return false; + } + if (rs != 1 && cs != 1) { + return false; + } + if (rs == 1 && cs == 1) { + return t->shape()[1] == 1 || t->shape()[2] == 1; + } + return true; + } + return false; +} + +inline void prepare_gemm_input(Tensor &output, Tensor &input, const size_t batch_size, const size_t m, const size_t n) { + const auto input_ndim = input->ndim(); + if (input_ndim == 2) { + rearrange_(output, input->as_strided( + {batch_size, m, n}, + {0, input->stride(0), input->stride(1)})); + } else if (input_ndim == 3 && input->shape()[0] == 1 && batch_size > 1) { + rearrange_(output, input->as_strided( + {batch_size, m, n}, + {0, input->stride(1), input->stride(2)})); + } else { + rearrange_(output, input); + } +} + +Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, + float beta, + float alpha) { + const size_t batch_size = batch1->shape()[0]; + const size_t m = batch1->shape()[1]; + const size_t n = batch2->shape()[2]; + + const Tensor &a = is_blas_compatible(batch1) ? batch1 : rearrange(batch1); + const Tensor &b = is_blas_compatible(batch2) ? batch2 : rearrange(batch2); + + if (beta == 0.0f) { + return gemm(a, b, alpha, 0.0f); + } + + Tensor result = Tensor::empty({batch_size, m, n}, a->dtype(), a->device()); + + prepare_gemm_input(result, input, batch_size, m, n); + + gemm_(result, a, b, alpha, beta); + return result; +} + +void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, + float beta, + float alpha) { + const size_t batch_size = batch1->shape()[0]; + const size_t m = batch1->shape()[1]; + const size_t n = batch2->shape()[2]; + + const Tensor &a = is_blas_compatible(batch1) ? batch1 : rearrange(batch1); + const Tensor &b = is_blas_compatible(batch2) ? batch2 : rearrange(batch2); + + const bool out_is_usable = out->is_contiguous() && out->ndim() == 3 && out->shape()[0] == batch_size && out->shape()[1] == m && out->shape()[2] == n; + + if (out_is_usable) { + if (beta != 0.0f && input->data() != out->data()) { + prepare_gemm_input(out, input, batch_size, m, n); + } + gemm_(out, a, b, alpha, beta); + } else { + Tensor result = Tensor::empty({batch_size, m, n}, a->dtype(), a->device()); + if (beta != 0.0f) { + prepare_gemm_input(result, input, batch_size, m, n); + } + gemm_(result, a, b, alpha, beta); + rearrange_(out, result); + } +} +} // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc new file mode 100644 index 000000000..ab88a28f9 --- /dev/null +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -0,0 +1,119 @@ +#include "infinicore/ops/bilinear.hpp" +#include "infinicore/ops/add.hpp" +#include "infinicore/ops/matmul.hpp" +#include "infinicore/ops/rearrange.hpp" + +#ifdef ENABLE_NVIDIA_API +namespace op::gemm::nvidia { +void set_tf32_enabled(bool); +} +#endif + +namespace infinicore::op { + +namespace { +// RAII 守卫:作用域内禁用 TF32 +struct ScopedTF32Disable { + ScopedTF32Disable() { +#ifdef ENABLE_NVIDIA_API + // 实际项目中建议添加检查,仅在 NVIDIA 设备上调用 + // 使用 ::op 强制从全局命名空间查找,避免被当前的 infinicore::op 遮蔽 + ::op::gemm::nvidia::set_tf32_enabled(false); +#endif + } + ~ScopedTF32Disable() { +#ifdef ENABLE_NVIDIA_API + ::op::gemm::nvidia::set_tf32_enabled(true); +#endif + } +}; + +inline bool is_gemm_compatible_3d(const Tensor &t) { + if (t->ndim() != 3) { + return false; + } + + const auto batch = t->shape()[0]; + const auto rows = t->shape()[1]; + const auto cols = t->shape()[2]; + const auto bs = t->stride(0); + const auto rs = t->stride(1); + const auto cs = t->stride(2); + + if (rs != 1 && cs != 1) { + return false; + } + + if (cs == 1) { + if (rs < static_cast(cols)) { + return false; + } + } else { + if (cs < static_cast(rows)) { + return false; + } + } + + if (batch > 1 && bs == 0) { + return false; + } + + return true; +} + +inline Tensor ensure_gemm_compatible(const Tensor &t) { + if (t->ndim() == 2) { + return t->is_contiguous() ? t : rearrange(t); + } else if (t->ndim() == 3) { + return is_gemm_compatible_3d(t) ? t : rearrange(t); + } + return t->is_contiguous() ? t : rearrange(t); +} + +} // anonymous namespace + +Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) { + ScopedTF32Disable tf32_guard; + + const size_t batch_size = x1->shape()[0]; + const size_t in1_features = x1->shape()[1]; + const size_t in2_features = x2->shape()[1]; + const size_t out_features = weight->shape()[0]; + + Tensor x1_compat = ensure_gemm_compatible(x1); + Tensor x2_compat = ensure_gemm_compatible(x2); + Tensor weight_cont = weight->is_contiguous() ? weight : weight->contiguous(); + + Tensor weight_permuted = weight_cont->permute({1, 0, 2}); + Tensor weight_permuted_cont = weight_permuted->is_contiguous() + ? weight_permuted + : weight_permuted->contiguous(); + Tensor weight_matrix = weight_permuted_cont->view({in1_features, out_features * in2_features}); + + Tensor intermediate = matmul(x1_compat, weight_matrix, 1.0f); + + Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); + Tensor intermediate_transposed = intermediate_3d->permute({0, 2, 1}); + Tensor intermediate_compat = ensure_gemm_compatible(intermediate_transposed); + + Tensor x2_row = x2_compat->view({batch_size, 1, in2_features}); + Tensor x2_row_compat = ensure_gemm_compatible(x2_row); + + Tensor out_3d = matmul(x2_row_compat, intermediate_compat, 1.0f); + Tensor out = out_3d->view({batch_size, out_features}); + + if (bias) { + Tensor bias_broadcast = (*bias)->as_strided( + {batch_size, out_features}, + {0, (*bias)->strides()[0]}); + out = add(out, bias_broadcast); + } + return out; +} + +void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { + Tensor result = bilinear(x1, x2, weight, bias); + rearrange_(out, result); +} + +} // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/ops/fmod/fmod.cc b/src/infinicore/ops/fmod/fmod.cc new file mode 100644 index 000000000..30bee17d6 --- /dev/null +++ b/src/infinicore/ops/fmod/fmod.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/fmod.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Fmod::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Fmod::execute(Tensor c, Tensor a, Tensor b) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); + infinicore::context::setDevice(c->device()); + dispatcher().lookup(c->device().getType())(c, a, b); +} + +Tensor fmod(Tensor a, Tensor b) { + auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); + fmod_(c, a, b); + return c; +} + +void fmod_(Tensor c, Tensor a, Tensor b) { + Fmod::execute(c, a, b); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/fmod/fmod_infiniop.cc b/src/infinicore/ops/fmod/fmod_infiniop.cc new file mode 100644 index 000000000..e796090d0 --- /dev/null +++ b/src/infinicore/ops/fmod/fmod_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/fmod.hpp" +#include + +namespace infinicore::op::fmod_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopFmodDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyFmodDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor c, Tensor a, Tensor b) { + size_t seed = hash_combine(c, b, a); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopFmodDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateFmodDescriptor( + context::getInfiniopHandle(c->device()), &desc, + c->desc(), a->desc(), b->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetFmodWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopFmod( + desc, workspace->data(), workspace_size, + c->data(), a->data(), b->data(), context::getStream())); +} + +static bool registered = []() { + Fmod::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::fmod_impl::infiniop \ No newline at end of file diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 978defa17..5d954c47f 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -2,10 +2,15 @@ #include +#include "ops/adaptive_max_pool1d.hpp" #include "ops/add.hpp" +#include "ops/asinh.hpp" #include "ops/attention.hpp" +#include "ops/baddbmm.hpp" +#include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" +#include "ops/fmod.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -21,9 +26,14 @@ namespace py = pybind11; namespace infinicore::ops { inline void bind(py::module &m) { + bind_adaptive_max_pool1d(m); bind_add(m); bind_attention(m); + bind_asinh(m); + bind_baddbmm(m); + bind_bilinear(m); bind_causal_softmax(m); + bind_fmod(m); bind_random_sample(m); bind_linear(m); bind_matmul(m); diff --git a/src/infinicore/pybind11/ops/adaptive_max_pool1d.hpp b/src/infinicore/pybind11/ops/adaptive_max_pool1d.hpp new file mode 100644 index 000000000..747d92b9a --- /dev/null +++ b/src/infinicore/pybind11/ops/adaptive_max_pool1d.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include "infinicore/ops/adaptive_max_pool1d.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_adaptive_max_pool1d(py::module &m) { + m.def("adaptive_max_pool1d", + &op::adaptive_max_pool1d, + py::arg("x"), + py::arg("output_size"), + R"doc(1D Adaptive Max Pooling. + +Args: + x: Input tensor of shape (N, C, L_in) or (N, L_in) + output_size: Target output size L_out +Returns: + Output tensor of shape (N, C, L_out) or (N, L_out) +)doc"); + + m.def("adaptive_max_pool1d_", + &op::adaptive_max_pool1d_, + py::arg("y"), + py::arg("x"), + py::arg("output_size"), + R"doc(In-place 1D Adaptive Max Pooling. + +Args: + y: Output tensor of shape (N, C, L_out) or (N, L_out) + x: Input tensor of shape (N, C, L_in) or (N, L_in) + output_size: Target output size L_out +)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infinicore/pybind11/ops/asinh.hpp b/src/infinicore/pybind11/ops/asinh.hpp new file mode 100644 index 000000000..bf1fcca23 --- /dev/null +++ b/src/infinicore/pybind11/ops/asinh.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/asinh.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_asinh(py::module &m) { + m.def("asinh", + &op::asinh, + py::arg("x"), + R"doc(Element-wise inverse hyperbolic sine function.)doc"); + + m.def("asinh_", + &op::asinh_, + py::arg("y"), + py::arg("x"), + R"doc(In-place element-wise inverse hyperbolic sine function.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/baddbmm.hpp b/src/infinicore/pybind11/ops/baddbmm.hpp new file mode 100644 index 000000000..3aef0ce20 --- /dev/null +++ b/src/infinicore/pybind11/ops/baddbmm.hpp @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include "infinicore/ops/baddbmm.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { + return op::baddbmm(input, batch1, batch2, beta, alpha); +} + +void py_baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { + op::baddbmm_(out, input, batch1, batch2, beta, alpha); +} + +inline void bind_baddbmm(py::module &m) { + m.def("baddbmm", + &py_baddbmm, + py::arg("input"), + py::arg("batch1"), + py::arg("batch2"), + py::arg("beta") = 1.0f, + py::arg("alpha") = 1.0f, + R"doc(Batched matrix-matrix product with addition. +Args: + input: Input tensor + batch1: First batch of matrices + batch2: Second batch of matrices + beta: Scaling factor for input tensor + alpha: Scaling factor for the product of batch1 and batch2 +Returns: + Output tensor after baddbmm operation +)doc"); + m.def("baddbmm_", + &py_baddbmm_, + py::arg("out"), + py::arg("input"), + py::arg("batch1"), + py::arg("batch2"), + py::arg("beta") = 1.0f, + py::arg("alpha") = 1.0f, + R"doc(In-place batched matrix-matrix product with addition. +Args: + out: Output tensor + input: Input tensor + batch1: First batch of matrices + batch2: Second batch of matrices + beta: Scaling factor for input tensor + alpha: Scaling factor for the product of batch1 and batch2 +)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infinicore/pybind11/ops/bilinear.hpp b/src/infinicore/pybind11/ops/bilinear.hpp new file mode 100644 index 000000000..9c8ff80d6 --- /dev/null +++ b/src/infinicore/pybind11/ops/bilinear.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include + +#include "infinicore/ops/bilinear.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_bilinear(Tensor x1, Tensor x2, Tensor weight, pybind11::object bias) { + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + return op::bilinear(x1, x2, weight, bias_tensor); +} + +void py_bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, pybind11::object bias) { + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + op::bilinear_(out, x1, x2, weight, bias_tensor); +} + +inline void bind_bilinear(py::module &m) { + m.def("bilinear", + &py_bilinear, + py::arg("x1"), + py::arg("x2"), + py::arg("weight"), + py::arg("bias"), + R"doc(Bilinear transformation of two input tensors. +Args: + x1: First input tensor + x2: Second input tensor + weight: Weight tensor + bias: Bias tensor (optional) +Returns: + Output tensor after bilinear transformation +)doc"); + + m.def("bilinear_", + &py_bilinear_, + py::arg("out"), + py::arg("x1"), + py::arg("x2"), + py::arg("weight"), + py::arg("bias"), + R"doc(In-place bilinear transformation of two input tensors. +Args: + out: Output tensor + x1: First input tensor + x2: Second input tensor + weight: Weight tensor + bias: Bias tensor (optional) +)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infinicore/pybind11/ops/fmod.hpp b/src/infinicore/pybind11/ops/fmod.hpp new file mode 100644 index 000000000..97af57da2 --- /dev/null +++ b/src/infinicore/pybind11/ops/fmod.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "infinicore/ops/fmod.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_fmod(py::module &m) { + m.def("fmod", + &op::fmod, + py::arg("a"), + py::arg("b"), + R"doc(Element-wise floating point remainder of division of two tensors.)doc"); + + m.def("fmod_", + &op::fmod_, + py::arg("c"), + py::arg("a"), + py::arg("b"), + R"doc(In-place element-wise floating point remainder of division of two tensors.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h b/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h new file mode 100644 index 000000000..288c2ece4 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h @@ -0,0 +1,47 @@ +#ifndef ADAPTIVE_MAX_POOL1D_H +#define ADAPTIVE_MAX_POOL1D_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::adaptive_max_pool1d::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AdaptiveMaxPool1dInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + AdaptiveMaxPool1dInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc, \ + size_t output_size); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream) const; \ + }; \ + } + +#endif // ADAPTIVE_MAX_POOL1D_H \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc new file mode 100644 index 000000000..69edf83bc --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc @@ -0,0 +1,98 @@ +#include "adaptive_max_pool1d_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../../reduce/cpu/reduce.h" +#include +#include + +namespace op::adaptive_max_pool1d::cpu { + +Descriptor::~Descriptor() {} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t adaptiveMaxPool1d(const AdaptiveMaxPool1dInfo *info, T *y, const T *x) { + + const size_t ndim = info->ndim(); + const size_t batch_size = info->shape[0]; + const size_t channels = ndim > 2 ? info->shape[1] : 1; + + const size_t input_length = info->input_length(); + const size_t output_length = info->output_length(); + + // 计算总的任务块数 (Batch * Channels) + const ptrdiff_t total_blocks = static_cast(batch_size * channels); + + const ptrdiff_t x_stride_last = info->x_strides.back(); + +#pragma omp parallel for + for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) { + const size_t i = block_idx / channels; // batch index + const size_t j = block_idx % channels; // channel index + + const T *x_ptr_base; + T *y_ptr_base; + + if (ndim > 2) { // (N, C, L) + x_ptr_base = x + i * info->x_strides[0] + j * info->x_strides[1]; + y_ptr_base = y + i * info->y_strides[0] + j * info->y_strides[1]; + } else { // (N, L) + x_ptr_base = x + i * info->x_strides[0]; + y_ptr_base = y + i * info->y_strides[0]; + } + + for (size_t out_idx = 0; out_idx < output_length; ++out_idx) { + // 计算池化窗口范围 [start_index, end_index) + // 公式参考 PyTorch: + // start = floor(out_idx * L_in / L_out) + // end = ceil((out_idx + 1) * L_in / L_out) + int start_index = std::floor((float)out_idx * input_length / output_length); + int end_index = std::ceil((float)(out_idx + 1) * input_length / output_length); + + start_index = std::max(start_index, 0); + end_index = std::min(end_index, (int)input_length); + int window_len = end_index - start_index; + + if (window_len <= 0) { + continue; + } + + const T *window_ptr = x_ptr_base + start_index * x_stride_last; + + auto max_val = op::common_cpu::reduce_op::max(window_ptr, window_len, x_stride_last); + y_ptr_base[out_idx] = utils::cast(max_val); + } + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + if (_info.atype == INFINI_DTYPE_F32) { + return adaptiveMaxPool1d(&_info, (float *)y, (const float *)x); + } else if (_info.atype == INFINI_DTYPE_F16) { + return adaptiveMaxPool1d(&_info, (fp16_t *)y, (const fp16_t *)x); + } else if (_info.atype == INFINI_DTYPE_BF16) { + return adaptiveMaxPool1d(&_info, (bf16_t *)y, (const bf16_t *)x); + } else if (_info.atype == INFINI_DTYPE_F64) { + return adaptiveMaxPool1d(&_info, (double *)y, (const double *)x); + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::adaptive_max_pool1d::cpu \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.h b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.h new file mode 100644 index 000000000..f3e8ced3c --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.h @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_CPU_H__ +#define __ADAPTIVE_MAX_POOL1D_CPU_H__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(cpu) + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh b/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh new file mode 100644 index 000000000..814688846 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh @@ -0,0 +1,54 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_CUDA_KERNEL_H__ +#define __ADAPTIVE_MAX_POOL1D_CUDA_KERNEL_H__ + +#include +#include + +template +__device__ void adaptiveMaxPool1dBlock( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim) { + + size_t block_idx = blockIdx.x; + size_t batch_idx = block_idx / channels; + size_t channel_idx = block_idx % channels; + + const Tdata *x_ptr; + Tdata *y_ptr; + + if (ndim > 2) { + x_ptr = x + batch_idx * stride_x_batch + channel_idx * stride_x_channel; + y_ptr = y + batch_idx * stride_y_batch + channel_idx * stride_y_channel; + } else { + x_ptr = x + batch_idx * stride_x_batch; + y_ptr = y + batch_idx * stride_y_batch; + } + + for (size_t out_idx = threadIdx.x; out_idx < output_length; out_idx += BLOCK_SIZE) { + int start_index = static_cast(floorf((float)out_idx * input_length / output_length)); + int end_index = static_cast(ceilf((float)(out_idx + 1) * input_length / output_length)); + + if (end_index <= start_index) { + continue; + } + + Tcompute max_val = Tcompute(x_ptr[start_index * stride_x_length]); + for (int i = start_index + 1; i < end_index; ++i) { + Tcompute val = Tcompute(x_ptr[i * stride_x_length]); + max_val = max(max_val, val); + } + + y_ptr[out_idx] = Tdata(max_val); + } +} + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/info.h b/src/infiniop/ops/adaptive_max_pool1d/info.h new file mode 100644 index 000000000..7194d2d93 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/info.h @@ -0,0 +1,65 @@ +#ifndef __ADAPATIVE_MAX_POOL1D_H__ +#define __ADAPATIVE_MAX_POOL1D_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +namespace op::adaptive_max_pool1d { + +class AdaptiveMaxPool1dInfo { + AdaptiveMaxPool1dInfo() = default; + +public: + infiniDtype_t atype; + std::vector shape; + std::vector y_strides; + std::vector x_strides; + size_t input_size; + size_t output_size; + size_t ndim() const { return shape.size(); } + size_t input_length() const { return input_size; } + size_t output_length() const { return output_size; } + + static utils::Result create( + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + + auto atype = y_desc->dtype(); + if (x_desc->dtype() != atype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (atype != INFINI_DTYPE_F16 && atype != INFINI_DTYPE_BF16 && atype != INFINI_DTYPE_F32 && atype != INFINI_DTYPE_F64) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + const size_t y_ndim = y_desc->ndim(); + const size_t x_ndim = x_desc->ndim(); + + if (y_ndim != x_ndim) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + for (size_t i = 0; i < y_ndim - 1; ++i) { + if (x_desc->dim(i) != y_desc->dim(i)) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (y_desc->dim(y_ndim - 1) != output_size) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + return utils::Result(AdaptiveMaxPool1dInfo{ + atype, + y_desc->shape(), + y_desc->strides(), + x_desc->strides(), + x_desc->dim(x_ndim - 1), + output_size}); + } +}; +} // namespace op::adaptive_max_pool1d + +#endif // __ADAPATIVE_MAX_POOL1D_H__ \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.cuh b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.cuh new file mode 100644 index 000000000..fcd068b6d --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_METAX_CUH__ +#define __ADAPTIVE_MAX_POOL1D_METAX_CUH__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(metax) + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca new file mode 100644 index 000000000..f72aae852 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca @@ -0,0 +1,130 @@ +#include "../../../devices/metax/metax_common.h" +#include "adaptive_max_pool1d_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_METAX_KERNEL adaptiveMaxPool1dKernel( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim) { + + adaptiveMaxPool1dBlock( + y, stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length,ndim); +} + +namespace op::adaptive_max_pool1d::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor(){ + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + uint32_t numblock, + void *y, infiniDtype_t dtype, + ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, + const void *x, + ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, + size_t channels, size_t input_length, size_t output_length, size_t ndim, + hcStream_t stream){ + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + adaptiveMaxPool1dKernel<<>> ( \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ + channels, input_length, output_length, ndim) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__hpcc_bfloat16, float); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float); + } else if (dtype == INFINI_DTYPE_F64) { + LAUNCH_KERNEL(double, double); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream_) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + const size_t ndim = _info.ndim(); + const size_t batch_size = _info.shape[0]; + const size_t channels = ndim > 2 ? _info.shape[1] : 1; + const size_t input_length = _info.input_length(); + const size_t output_length = _info.output_length(); + + ptrdiff_t stride_x_batch = _info.x_strides[0]; + ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; + ptrdiff_t stride_x_length = _info.x_strides.back(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; + + uint32_t num_blocks = static_cast(batch_size * channels); + auto stream = reinterpret_cast(stream_); + + if (_opaque->internal->maxThreadsPerBlock() >= METAX_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::adaptive_max_pool1d::metax \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h new file mode 100644 index 000000000..c56ad6fd4 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_MOOORE_H__ +#define __ADAPTIVE_MAX_POOL1D_MOOORE_H__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(moore) + +#endif diff --git a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu new file mode 100644 index 000000000..256392f78 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu @@ -0,0 +1,144 @@ +#include "../../../devices/moore/moore_common.h" +#include "adaptive_max_pool1d_moore.h" + +#include "../../../devices/moore/moore_kernel_common.h" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_MOORE_KERNEL adaptiveMaxPool1dKernel( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim){ + + adaptiveMaxPool1dBlock( + y, stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim); +} + +namespace op::adaptive_max_pool1d::moore { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + uint32_t num_blocks, + void *y, infiniDtype_t dtype, + ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, + const void *x, + ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, + size_t channels, size_t input_length, size_t output_length, size_t ndim, + musaStream_t musa_stream) { + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + adaptiveMaxPool1dKernel<<>>( \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ + channels, input_length, output_length, ndim) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__mt_bfloat16, float); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float); + } else if (dtype == INFINI_DTYPE_F64) { + LAUNCH_KERNEL(double, double); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + const size_t ndim = _info.ndim(); + const size_t batch_size = _info.shape[0]; + const size_t channels = ndim > 2 ? _info.shape[1] : 1; + const size_t input_length = _info.input_length(); + const size_t output_length = _info.output_length(); + + ptrdiff_t stride_x_batch = _info.x_strides[0]; + ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; + ptrdiff_t stride_x_length = _info.x_strides.back(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; + + uint32_t num_blocks = static_cast(batch_size * channels); + auto musa_stream = reinterpret_cast(stream); + + if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::adaptive_max_pool1d::moore \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu new file mode 100644 index 000000000..96ffe573f --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu @@ -0,0 +1,144 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "adaptive_max_pool1d_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_CUDA_KERNEL adaptiveMaxPool1dKernel( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim) { + + adaptiveMaxPool1dBlock( + y, stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim); +} + +namespace op::adaptive_max_pool1d::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + uint32_t num_blocks, + void *y, infiniDtype_t dtype, + ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, + const void *x, + ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, + size_t channels, size_t input_length, size_t output_length, size_t ndim, + cudaStream_t cuda_stream) { + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + adaptiveMaxPool1dKernel<<>>( \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ + channels, input_length, output_length, ndim) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__nv_bfloat16, float); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float); + } else if (dtype == INFINI_DTYPE_F64) { + LAUNCH_KERNEL(double, double); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + const size_t ndim = _info.ndim(); + const size_t batch_size = _info.shape[0]; + const size_t channels = ndim > 2 ? _info.shape[1] : 1; + const size_t input_length = _info.input_length(); + const size_t output_length = _info.output_length(); + + ptrdiff_t stride_x_batch = _info.x_strides[0]; + ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; + ptrdiff_t stride_x_length = _info.x_strides.back(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; + + uint32_t num_blocks = static_cast(batch_size * channels); + auto cuda_stream = reinterpret_cast(stream); + + if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + cuda_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::adaptive_max_pool1d::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cuh b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cuh new file mode 100644 index 000000000..b980ce269 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_CUDA_H__ +#define __ADAPTIVE_MAX_POOL1D_CUDA_H__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(nvidia) + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/operator.cc b/src/infiniop/ops/adaptive_max_pool1d/operator.cc new file mode 100644 index 000000000..7048a1033 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/operator.cc @@ -0,0 +1,147 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/adaptive_max_pool1d.h" + +#ifdef ENABLE_CPU_API +#include "cpu/adaptive_max_pool1d_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/adaptive_max_pool1d_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/adaptive_max_pool1d_metax.cuh" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/adaptive_max_pool1d_moore.h" +#endif + +__C infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( + infiniopHandle_t handle, + infiniopAdaptiveMaxPool1dDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::adaptive_max_pool1d::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + x_desc, \ + output_size) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef CREATE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize( + infiniopAdaptiveMaxPool1dDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopAdaptiveMaxPool1d( + infiniopAdaptiveMaxPool1dDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, y, x, stream); + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef CALCULATE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopDestroyAdaptiveMaxPool1dDescriptor( + infiniopAdaptiveMaxPool1dDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef DESTROY + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} \ No newline at end of file diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.cc b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc new file mode 100644 index 000000000..4d7627473 --- /dev/null +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc @@ -0,0 +1,50 @@ +#include "asinh_cpu.h" + +namespace op::asinh::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::asinh::cpu \ No newline at end of file diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.h b/src/infiniop/ops/asinh/cpu/asinh_cpu.h new file mode 100644 index 000000000..076fcb30a --- /dev/null +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.h @@ -0,0 +1,22 @@ +#ifndef __ASINH_CPU_H__ +#define __ASINH_CPU_H__ + +#include + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(asinh, cpu) + +namespace op::asinh::cpu { +typedef struct AsinhOp { +public: + static constexpr size_t num_inputs = 1; + + template + T operator()(const T &x) const { + return std::asinh(x); + } +} AsinhOp; +} // namespace op::asinh::cpu + +#endif // __ASINH_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/cuda/kernel.cuh b/src/infiniop/ops/asinh/cuda/kernel.cuh new file mode 100644 index 000000000..2bd6dcbf0 --- /dev/null +++ b/src/infiniop/ops/asinh/cuda/kernel.cuh @@ -0,0 +1,29 @@ +#ifndef __ASINH_CUDA_KERNEL_H__ +#define __ASINH_CUDA_KERNEL_H__ + +namespace op::asinh::cuda { + +typedef struct AsinhOp { +public: + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + + if constexpr (std::is_same_v) { + float x_f = __half2float(x); + return __float2half(asinhf(x_f)); + } else if constexpr (std::is_same_v) { + float x_f = __bfloat162float(x); + return __float2bfloat16(asinhf(x_f)); + } else if constexpr (std::is_same_v) { + return asinhf(x); + } else { + return ::asinh(x); + } + } + +} AsinhOp; + +} // namespace op::asinh::cuda + +#endif // __ASINH_CUDA_KERNEL_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/metax/asinh.maca b/src/infiniop/ops/asinh/metax/asinh.maca new file mode 100644 index 000000000..f6f4ac3f9 --- /dev/null +++ b/src/infiniop/ops/asinh/metax/asinh.maca @@ -0,0 +1,58 @@ +#include "asinh_metax.h" +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::asinh::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::AsinhOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::AsinhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::AsinhOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::AsinhOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::asinh::metax diff --git a/src/infiniop/ops/asinh/metax/asinh_metax.h b/src/infiniop/ops/asinh/metax/asinh_metax.h new file mode 100644 index 000000000..dacb77f0d --- /dev/null +++ b/src/infiniop/ops/asinh/metax/asinh_metax.h @@ -0,0 +1,8 @@ +#ifndef __ASINH_METAX_API_H__ +#define __ASINH_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(asinh, metax) + +#endif // __ASINH_METAX_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/moore/asinh_moore.h b/src/infiniop/ops/asinh/moore/asinh_moore.h new file mode 100644 index 000000000..36c93d53a --- /dev/null +++ b/src/infiniop/ops/asinh/moore/asinh_moore.h @@ -0,0 +1,8 @@ +#ifndef __ASINH_MOORE_API_H__ +#define __ASINH_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(asinh, moore) + +#endif // __ASINH_MOORE_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/moore/asinh_moore.mu b/src/infiniop/ops/asinh/moore/asinh_moore.mu new file mode 100644 index 000000000..35a8d6475 --- /dev/null +++ b/src/infiniop/ops/asinh/moore/asinh_moore.mu @@ -0,0 +1,59 @@ +#include "asinh_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "../cuda/kernel.cuh" + +namespace op::asinh::moore { +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create MOORE elementwise descriptor + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::AsinhOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::AsinhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::AsinhOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::AsinhOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::asinh::moore \ No newline at end of file diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu new file mode 100644 index 000000000..77a4652bc --- /dev/null +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu @@ -0,0 +1,56 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" + +#include "../cuda/kernel.cuh" +#include "asinh_nvidia.cuh" + +namespace op::asinh::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::AsinhOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::AsinhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::AsinhOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::AsinhOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::asinh::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh new file mode 100644 index 000000000..5b75a553c --- /dev/null +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ASINH_NVIDIA_API_H__ +#define __ASINH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(asinh, nvidia) + +#endif // __ASINH_NVIDIA_API_H \ No newline at end of file diff --git a/src/infiniop/ops/asinh/operator.cc b/src/infiniop/ops/asinh/operator.cc new file mode 100644 index 000000000..e3decacf1 --- /dev/null +++ b/src/infiniop/ops/asinh/operator.cc @@ -0,0 +1,141 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/asinh.h" + +#ifdef ENABLE_CPU_API +#include "cpu/asinh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/asinh_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/asinh_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/asinh_moore.h" +#endif + +__C infiniStatus_t infiniopCreateAsinhDescriptor( + infiniopHandle_t handle, + infiniopAsinhDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::asinh::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc}) + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t infiniopGetAsinhWorkspaceSize(infiniopAsinhDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopAsinh(infiniopAsinhDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream); + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyAsinhDescriptor(infiniopAsinhDescriptor_t desc) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} \ No newline at end of file diff --git a/src/infiniop/ops/fmod/cpu/fmod_cpu.cc b/src/infiniop/ops/fmod/cpu/fmod_cpu.cc new file mode 100644 index 000000000..1f27290de --- /dev/null +++ b/src/infiniop/ops/fmod/cpu/fmod_cpu.cc @@ -0,0 +1,53 @@ +#include "fmod_cpu.h" + +namespace op::fmod::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &out_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(out_shape, a_shape, b_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::fmod::cpu diff --git a/src/infiniop/ops/fmod/cpu/fmod_cpu.h b/src/infiniop/ops/fmod/cpu/fmod_cpu.h new file mode 100644 index 000000000..54af25540 --- /dev/null +++ b/src/infiniop/ops/fmod/cpu/fmod_cpu.h @@ -0,0 +1,19 @@ +#ifndef _FMOD_CPU_H__ +#define _FMOD_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(fmod, cpu) + +namespace op::fmod::cpu { +typedef struct FmodOp { +public: + static constexpr size_t num_inputs = 2; + template + T operator()(const T &a, const T &b) const { + return std::fmod(a, b); + } +} FmodOp; +} // namespace op::fmod::cpu + +#endif // _FMOD_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/cuda/kernel.cuh b/src/infiniop/ops/fmod/cuda/kernel.cuh new file mode 100644 index 000000000..6e30ed25e --- /dev/null +++ b/src/infiniop/ops/fmod/cuda/kernel.cuh @@ -0,0 +1,48 @@ +#ifndef __FMOD_CUDA_H__ +#define __FMOD_CUDA_H__ + +namespace op::fmod::cuda { +typedef struct FmodOp { + static constexpr size_t num_inputs = 2; + template + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + // fmod(a, b) = a - b * trunc(a / b) + if constexpr (std::is_same_v) { + // 对于 half2,转换为 float 计算后再转回 + float2 af = __half22float2(a); + float2 bf = __half22float2(b); + float2 result; + result.x = ::fmodf(af.x, bf.x); + result.y = ::fmodf(af.y, bf.y); + return __float22half2_rn(result); + } else if constexpr (std::is_same_v) { + // 对于 bfloat162,转换为 float 计算后再转回 + float af_low = __bfloat162float(__low2bfloat16(a)); + float af_high = __bfloat162float(__high2bfloat16(a)); + float bf_low = __bfloat162float(__low2bfloat16(b)); + float bf_high = __bfloat162float(__high2bfloat16(b)); + return __floats2bfloat162_rn(::fmodf(af_low, bf_low), ::fmodf(af_high, bf_high)); + } else if constexpr (std::is_same_v) { + // 对于 half,转换为 float 计算后再转回 + float af = __half2float(a); + float bf = __half2float(b); + return __float2half(::fmodf(af, bf)); + } else if constexpr (std::is_same_v) { + // 对于 bfloat16,转换为 float 计算后再转回 + float af = __bfloat162float(a); + float bf = __bfloat162float(b); + return __float2bfloat16(::fmodf(af, bf)); + } else if constexpr (std::is_same_v) { + return ::fmodf(a, b); + } else if constexpr (std::is_same_v) { + return ::fmod(a, b); + } else { + // 整数类型使用 % 运算符 + return a % b; + } + } +} FmodOp; + +} // namespace op::fmod::cuda + +#endif // __FMOD_CUDA_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/metax/fmod_metax.h b/src/infiniop/ops/fmod/metax/fmod_metax.h new file mode 100644 index 000000000..ad5769231 --- /dev/null +++ b/src/infiniop/ops/fmod/metax/fmod_metax.h @@ -0,0 +1,8 @@ +#ifndef __FMOD_METAX_API_H__ +#define __FMOD_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(fmod, metax) + +#endif // __FMOD_METAX_API_H__ diff --git a/src/infiniop/ops/fmod/metax/mul_metax.maca b/src/infiniop/ops/fmod/metax/mul_metax.maca new file mode 100644 index 000000000..c9d54ad62 --- /dev/null +++ b/src/infiniop/ops/fmod/metax/mul_metax.maca @@ -0,0 +1,61 @@ +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +#include "fmod_metax.h" + +namespace op::fmod::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::fmod::metax diff --git a/src/infiniop/ops/fmod/moore/fmod_moore.h b/src/infiniop/ops/fmod/moore/fmod_moore.h new file mode 100644 index 000000000..b24c337a8 --- /dev/null +++ b/src/infiniop/ops/fmod/moore/fmod_moore.h @@ -0,0 +1,8 @@ +#ifndef __FMOD_MOORE_API_H__ +#define __FMOD_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(fmod, moore) + +#endif // __FMOD_MOORE_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/moore/fmod_moore.mu b/src/infiniop/ops/fmod/moore/fmod_moore.mu new file mode 100644 index 000000000..0c37da459 --- /dev/null +++ b/src/infiniop/ops/fmod/moore/fmod_moore.mu @@ -0,0 +1,63 @@ +#include "fmod_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "../cuda/kernel.cuh" + +namespace op::fmod::moore { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // create MOORE elementwise descriptor + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::fmod::moore \ No newline at end of file diff --git a/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu new file mode 100644 index 000000000..a74295264 --- /dev/null +++ b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu @@ -0,0 +1,59 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" +#include "../cuda/kernel.cuh" +#include "fmod_nvidia.cuh" + +namespace op::fmod::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::fmod::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh new file mode 100644 index 000000000..e40d0088d --- /dev/null +++ b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __MUL_CUDA_API_H__ +#define __MUL_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(fmod, nvidia) + +#endif // __MUL_CUDA_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/operator.cc b/src/infiniop/ops/fmod/operator.cc new file mode 100644 index 000000000..1fd433c4a --- /dev/null +++ b/src/infiniop/ops/fmod/operator.cc @@ -0,0 +1,152 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/fmod.h" + +#ifdef ENABLE_CPU_API +#include "cpu/fmod_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/fmod_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/fmod_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/fmod_moore.h" +#endif + +__C infiniStatus_t infiniopCreateFmodDescriptor( + infiniopHandle_t handle, + infiniopFmodDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::fmod::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + {a_desc, \ + b_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetFmodWorkspaceSize(infiniopFmodDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopFmod( + infiniopFmodDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, {a, b}, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFmodDescriptor(infiniopFmodDescriptor_t desc) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} diff --git a/src/infiniop/ops/gemm/cpu/gemm_cpu.cc b/src/infiniop/ops/gemm/cpu/gemm_cpu.cc index d19965614..6f7a2e3e0 100644 --- a/src/infiniop/ops/gemm/cpu/gemm_cpu.cc +++ b/src/infiniop/ops/gemm/cpu/gemm_cpu.cc @@ -64,7 +64,11 @@ void calculate( *c_ = utils::cast(beta * utils::cast(*c_) + alpha * sum); } } else { - *c_ = beta * (*c_) + alpha * sum; + if (beta == 0) { + *c_ = alpha * sum; + } else { + *c_ = beta * (*c_) + alpha * sum; + } } } } diff --git a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu index 0e0c65f2b..580cca658 100644 --- a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu +++ b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu @@ -3,6 +3,14 @@ namespace op::gemm::nvidia { +// 添加线程局部控制开关 +thread_local bool g_tf32_enabled = true; + +// 暴露设置函数(非静态,以便外部链接) +void set_tf32_enabled(bool enabled) { + g_tf32_enabled = enabled; +} + struct Descriptor::Opaque { std::shared_ptr internal; }; @@ -71,7 +79,8 @@ infiniStatus_t Descriptor::calculate( #if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) compute_type = CUDA_R_32F; #else - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + // compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = g_tf32_enabled ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; #endif break; diff --git a/src/utils/rearrange.cc b/src/utils/rearrange.cc index 7465302d0..48c3c4ed9 100644 --- a/src/utils/rearrange.cc +++ b/src/utils/rearrange.cc @@ -144,13 +144,26 @@ void rearrange( utils::Result RearrangeMeta::distributeUnit(const std::vector &candidates) const { // 获取当前的unit大小 size_t current_unit = _meta[0]; + size_t ndim_value = this->ndim(); - // 寻找满足条件的unit值:当前unit能被其整除 + // 寻找满足条件的unit值:当前unit能被其整除,且所有strides也能被其整除 size_t new_unit = 0; for (size_t candidate : candidates) { if (current_unit % candidate == 0) { - new_unit = candidate; - break; + // 检查所有 strides 是否都能被 candidate 整除(确保内存对齐) + bool strides_aligned = true; + for (size_t i = 0; i < ndim_value; ++i) { + ptrdiff_t dst_stride = std::abs(dst_strides()[i]); + ptrdiff_t src_stride = std::abs(src_strides()[i]); + if (dst_stride % candidate != 0 || src_stride % candidate != 0) { + strides_aligned = false; + break; + } + } + if (strides_aligned) { + new_unit = candidate; + break; + } } } @@ -164,9 +177,6 @@ utils::Result RearrangeMeta::distributeUnit(const std::vector(_meta); } - // 获取当前维度 - size_t ndim_value = this->ndim(); - // 创建新的布局数组 std::vector layout(2 + (ndim_value + 1) * 3, 0); diff --git a/test/infinicore/ops/adaptive_max_pool1d.py b/test/infinicore/ops/adaptive_max_pool1d.py index 0e683b4f1..e3aa89d3b 100644 --- a/test/infinicore/ops/adaptive_max_pool1d.py +++ b/test/infinicore/ops/adaptive_max_pool1d.py @@ -67,9 +67,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.adaptive_max_pool1d(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.adaptive_max_pool1d(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.nn.functional.adaptive_max_pool1d(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/asinh.py b/test/infinicore/ops/asinh.py index 97bcd5edb..715cded1b 100644 --- a/test/infinicore/ops/asinh.py +++ b/test/infinicore/ops/asinh.py @@ -97,9 +97,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.asinh(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.asinh(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.asinh(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/baddbmm.py b/test/infinicore/ops/baddbmm.py index 35b4cd625..bc3e79e45 100644 --- a/test/infinicore/ops/baddbmm.py +++ b/test/infinicore/ops/baddbmm.py @@ -99,9 +99,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.baddbmm(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.baddbmm(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.baddbmm(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/bilinear.py b/test/infinicore/ops/bilinear.py index da31e0b74..2b9970a4c 100644 --- a/test/infinicore/ops/bilinear.py +++ b/test/infinicore/ops/bilinear.py @@ -44,11 +44,17 @@ def parse_test_cases(): in2 = TensorSpec.from_tensor(in2_shape, in2_strides, dtype) weight = TensorSpec.from_tensor(weight_shape, weight_strides, dtype) + inputs = [in1, in2, weight] + if bias_present: + bias_shape = (weight_shape[0],) + bias = TensorSpec.from_tensor(bias_shape, None, dtype) + inputs.append(bias) + kwargs = {} test_cases.append( TestCase( - inputs=[in1, in2, weight], + inputs=inputs, kwargs=kwargs, output_spec=None, comparison_target=None, @@ -72,9 +78,10 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.bilinear(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.bilinear(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + from infinicore.ops.bilinear import bilinear + + return bilinear(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/fmod.py b/test/infinicore/ops/fmod.py index 6ef862154..d41ef8ef6 100644 --- a/test/infinicore/ops/fmod.py +++ b/test/infinicore/ops/fmod.py @@ -103,9 +103,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.fmod(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.fmod(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.fmod(*args, **kwargs) def main():