diff --git a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc index 7251eabd9b7..274dba7ee16 100644 --- a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc @@ -16,29 +16,29 @@ #include "paddle/extension.h" #include "xpu/plugin.h" -std::vector GetPaddingOffset(const paddle::Tensor &input_ids, - const paddle::Tensor &cum_offsets, - const paddle::Tensor &token_num, - const paddle::Tensor &seq_len) { +std::vector GetPaddingOffset(const paddle::Tensor& input_ids, + const paddle::Tensor& seq_len, + const int64_t cpu_token_num) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + auto xpu_ctx = static_cast(dev_ctx); std::vector input_ids_shape = input_ids.shape(); const int bsz = seq_len.shape()[0]; const int seq_length = input_ids_shape[1]; - auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + const int token_num_data = static_cast(cpu_token_num); - const int token_num_data = cpu_token_num.data()[0]; auto x_remove_padding = paddle::full( {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); auto batch_id_per_token = paddle::full( {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); + auto cum_offsets_out = + paddle::full({bsz}, 0, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + if (token_num_data > 0) { int r = fastdeploy::plugin::get_padding_offset(xpu_ctx->x_context(), @@ -48,7 +48,6 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, cu_seqlens_k.data(), x_remove_padding.data(), input_ids.data(), - cum_offsets.data(), seq_len.data(), seq_length, bsz, @@ -64,20 +63,15 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, } std::vector> GetPaddingOffsetInferShape( - const std::vector &input_ids_shape, - const std::vector &cum_offsets_shape, - const std::vector &token_num_shape, - const std::vector &seq_len_shape) { + const std::vector& input_ids_shape, + const std::vector& seq_len_shape) { int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( - const paddle::DataType &input_ids_dtype, - const paddle::DataType &cum_offsets_dtype, - const paddle::DataType &token_num_dtype, - const paddle::DataType &seq_len_dtype) { + const paddle::DataType& input_ids_dtype, + const paddle::DataType& seq_len_dtype) { return {input_ids_dtype, seq_len_dtype, seq_len_dtype, @@ -86,12 +80,13 @@ std::vector GetPaddingOffsetInferDtype( } PD_BUILD_OP(get_padding_offset) - .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) + .Inputs({"input_ids", "seq_len"}) .Outputs({"x_remove_padding", "cum_offsets_out", "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) + .Attrs({"cpu_token_num: int64_t"}) .SetKernelFn(PD_KERNEL(GetPaddingOffset)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 14468ddda48..c7f6e5616e5 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -456,9 +456,8 @@ void GetOutputEPDynamic(const paddle::Tensor& x, int msg_queue_id); std::vector GetPaddingOffset(const paddle::Tensor& input_ids, - const paddle::Tensor& cum_offsets, - const paddle::Tensor& token_num, - const paddle::Tensor& seq_len); + const paddle::Tensor& seq_len, + const int64_t cpu_token_num); void GetStopFlagsMulti(const paddle::Tensor& topk_ids, const paddle::Tensor& stop_flags, @@ -975,9 +974,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_padding_offset", &GetPaddingOffset, py::arg("input_ids"), - py::arg("cum_offsets"), - py::arg("token_num"), py::arg("seq_len"), + py::arg("cpu_token_num"), "get padding offset function"); m.def("init_kv_signal_per_query", diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 1cbd7a8029b..da34b507aff 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -68,13 +68,12 @@ DLL_EXPORT int token_penalty_multi_scores(api::Context* ctx, const int64_t length_bad_words); DLL_EXPORT int get_padding_offset(api::Context* ctx, - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, int64_t* x_remove_padding, const int64_t* input_ids, - const int* cum_offsets, const int* seq_lens, const int max_seq_len, const int bs, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu index 2b2b283daa6..2f1a4a0db59 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu @@ -1,50 +1,123 @@ #include "xpu/kernel/cluster.h" #include "xpu/kernel/cluster_partition.h" #include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_simd.h" namespace fd_xpu3 { -__global__ void get_padding_offset(int *batch_id_per_token, - int *cum_offsets_out, - int *cu_seqlens_q, - int *cu_seqlens_k, - const int *cum_offsets, - const int *seq_lens, +#define MAX_BATCH_SIZE 1024 + +static inline __device__ int v_reduce_sum_int32(int32x16_t& v0) { + auto v1 = vsrlp_int32x16(1 << 8, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 7, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 6, v0); + v0 = vvadd_int32x16(v0, v1); + v1 = vsrlp_int32x16(1 << 5, v0); + v0 = vvadd_int32x16(v0, v1); + return vextract_int32x16(v0, 1); +} + +inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int* x, + int64_t len) { + int32x16_t x_l, x_h; + int32x16_t sum = vset_zero_int(); + const auto rounddown_len = rounddown32(len); + + for (int64_t i = 0; i < rounddown_len; i += 32) { + vload2_sm(x + i, x_l, x_h); + sum = vvadd_int32x16(sum, x_l); + sum = vvadd_int32x16(sum, x_h); + } + + if (rounddown_len < len) { + const auto mask = ~(-1 << (len - rounddown_len)); + vload2_sm_mz(x + rounddown_len, x_l, x_h, mask); + sum = vvadd_int32x16(sum, x_l); + sum = vvadd_int32x16(sum, x_h); + } + return v_reduce_sum_int32(sum); +} + +__global__ void get_padding_offset(int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int64_t* input_data, + const int* seq_lens, const int max_seq_len, const int bs) { int cid = core_id(); int ncores = core_num(); int clusterid = cluster_id(); int nclusters = cluster_num(); - int tid = clusterid * ncores + cid; - - int buf_len = 32; - __simd__ int batch_id_per_token_lm[buf_len]; - __simd__ int cum_offsets_lm[16]; - int seq_len_lm; - for (int i = clusterid; i < bs; i += nclusters) { - GM2LM_ASYNC(seq_lens + i, &seq_len_lm, sizeof(int)); - GM2LM(cum_offsets + i - 1, cum_offsets_lm, 2 * sizeof(int)); - if (i == 0) { - cum_offsets_lm[0] = 0; - } - for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) { - int cur_len = min(seq_len_lm - j, buf_len); - for (int k = 0; k < cur_len; k++) { - batch_id_per_token_lm[k] = i; - } - mfence_lm(); - LM2GM(batch_id_per_token_lm, - batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j, - cur_len * sizeof(int)); + + __shared__ int sm_seq_lens[MAX_BATCH_SIZE]; + __shared__ int sm_cum_seq_len; + __simd__ __shared__ int buffer_cu_seqlens[64]; + + if (cid == 0) { + GM2SM(seq_lens, sm_seq_lens, sizeof(int) * bs); + } + sync_all(); + + for (int bi = clusterid; bi < bs; bi += nclusters) { + int cum_seq_len = 0; + for (int i = cid; i <= bi; i += ncores) { + cum_seq_len += sm_seq_lens[i]; } + buffer_cu_seqlens[cid] = cum_seq_len; + mfence(); + sync_all(); + if (cid == 0) { - int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1]; - mfence_lm(); - LM2GM_ASYNC(cum_offsets_lm, cum_offsets_out + i, sizeof(int)); - LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + i + 1, sizeof(int)); - LM2GM(&cum_seq_len, cu_seqlens_k + i + 1, sizeof(int)); + cum_seq_len = + primitive_reduce_sum_sm(buffer_cu_seqlens, min(bi + 1, ncores)); + + LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int)); + LM2GM_ASYNC(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int)); + + int cum_offset = bi * max_seq_len - (cum_seq_len - sm_seq_lens[bi]); + LM2GM(&cum_offset, cum_offsets_out + bi, sizeof(int)); + + sm_cum_seq_len = cum_seq_len; + } + mfence(); + sync_all(); + + const int lm_seq_lens = sm_seq_lens[bi]; + const int tgt_offset = sm_cum_seq_len - lm_seq_lens; + const int buf_len = 32; + __simd__ int64_t input_lm[buf_len]; + __simd__ int batch_id_lm[buf_len]; + + for (int k = 0; k < buf_len; k++) { + batch_id_lm[k] = bi; + } + mfence_lm(); + + for (int j = cid * buf_len; j < lm_seq_lens; j += ncores * buf_len) { + int cur_len = min(lm_seq_lens - j, buf_len); + GM2LM(input_data + bi * max_seq_len + j, + input_lm, + sizeof(int64_t) * cur_len); + LM2GM(input_lm, + ids_remove_padding + tgt_offset + j, + sizeof(int64_t) * cur_len); + LM2GM(batch_id_lm, + batch_id_per_token + tgt_offset + j, + sizeof(int) * cur_len); } + mfence(); + sync_all(); + } + + if (cid == 0 && clusterid == 0) { + const int lm_zero = 0; + LM2GM_ASYNC(&lm_zero, cu_seqlens_q, sizeof(int)); + LM2GM(&lm_zero, cu_seqlens_k, sizeof(int)); } } diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp index 551185ffdac..9659aeb9d8a 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp @@ -19,176 +19,121 @@ namespace fd_xpu3 { -__attribute__((global)) void get_padding_offset(int *padding_offset, - int *cum_offsets_out, - int *cu_seqlens_q, - int *cu_seqlens_k, - const int *cum_offsets, - const int *seq_lens, +__attribute__((global)) void get_padding_offset(int64_t* ids_remove_padding, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int64_t* input_data, + const int* seq_lens, const int max_seq_len, const int bs); -__attribute__((global)) void remove_padding(int64_t *x_remove_padding, - const int64_t *input_data, - const int *seq_lens, - const int *cum_offsets, - const int sequence_length, - const int bs); } // namespace fd_xpu3 namespace fastdeploy { namespace plugin { -static int get_padding_offset_cpu(int *padding_offset, - int *cum_offsets_out, - int *cu_seqlens_q, - int *cu_seqlens_k, - const int *cum_offsets, - const int *seq_lens, - const int max_seq_len, - const int bs) { +static int cpu_wrapper(api::Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + int64_t* x_remove_padding, + const int64_t* input_ids, + const int* seq_lens, + const int max_seq_len, + const int bs) { + int cum_seq_len = 0; + cu_seqlens_q[0] = 0; + cu_seqlens_k[0] = 0; for (int i = 0; i < bs; i++) { - int cum_offset = i == 0 ? 0 : cum_offsets[i - 1]; + cum_offsets_out[i] = i * max_seq_len - cum_seq_len; for (int j = 0; j < seq_lens[i]; j++) { - // TODO(mayang02): check offset of padding_offset - padding_offset[i * max_seq_len - cum_offset + j] = cum_offset; + const int tgt = cum_seq_len + j; + x_remove_padding[tgt] = input_ids[i * max_seq_len + j]; + batch_id_per_token[tgt] = i; } - cum_offsets_out[i] = cum_offset; - int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i]; + cum_seq_len += seq_lens[i]; cu_seqlens_q[i + 1] = cum_seq_len; cu_seqlens_k[i + 1] = cum_seq_len; } return api::SUCCESS; } -static int remove_padding_cpu(int64_t *x_remove_padding, - const int64_t *input_data, - const int *seq_lens, - const int *cum_offsets, - const int sequence_length, - const int bs) { - for (int i = 0; i < bs; i++) { - for (int j = 0; j < seq_lens[i]; j++) { - const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j; - const int src_seq_id = i * sequence_length + j; - // TODO(mayang02): check offset of x_remove_padding - x_remove_padding[tgt_seq_id] = input_data[src_seq_id]; - } - } - return api::SUCCESS; -} - -static int cpu_wrapper(api::Context *ctx, - int *padding_offset, - int *cum_offsets_out, - int *cu_seqlens_q, - int *cu_seqlens_k, - int64_t *x_remove_padding, - const int64_t *input_ids, - const int *cum_offsets, - const int *seq_lens, - const int max_seq_len, - const int bs) { - get_padding_offset_cpu(padding_offset, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - cum_offsets, - seq_lens, - max_seq_len, - bs); - remove_padding_cpu( - x_remove_padding, input_ids, seq_lens, cum_offsets_out, max_seq_len, bs); - return api::SUCCESS; -} - -static int xpu3_wrapper(api::Context *ctx, - int *padding_offset, - int *cum_offsets_out, - int *cu_seqlens_q, - int *cu_seqlens_k, - int64_t *x_remove_padding, - const int64_t *input_ids, - const int *cum_offsets, - const int *seq_lens, +static int xpu3_wrapper(api::Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + int64_t* x_remove_padding, + const int64_t* input_ids, + const int* seq_lens, const int max_seq_len, const int bs) { using XPU_INT64 = typename api::XPUIndexType::type; - auto get_padding_offset = fd_xpu3::get_padding_offset; - auto remove_padding = fd_xpu3::remove_padding; int32_t ret_xre = - get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>( - padding_offset, + fd_xpu3::get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(x_remove_padding), + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, - cum_offsets, + reinterpret_cast(input_ids), seq_lens, max_seq_len, bs); KERNEL_ASSERT_SUCCESS(ctx, ret_xre); - ret_xre = remove_padding<<ncluster(), 64, ctx->xpu_stream>>>( - reinterpret_cast(x_remove_padding), - reinterpret_cast(input_ids), - seq_lens, - cum_offsets_out, - max_seq_len, - bs); - KERNEL_ASSERT_SUCCESS(ctx, ret_xre); return api::SUCCESS; } -int get_padding_offset(api::Context *ctx, - int *padding_offset, - int *cum_offsets_out, - int *cu_seqlens_q, - int *cu_seqlens_k, - int64_t *x_remove_padding, - const int64_t *input_ids, - const int *cum_offsets, - const int *seq_lens, +int get_padding_offset(api::Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + int64_t* x_remove_padding, + const int64_t* input_ids, + const int* seq_lens, const int max_seq_len, const int bs, const int64_t token_num) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int); WRAPPER_DUMP_PARAM4( - ctx, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k); - WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, seq_lens); - WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs); + ctx, batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k); + WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, seq_lens, max_seq_len); + WRAPPER_DUMP_PARAM2(ctx, bs, token_num); WRAPPER_DUMP(ctx); WRAPPER_ASSERT_GT(ctx, bs, 0); WRAPPER_ASSERT_GT(ctx, max_seq_len, 0); - WRAPPER_CHECK_PTR(ctx, int, token_num, padding_offset); + WRAPPER_CHECK_PTR(ctx, int64_t, token_num, x_remove_padding); + WRAPPER_CHECK_PTR(ctx, int, token_num, batch_id_per_token); WRAPPER_CHECK_PTR(ctx, int, bs, cum_offsets_out); WRAPPER_CHECK_PTR(ctx, int, bs + 1, cu_seqlens_q); WRAPPER_CHECK_PTR(ctx, int, bs + 1, cu_seqlens_k); - WRAPPER_CHECK_PTR(ctx, int64_t, token_num, x_remove_padding); WRAPPER_CHECK_PTR(ctx, int64_t, bs * max_seq_len, input_ids); - WRAPPER_CHECK_PTR(ctx, int, bs, cum_offsets); WRAPPER_CHECK_PTR(ctx, int, bs, seq_lens); if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, x_remove_padding, input_ids, - cum_offsets, seq_lens, max_seq_len, bs); } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, x_remove_padding, input_ids, - cum_offsets, seq_lens, max_seq_len, bs); diff --git a/custom_ops/xpu_ops/test/test_get_padding_offset.py b/custom_ops/xpu_ops/test/test_get_padding_offset.py index 614386488a6..311f7f324be 100644 --- a/custom_ops/xpu_ops/test/test_get_padding_offset.py +++ b/custom_ops/xpu_ops/test/test_get_padding_offset.py @@ -21,8 +21,7 @@ max_len = 10 seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1) -cum_offset = np.cumsum((max_len - seq_lens).flatten(), -1, "int32") -token_num = np.sum(seq_lens) +token_num = int(np.sum(seq_lens)) bs = seq_lens.shape[0] input_ids = np.zeros([bs, max_len], "int64") for i in range(bs): @@ -32,34 +31,44 @@ ( x_remove_padding, cum_offsets_out, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k, ) = get_padding_offset( paddle.to_tensor(input_ids), - paddle.to_tensor(cum_offset), - paddle.to_tensor(token_num), - paddle.to_tensor(seq_lens), + paddle.to_tensor(seq_lens.flatten()), + token_num, ) print("input_ids:\n", input_ids) -print("cum_offset:\n", cum_offset) +print("seq_lens:\n", seq_lens.flatten()) print("token_num:\n", token_num) -print("seq_lens:\n", seq_lens) print("x_remove_padding:\n", x_remove_padding) print("cum_offsets_out:\n", cum_offsets_out) -print("padding_offset:\n", padding_offset) +print("batch_id_per_token:\n", batch_id_per_token) print("cu_seqlens_q:\n", cu_seqlens_q) print("cu_seqlens_k:\n", cu_seqlens_k) ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64") ref_cum_offsets_out = np.array([0, 6, 13], "int32") -ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32") +ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32") ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32") ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32") -assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed." -assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed." -assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed." -assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed." -assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed." +assert ( + np.sum(np.abs(ref_x_remove_padding - x_remove_padding.numpy())) == 0 +), f"Check x_remove_padding failed.\nref: {ref_x_remove_padding}\ngot: {x_remove_padding.numpy()}" +assert ( + np.sum(np.abs(ref_cum_offsets_out - cum_offsets_out.numpy())) == 0 +), f"Check cum_offsets_out failed.\nref: {ref_cum_offsets_out}\ngot: {cum_offsets_out.numpy()}" +assert ( + np.sum(np.abs(ref_batch_id_per_token - batch_id_per_token.numpy())) == 0 +), f"Check batch_id_per_token failed.\nref: {ref_batch_id_per_token}\ngot: {batch_id_per_token.numpy()}" +assert ( + np.sum(np.abs(ref_cu_seqlens_q - cu_seqlens_q.numpy())) == 0 +), f"Check cu_seqlens_q failed.\nref: {ref_cu_seqlens_q}\ngot: {cu_seqlens_q.numpy()}" +assert ( + np.sum(np.abs(ref_cu_seqlens_k - cu_seqlens_k.numpy())) == 0 +), f"Check cu_seqlens_k failed.\nref: {ref_cu_seqlens_k}\ngot: {cu_seqlens_k.numpy()}" + +print("\nAll checks passed!") diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 8a449b597d0..5c959875271 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -111,6 +111,7 @@ def xpu_pre_process( max_len = input_ids.shape[1] cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") token_num = paddle.sum(seq_lens_this_time) + token_num_cpu = paddle.sum(seq_lens_this_time).cpu() if use_speculate_method: ( @@ -151,7 +152,7 @@ def xpu_pre_process( batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) + ) = get_padding_offset(input_ids, seq_lens_this_time, token_num_cpu) share_inputs["cum_offsets"] = cum_offsets share_inputs["batch_id_per_token"] = batch_id_per_token