From 1ba0bcfadae445cccd8206003a529e8357041e63 Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Tue, 30 Dec 2025 15:18:14 +0800 Subject: [PATCH] issue/848 - feat: add paged attention prefill for nvidia gpu with test pass --- include/infiniop.h | 5 +- .../infiniop/ops/paged_attention_prefill.h | 83 +++++ .../paged_attention_prefill/cuda/kernel.cuh | 134 ++++++++ .../ops/paged_attention_prefill/info.h | 107 ++++++ .../nvidia/paged_attention_prefill_nvidia.cu | 136 ++++++++ .../nvidia/paged_attention_prefill_nvidia.cuh | 8 + .../ops/paged_attention_prefill/operator.cc | 95 ++++++ .../paged_attention_prefill.h | 56 ++++ test/infiniop/libinfiniop/op_register.py | 48 +++ test/infiniop/paged_attention_prefill.py | 315 ++++++++++++++++++ test/infiniop/paged_caching_prefill.py | 250 ++++++++++++++ 11 files changed, 1235 insertions(+), 2 deletions(-) create mode 100644 include/infiniop/ops/paged_attention_prefill.h create mode 100644 src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh create mode 100644 src/infiniop/ops/paged_attention_prefill/info.h create mode 100644 src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu create mode 100644 src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh create mode 100644 src/infiniop/ops/paged_attention_prefill/operator.cc create mode 100644 src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h create mode 100644 test/infiniop/paged_attention_prefill.py create mode 100644 test/infiniop/paged_caching_prefill.py diff --git a/include/infiniop.h b/include/infiniop.h index 97f3bdaee..ccdab09c3 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -15,6 +15,9 @@ #include "infiniop/ops/lp_norm.h" #include "infiniop/ops/mul.h" #include "infiniop/ops/ones.h" +#include "infiniop/ops/paged_attention.h" +#include "infiniop/ops/paged_attention_prefill.h" +#include "infiniop/ops/paged_caching.h" #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" @@ -31,7 +34,5 @@ #include "infiniop/ops/topksoftmax.h" #include "infiniop/ops/zeros.h" #include "infiniop/tensor_descriptor.h" -#include "infiniop/ops/paged_attention.h" -#include "infiniop/ops/paged_caching.h" #endif // __INFINIOP_API_H__ diff --git a/include/infiniop/ops/paged_attention_prefill.h b/include/infiniop/ops/paged_attention_prefill.h new file mode 100644 index 000000000..af10b1fc8 --- /dev/null +++ b/include/infiniop/ops/paged_attention_prefill.h @@ -0,0 +1,83 @@ +#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__ +#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__ + +#include "../operator_descriptor.h" + +// Define an opaque handle for the Paged Attention Prefill descriptor. +typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t; + +/** + * @brief Creates a descriptor for the Paged Attention Prefill operation. + * @param handle The handle to the InfiniOP library context. + * @param desc_ptr A pointer to store the created descriptor. + * @param out_desc Descriptor for the output tensor. + * @param q_desc Descriptor for the query tensor (packed/flattened). + * @param k_cache_desc Descriptor for the global physical key cache. + * @param v_cache_desc Descriptor for the global physical value cache. + * @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks. + * @param cache_lens_desc Descriptor for the total sequence lengths (history + current). + * @param seq_lens_desc Descriptor for the current prefill sequence lengths. + * @param offset_desc Descriptor for the start position of each sequence in the packed Q tensor. + * @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL. + * @param scale The attention scaling factor. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( + infiniopHandle_t handle, + infiniopPagedAttentionPrefillDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t offset_desc, + infiniopTensorDescriptor_t alibi_slopes_desc, + float scale); + +/** + * @brief Retrieves the workspace size required for the Paged Attention Prefill operation. + */ +__C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( + infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size); + +/** + * @brief Executes the Paged Attention Prefill operation. + * @param desc The Paged Attention Prefill descriptor. + * @param workspace Pointer to the workspace memory. + * @param workspace_size The size of the workspace. + * @param out Pointer to the output tensor data. + * @param q Pointer to the query tensor data (packed). + * @param k_cache Pointer to the global key cache data. + * @param v_cache Pointer to the global value cache data. + * @param block_tables Pointer to the block tables data. + * @param cache_lens Pointer to the total sequence lengths data. + * @param seq_lens Pointer to the current prefill sequence lengths data. + * @param offset Pointer to the sequence start offsets data. + * @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL. + * @param stream The CUDA/device stream for the operation. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopPagedAttentionPrefill( + infiniopPagedAttentionPrefillDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *cache_lens, + const void *seq_lens, + const void *offset, + const void *alibi_slopes, + void *stream); + +/** + * @brief Destroys a Paged Attention Prefill descriptor. + */ +__C __export infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( + infiniopPagedAttentionPrefillDescriptor_t desc); + +#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh new file mode 100644 index 000000000..ec9aad40c --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh @@ -0,0 +1,134 @@ +#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ +#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ + +namespace op::paged_attention_prefill::cuda { + +// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence +__device__ __forceinline__ int find_seq_id(int token_idx, const int64_t *offset, int num_seqs) { + int low = 0, high = num_seqs - 1; + while (low <= high) { + int mid = (low + high) >> 1; + if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) { + return mid; + } else if (token_idx < offset[mid]) { + high = mid - 1; + } else { + low = mid + 1; + } + } + return 0; +} + +template +__global__ void pagedAttentionPrefillKernel( + Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_, + const int64_t *block_tables_, const int64_t *cache_lens_, const int64_t *seq_lens_, + const float *alibi_slopes_, + const size_t num_heads, const size_t num_kv_heads, const float scale, + const size_t max_num_blocks_per_seq, const size_t block_size, + const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride, + const size_t head_size, + const int64_t *offset_, + const size_t num_seqs) { + + // --- 使用 2D Grid 坐标 --- + const int global_token_idx = blockIdx.x; // 展平后的全局 token 索引 + const int head_idx = blockIdx.y; // Head 索引 + const int dim_idx = threadIdx.x; // Head 内部维度 + + if (dim_idx >= head_size) { + return; + } + + // --- 通过二分查找 offset 找到所属的 seq_idx --- + int seq_idx = find_seq_id(global_token_idx, offset_, num_seqs); + + // --- 获取该 Sequence 本次 Prefill 的长度 + const int64_t cur_new_len = seq_lens_[seq_idx]; + + // --- 该 token 在当前序列中的相对位置 + int q_token_idx = global_token_idx - offset_[seq_idx]; + + const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size; + Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; + + // --- KV Cache 相关信息 + const int64_t total_seq_len = cache_lens_[seq_idx]; + const int64_t history_len = total_seq_len - cur_new_len; + const int64_t causal_limit = history_len + q_token_idx; + + const size_t num_queries_per_kv = num_heads / num_kv_heads; + const size_t kv_head_idx = head_idx / num_queries_per_kv; + const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + + // Pass 1: 计算 Score 并找最大值 + Tcompute max_score = -FLT_MAX; + for (int t = 0; t <= causal_limit; ++t) { + const int64_t b_idx = t / block_size; + const int64_t t_off = t % block_size; + const int64_t physical_block_id = block_table[b_idx]; + const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; + + Tcompute score = 0.0f; + for (int d = 0; d < head_size; ++d) { + score += static_cast(q_ptr_base[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += alibi_slope * static_cast(t - causal_limit); + } + if (score > max_score) { + max_score = score; + } + } + + // Pass 2: 计算 Sum of Exp + Tcompute sum_exp = 0.0f; + for (int t = 0; t <= causal_limit; ++t) { + const int64_t b_idx = t / block_size; + const int64_t t_off = t % block_size; + const int64_t physical_block_id = block_table[b_idx]; + const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; + + Tcompute score = 0.0f; + for (int d = 0; d < head_size; ++d) { + score += static_cast(q_ptr_base[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += alibi_slope * static_cast(t - causal_limit); + } + sum_exp += expf(static_cast(score - max_score)); + } + + // Pass 3: 加权求和得到输出 + Tcompute acc = 0.0f; + Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f); + for (int t = 0; t <= causal_limit; ++t) { + const int64_t b_idx = t / block_size; + const int64_t t_off = t % block_size; + const int64_t physical_block_id = block_table[b_idx]; + + const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; + Tcompute score = 0.0f; + for (int d = 0; d < head_size; ++d) { + score += static_cast(q_ptr_base[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += alibi_slope * static_cast(t - causal_limit); + } + Tcompute prob = expf(static_cast(score - max_score)) * inv_sum; + + const Tdata *v_vec = v_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; + acc += prob * static_cast(v_vec[dim_idx]); + } + + out_ptr[dim_idx] = static_cast(acc); +} + +} // namespace op::paged_attention_prefill::cuda + +#endif diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h new file mode 100644 index 000000000..39c6b5715 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -0,0 +1,107 @@ +#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__ +#define __PAGED_ATTENTION_PREFILL_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include +#include + +namespace op::paged_attention_prefill { + +class PagedAttentionPrefillInfo { + PagedAttentionPrefillInfo() = default; + +public: + infiniDtype_t dtype; + float scale; + + size_t num_seqs; + size_t num_heads; + size_t num_kv_heads; + size_t head_size; + size_t block_size; + size_t max_num_blocks_per_seq; + size_t total_q_tokens; + + ptrdiff_t q_stride; + ptrdiff_t kv_block_stride; + ptrdiff_t kv_head_stride; + ptrdiff_t o_stride; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t offset_desc, + const std::optional &alibi_slopes_desc, + float scale) { + + auto dtype = q_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + + if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (offset_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) { + std::cerr << "[Error] PagedAttentionPrefill: ALiBi slopes are not supported yet." << std::endl; + return INFINI_STATUS_BAD_PARAM; + } + + // Q shape: [total_tokens, heads, dim] (3D) + auto q_shape = q_desc->shape(); + if (q_shape.size() < 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + size_t total_q_tokens = q_shape[0]; + + size_t num_heads = q_shape[q_shape.size() - 2]; + size_t head_size = q_shape[q_shape.size() - 1]; + + if (head_size != 128) { + std::cerr << "[Error] PagedAttentionPrefill head_size = 128 supported, got " << head_size << std::endl; + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // 从 seq_lens 获取 num_seqs + size_t num_seqs = seq_lens_desc->shape()[0]; + + auto k_cache_shape = k_cache_desc->shape(); + size_t num_kv_heads = k_cache_shape[1]; + size_t block_size = v_cache_desc->shape()[2]; + size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + + // 提取步长,需要保持多个请求的 Q 连续 + ptrdiff_t q_stride = q_desc->stride(0); + ptrdiff_t kv_block_stride = k_cache_desc->stride(0); + ptrdiff_t kv_head_stride = k_cache_desc->stride(1); + ptrdiff_t o_stride = out_desc->stride(0); + + return utils::Result(PagedAttentionPrefillInfo{ + dtype, + scale, + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_num_blocks_per_seq, + total_q_tokens, + q_stride, + kv_block_stride, + kv_head_stride, + o_stride}); + } +}; + +} // namespace op::paged_attention_prefill + +#endif diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu new file mode 100644 index 000000000..02ed47186 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -0,0 +1,136 @@ +#include +#include +#include +#include + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/kernel.cuh" +#include "paged_attention_prefill_nvidia.cuh" + +// ============================================================================== +// Host wrapper to launch the global kernel +// ============================================================================== +template +infiniStatus_t launchPagedAttentionPrefill( + Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, + const int64_t *block_tables, const int64_t *cache_lens, const int64_t *seq_lens, + const int64_t *offset, + const float *alibi_slopes, + const size_t num_heads, + const size_t num_seqs, + const size_t num_kv_heads, + const float scale, + const size_t max_num_blocks_per_seq, + const size_t block_size, + const size_t total_q_tokens, + const ptrdiff_t q_stride, + const ptrdiff_t kv_block_stride, + const ptrdiff_t kv_head_stride, + const ptrdiff_t o_stride, + const size_t head_size, + cudaStream_t stream) { + + if (total_q_tokens == 0 || num_heads == 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // 使用 2D Grid: X轴是所有 Token,Y轴是所有 Head + dim3 grid(total_q_tokens, num_heads); + dim3 block(head_size); + + op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel + <<>>( + out, q, k_cache, v_cache, + block_tables, cache_lens, seq_lens, alibi_slopes, + num_heads, num_kv_heads, scale, + max_num_blocks_per_seq, block_size, + kv_block_stride, kv_head_stride, + head_size, + offset, num_seqs); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA Kernel Launch Failed: " << cudaGetErrorString(err) << std::endl; + return INFINI_STATUS_INTERNAL_ERROR; + } + + return INFINI_STATUS_SUCCESS; +} + +namespace op::paged_attention_prefill::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t offset_desc, + const std::optional &alibi_slopes_desc, + float scale) { + + auto info = PagedAttentionPrefillInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, + block_tables_desc, cache_lens_desc, seq_lens_desc, + offset_desc, + alibi_slopes_desc, scale); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *cache_lens, const void *seq_lens, + const void *offset, + const void *alibi_slopes, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + + if (_info.head_size > 1024) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + launchPagedAttentionPrefill( \ + (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ + (const int64_t *)block_tables, (const int64_t *)cache_lens, (const int64_t *)seq_lens, \ + (const int64_t *)offset, \ + (const float *)alibi_slopes, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ + _info.scale, _info.max_num_blocks_per_seq, \ + _info.block_size, _info.total_q_tokens, \ + _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \ + _info.head_size, \ + stream) + + if (_info.dtype == INFINI_DTYPE_F16) { + return LAUNCH_KERNEL(half, float); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + return LAUNCH_KERNEL(__nv_bfloat16, float); + } else if (_info.dtype == INFINI_DTYPE_F32) { + return LAUNCH_KERNEL(float, float); + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::paged_attention_prefill::nvidia diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh new file mode 100644 index 000000000..b9d3e97f1 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_PREFILL_NVIDIA_H__ +#define __PAGED_ATTENTION_PREFILL_NVIDIA_H__ + +#include "../paged_attention_prefill.h" + +DESCRIPTOR(nvidia) + +#endif // __PAGED_ATTENTION_PREFILL_NVIDIA_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc new file mode 100644 index 000000000..fe7688300 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -0,0 +1,95 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/paged_attention_prefill.h" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/paged_attention_prefill_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( + infiniopHandle_t handle, + infiniopPagedAttentionPrefillDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t offset_desc, + infiniopTensorDescriptor_t alibi_slopes_desc, + float scale) { + + infiniopTensorDescriptor_t alibi_opt = (alibi_slopes_desc == nullptr) ? nullptr : alibi_slopes_desc; + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, \ + seq_lens_desc, offset_desc, alibi_opt, scale); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( + infiniopPagedAttentionPrefillDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopPagedAttentionPrefill( + infiniopPagedAttentionPrefillDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *cache_lens, const void *seq_lens, + const void *offset, + const void *alibi_slopes, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \ + cache_lens, seq_lens, offset, alibi_slopes, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( + infiniopPagedAttentionPrefillDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} diff --git a/src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h b/src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h new file mode 100644 index 000000000..50d3ece1a --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h @@ -0,0 +1,56 @@ +#ifndef PAGED_ATTENTION_PREFILL_H +#define PAGED_ATTENTION_PREFILL_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::paged_attention_prefill::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + PagedAttentionPrefillInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + PagedAttentionPrefillInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_cache_desc, \ + infiniopTensorDescriptor_t v_cache_desc, \ + infiniopTensorDescriptor_t block_tables_desc, \ + infiniopTensorDescriptor_t cache_lens_desc, \ + infiniopTensorDescriptor_t seq_lens_desc, \ + infiniopTensorDescriptor_t offset_desc, \ + const std::optional &alibi_slopes_desc, \ + float scale); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, const void *q, const void *k_cache, const void *v_cache, \ + const void *block_tables, const void *cache_lens, const void *seq_lens, \ + const void *offset, \ + const void *alibi_slopes, \ + void *stream) const; \ + }; \ + } + +#endif // PAGED_ATTENTION_PREFILL_H diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index aa0ce5250..7c95ff84d 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -939,6 +939,7 @@ def tanh_(lib): infiniopOperatorDescriptor_t, ] + @OpRegister.operator def scaled_mm_int8_(lib): lib.infiniopCreateI8GemmDescriptor.restype = c_int32 @@ -1061,3 +1062,50 @@ def paged_caching_(lib): lib.infiniopDestroyPagedCachingDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def paged_attention_prefill_(lib): + lib.infiniopCreatePagedAttentionPrefillDescriptor.restype = c_int32 + lib.infiniopCreatePagedAttentionPrefillDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + ] + + lib.infiniopGetPagedAttentionPrefillWorkspaceSize.restype = c_int32 + lib.infiniopGetPagedAttentionPrefillWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopPagedAttentionPrefill.restype = c_int32 + lib.infiniopPagedAttentionPrefill.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyPagedAttentionPrefillDescriptor.restype = c_int32 + lib.infiniopDestroyPagedAttentionPrefillDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/paged_attention_prefill.py b/test/infiniop/paged_attention_prefill.py new file mode 100644 index 000000000..948fd72d5 --- /dev/null +++ b/test/infiniop/paged_attention_prefill.py @@ -0,0 +1,315 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +_TEST_CASES = [ + # num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds + (1, 1, 1, 128, 8, 16, 1), + (1, 4, 4, 128, 8, 16, 4), + (2, 8, 8, 128, 16, 32, 2), + (4, 16, 16, 128, 8, 64, 3), + (8, 64, 64, 128, 8, 16, 5), + (16, 128, 128, 128, 8, 16, 4), +] + +_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] + +_TOLERANCE_MAP = { + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, + InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 5 +NUM_ITERATIONS = 10 + + +# ============================================================================== +# Helper Classes & Reference Implementation +# ============================================================================== +class SimpleCacheManager: + def __init__(self, num_blocks, block_size): + self.num_blocks = num_blocks + self.block_size = block_size + self.free_blocks = list(range(num_blocks)) + self.request_to_blocks = {} + self.request_to_len = {} + + def allocate_slots(self, request_id, num_new_tokens): + if request_id not in self.request_to_len: + self.request_to_len[request_id] = 0 + self.request_to_blocks[request_id] = [] + + start_pos = self.request_to_len[request_id] + new_total_len = start_pos + num_new_tokens + needed_blocks = (new_total_len + self.block_size - 1) // self.block_size + added_blocks = needed_blocks - len(self.request_to_blocks[request_id]) + + for _ in range(added_blocks): + self.request_to_blocks[request_id].append(self.free_blocks.pop(0)) + + self.request_to_len[request_id] = new_total_len + return self.request_to_blocks[request_id], new_total_len + + +def ref_paged_attention_multi_turn( + query_new, k_cache, v_cache, block_tables, seq_lens, new_lens, offset, scale +): + block_size = k_cache.shape[2] + outputs = torch.zeros_like(query_new) + for i in range(len(offset) - 1): + total_len = seq_lens[i].item() + num_new = new_lens[i].item() + history_len = total_len - num_new + + table = block_tables[i] + keys_all, values_all = [], [] + for j in range(total_len): + b_id = table[j // block_size].item() + off = j % block_size + keys_all.append(k_cache[b_id, :, off, :]) + values_all.append(v_cache[b_id, :, off, :]) + + K = torch.stack(keys_all, dim=0) + V = torch.stack(values_all, dim=0) + Q = query_new[offset[i] : offset[i] + num_new, :, :] + + scores = torch.einsum("qhd,khd->hqk", Q, K).float() * scale + + mask = torch.full((num_new, total_len), float("-inf"), device=Q.device) + for q_idx in range(num_new): + mask[q_idx, : history_len + q_idx + 1] = 0.0 + + scores = scores + mask.unsqueeze(0) + attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, V) + + outputs[offset[i] : offset[i] + num_new, :, :] = out + + return outputs + + +# ============================================================================== +# Test Operator Implementation +# ============================================================================== +def test( + handle, + device, + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_step_len, + num_rounds, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing PagedAttentionPrefill on {InfiniDeviceNames[device]} with " + f"seqs:{num_seqs}, heads:{num_heads}, head_size:{head_size}, " + f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}" + ) + + # 1. Initialize persistent resources + num_blocks = 8192 + manager = SimpleCacheManager(num_blocks, block_size) + scale = head_size**-0.5 + + k_cache = TestTensor( + (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device + ) + v_cache = TestTensor( + (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device + ) + + # Multi-turn testing loop + for r in range(num_rounds): + # Prepare dynamic inputs for this round + seq_lens_cpu = torch.randint( + 1, max_step_len + 1, (num_seqs,), dtype=torch.int64 + ) + + q_total_tokens = seq_lens_cpu.sum().item() + q_packed_tensors = torch.zeros(q_total_tokens, num_heads, head_size) + + cache_lens_list = [] + all_block_tables = [] + + offset_list = [] + cur_offset = 0 + for i in range(num_seqs): + offset_list.append(cur_offset) + + cur_new_len = seq_lens_cpu[i].item() + table, cache_len = manager.allocate_slots(i, cur_new_len) + cache_lens_list.append(cache_len) + all_block_tables.append(table) + + # Simulated KV insertion + k_new = torch.randn(cur_new_len, num_kv_heads, head_size) + v_new = torch.randn(cur_new_len, num_kv_heads, head_size) + q_val = torch.randn(cur_new_len, num_heads, head_size) + q_packed_tensors[cur_offset : cur_offset + cur_new_len] = q_val + + cur_offset = cur_offset + cur_new_len + + history_len = cache_len - cur_new_len + for t in range(cur_new_len): + logical_pos = history_len + t + b_id = table[logical_pos // block_size] + off = logical_pos % block_size + k_cache.torch_tensor()[b_id, :, off, :] = k_new[t] + v_cache.torch_tensor()[b_id, :, off, :] = v_new[t] + + offset_list.append(cur_offset) + + k_cache.actual_tensor().copy_(k_cache._torch_tensor) + v_cache.actual_tensor().copy_(v_cache._torch_tensor) + + # 2. Wrap tensors for Infiniop + q_new = TestTensor.from_torch(q_packed_tensors, dtype, device) + out = TestTensor.from_torch(q_packed_tensors, dtype, device) + out.actual_tensor().zero_() + + cache_lens = TestTensor.from_torch( + torch.tensor(cache_lens_list, dtype=torch.int64), InfiniDtype.I64, device + ) + seq_lens = TestTensor.from_torch(seq_lens_cpu, InfiniDtype.I64, device) + + offset = TestTensor.from_torch( + torch.tensor(offset_list, dtype=torch.int64), InfiniDtype.I64, device + ) + + max_blocks = max(len(t) for t in all_block_tables) + padded_tables = [t + [0] * (max_blocks - len(t)) for t in all_block_tables] + block_tables = TestTensor.from_torch( + torch.tensor(padded_tables, dtype=torch.int64), InfiniDtype.I64, device + ) + + # 3. Reference Calculation + def torch_paged_attention_multi_turn(): + return ref_paged_attention_multi_turn( + q_new.torch_tensor(), + k_cache.torch_tensor(), + v_cache.torch_tensor(), + block_tables.torch_tensor(), + cache_lens.torch_tensor(), + seq_lens.torch_tensor(), + offset.torch_tensor(), + scale, + ) + + ans = torch_paged_attention_multi_turn() + + # 4. Infiniop Operator Execution + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreatePagedAttentionPrefillDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + q_new.descriptor, + k_cache.descriptor, + v_cache.descriptor, + block_tables.descriptor, + cache_lens.descriptor, + seq_lens.descriptor, + offset.descriptor, + None, # alibi_slopes_desc + scale, + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetPagedAttentionPrefillWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_attn(): + check_error( + LIBINFINIOP.infiniopPagedAttentionPrefill( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + q_new.data(), + k_cache.data(), + v_cache.data(), + block_tables.data(), + cache_lens.data(), + seq_lens.data(), + offset.data(), + None, + None, + ) + ) + + lib_attn() + if sync: + sync() + + # 5. Validation + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + # Profiling + if PROFILE: + profile_operation( + f"Torch_R{r}", + lambda: torch_paged_attention_multi_turn(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + f" Lib_R{r}", lambda: lib_attn(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error( + LIBINFINIOP.infiniopDestroyPagedAttentionPrefillDescriptor(descriptor) + ) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/paged_caching_prefill.py b/test/infiniop/paged_caching_prefill.py new file mode 100644 index 000000000..1fa9957fc --- /dev/null +++ b/test/infiniop/paged_caching_prefill.py @@ -0,0 +1,250 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +_TEST_CASES = [ + # num_seqs, max_step_len, num_kv_heads, head_size, block_size, num_rounds + (1, 16, 1, 128, 8, 5), + (2, 64, 8, 128, 16, 2), + (8, 128, 32, 128, 16, 3), + (5, 512, 40, 128, 16, 3), + (16, 64, 8, 128, 32, 1), + (10, 256, 40, 128, 32, 3), +] + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F32: {"atol": 1e-8, "rtol": 1e-8}, + InfiniDtype.F16: {"atol": 1e-8, "rtol": 1e-8}, + InfiniDtype.BF16: {"atol": 1e-8, "rtol": 1e-8}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 5 +NUM_ITERATIONS = 10 + + +# ============================================================================== +# Helper Classes & Reference Implementation +# ============================================================================== +class SimpleCacheManager: + def __init__(self, num_blocks, block_size): + self.num_blocks = num_blocks + self.block_size = block_size + self.free_blocks = list(range(num_blocks)) + self.request_to_blocks = {} + self.request_to_len = {} + + def allocate_slots(self, request_id, num_new_tokens): + if request_id not in self.request_to_len: + self.request_to_len[request_id] = 0 + self.request_to_blocks[request_id] = [] + + start_pos = self.request_to_len[request_id] + new_total_len = start_pos + num_new_tokens + needed_blocks = (new_total_len + self.block_size - 1) // self.block_size + added_blocks = needed_blocks - len(self.request_to_blocks[request_id]) + + for _ in range(added_blocks): + self.request_to_blocks[request_id].append(self.free_blocks.pop(0)) + + slots = [] + for i in range(start_pos, new_total_len): + block_idx_in_seq = i // self.block_size + block_offset = i % self.block_size + physical_block_id = self.request_to_blocks[request_id][block_idx_in_seq] + slots.append(physical_block_id * self.block_size + block_offset) + + self.request_to_len[request_id] = new_total_len + return torch.tensor(slots, dtype=torch.int32) + + +def ref_paged_caching(k_new, v_new, k_pool, v_pool, slots, block_size): + """Reference implementation for incremental caching.""" + for i in range(k_new.shape[0]): + slot = slots[i].item() + b_id = slot // block_size + off = slot % block_size + k_pool[b_id, :, off, :] = k_new[i] + v_pool[b_id, :, off, :] = v_new[i] + return k_pool, v_pool + + +# ============================================================================== +# Test Operator Implementation +# ============================================================================== +def test( + handle, + device, + num_seqs, + max_step_len, + num_kv_heads, + head_size, + block_size, + num_rounds, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing PagedCaching on {InfiniDeviceNames[device]} with " + f"seqs:{num_seqs}, max_step_len:{max_step_len}, num_kv_heads:{num_kv_heads}, head_size:{head_size}, " + f"block_size:{block_size}, rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}" + ) + + # 1. Initialize Global Cache Pool + num_blocks = 8192 + manager = SimpleCacheManager(num_blocks, block_size) + + k_cache_pool = TestTensor( + (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device + ) + v_cache_pool = TestTensor( + (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device + ) + + # Reference pools (CPU/Torch) + k_pool_ref = k_cache_pool.torch_tensor().clone() + v_pool_ref = v_cache_pool.torch_tensor().clone() + + for r in range(num_rounds): + # Prepare incremental data for this round + round_ntok_list = torch.randint( + 1, max_step_len + 1, (num_seqs,), dtype=torch.int32 + ) + all_slots, all_k, all_v = [], [], [] + + for i in range(num_seqs): + n_new = round_ntok_list[i].item() + all_slots.append(manager.allocate_slots(i, n_new)) + all_k.append(torch.randn(n_new, num_kv_heads, head_size)) + all_v.append(torch.randn(n_new, num_kv_heads, head_size)) + + k_in_torch = torch.cat(all_k, dim=0) + v_in_torch = torch.cat(all_v, dim=0) + slots_torch = torch.cat(all_slots, dim=0) + + k_in = TestTensor.from_torch(k_in_torch, dtype, device) + v_in = TestTensor.from_torch(v_in_torch, dtype, device) + slot_mapping = TestTensor.from_torch(slots_torch, InfiniDtype.I64, device) + + # 2. Reference Calculation + def torch_caching(): + nonlocal k_pool_ref, v_pool_ref + return ref_paged_caching( + k_in.torch_tensor(), + v_in.torch_tensor(), + k_pool_ref, + v_pool_ref, + slots_torch, + block_size, + ) + + torch_caching() + + # 3. Infiniop Operator Execution + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreatePagedCachingDescriptor( + handle, + ctypes.byref(descriptor), + k_in.descriptor, + v_in.descriptor, + k_cache_pool.descriptor, + v_cache_pool.descriptor, + slot_mapping.descriptor, + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetPagedCachingWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_caching(): + check_error( + LIBINFINIOP.infiniopPagedCaching( + descriptor, + workspace.data(), + workspace_size.value, + k_in.data(), + v_in.data(), + k_cache_pool.data(), + v_cache_pool.data(), + slot_mapping.data(), + None, + ) + ) + + lib_caching() + if sync: + sync() + + # 4. Validation + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + # Check a small slice of the updated cache + debug(k_cache_pool.actual_tensor(), k_pool_ref, atol=atol, rtol=rtol) + + assert torch.allclose( + k_cache_pool.actual_tensor(), k_pool_ref, atol=atol, rtol=rtol + ) + assert torch.allclose( + v_cache_pool.actual_tensor(), v_pool_ref, atol=atol, rtol=rtol + ) + + # 5. Profiling + if PROFILE: + profile_operation( + f"Torch_R{r}", + lambda: torch_caching(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + f" Lib_R{r}", lambda: lib_caching(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyPagedCachingDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m")