From df162ead7c9bdf25094443682c203b2a969b38f3 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Tue, 24 Feb 2026 03:26:23 -0800 Subject: [PATCH] Implementation of tiled attention with bf16 and circular buffers which reduces memory requirements by 4x on longer context on gemma models. It also supports better parallelism for small batch sizes / small models. It also is able to utilize VDPBF16PS for nice 2x improvement on avx512 PiperOrigin-RevId: 874517319 --- BUILD.bazel | 2 + CMakeLists.txt | 10 +- gemma/activations.h | 18 + gemma/configs.h | 4 + gemma/flash_attention.cc | 614 ++++++++++++++++++++++++++++ gemma/flash_attention.h | 98 +++-- gemma/flash_attention_test.cc | 292 +++++++++++++ gemma/gemma.cc | 11 +- gemma/gemma_args.h | 8 + gemma/kv_cache.cc | 62 +++ gemma/kv_cache.h | 80 ++++ gemma/tiled_attention.cc | 660 ++++++++++++++++++++++++++++++ gemma/tiled_attention.h | 42 ++ gemma/tiled_attention_test.cc | 749 ++++++++++++++++++++++++++++++++++ ops/ops-inl.h | 444 ++++++++++++++++++++ 15 files changed, 3056 insertions(+), 38 deletions(-) create mode 100644 gemma/tiled_attention.cc create mode 100644 gemma/tiled_attention.h create mode 100644 gemma/tiled_attention_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index cb6ca504..885bb665 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -652,10 +652,12 @@ cc_library( name = "gemma_lib", srcs = [ "gemma/gemma.cc", + "gemma/tiled_attention.cc", "gemma/vit.cc", ], hdrs = [ "gemma/gemma.h", + "gemma/tiled_attention.h", "gemma/vit.h", ], exec_properties = { diff --git a/CMakeLists.txt b/CMakeLists.txt index 47d7c4c2..58bfab51 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,6 +93,8 @@ set(SOURCES gemma/model_store.h gemma/tensor_info.cc gemma/tensor_info.h + gemma/tiled_attention.cc + gemma/tiled_attention.h gemma/tokenizer.cc gemma/tokenizer.h gemma/vit.cc @@ -171,20 +173,20 @@ install(TARGETS libgemma DESTINATION lib) if(BUILD_GEMMA_DLL) add_library(gemma_shared SHARED ${SOURCES}) set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17) -set_target_properties(gemma_shared PROPERTIES +set_target_properties(gemma_shared PROPERTIES PREFIX "" OUTPUT_NAME "gemma" ) set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON) target_include_directories(gemma_shared PUBLIC ./) -target_link_libraries(gemma_shared PRIVATE +target_link_libraries(gemma_shared PRIVATE $ $ $ ) target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR}) -target_compile_definitions(gemma_shared - PRIVATE +target_compile_definitions(gemma_shared + PRIVATE GEMMA_EXPORTS $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX> ) diff --git a/gemma/activations.h b/gemma/activations.h index 1e0a56ac..3df61c5a 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -153,6 +153,14 @@ struct AttentionActivations { // Accumulation of attention outputs over heads MatStorageT att_sums; + MatStorageT k_tile_vec; + MatStorageT v_tile_vec; + std::vector> sub_task_att_out; + std::vector + sub_task_exp_denominator_sums; + std::vector + sub_task_max_logits; + // Rope MatStorageT inv_timescale; MatStorageT inv_timescale_global; @@ -244,6 +252,16 @@ struct AttentionActivationsPtrs { // Accumulation of attention outputs over heads, size batch_size x // model_dim. MatPtrT att_sums; + // Stores intermediate results of computing QKV, + // [qbatch * kv_heads , k_tile_size * qkv_dim] + MatPtrT k_tile_vec; + MatPtrT v_tile_vec; + // Used by TiledFlashAttention to store intermediate results. + std::vector>* sub_task_att_out; + std::vector* + sub_task_exp_denominator_sums; + std::vector* + sub_task_max_logits; // Inverse timescales for RoPE computation. MatPtrT inv_timescale; // Inverse timescales for global RoPE computation. diff --git a/gemma/configs.h b/gemma/configs.h index 2f15ee81..3df7ec99 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -83,6 +83,8 @@ static inline bool EnumValid(LayerAttentionType type) { enum class AttentionImpl { kOld, kFlash, + kFlashTransposedQs, + kFlashTransposedQsBF16, kSentinel, }; @@ -108,6 +110,8 @@ static inline int AttentionImplToFlags(AttentionImpl impl, case AttentionImpl::kOld: return kAttentionUseOld; case AttentionImpl::kFlash: + case AttentionImpl::kFlashTransposedQs: + case AttentionImpl::kFlashTransposedQsBF16: default: return 0; } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 8aa29cf4..ebe8ee18 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -921,6 +921,620 @@ Tile4FlashState TileFlashAttention4( return state; } +template , + typename T> +static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth( + DQ_T df, const Q_T* HWY_RESTRICT q, const Q_T* HWY_RESTRICT q2, + const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VQ_T& sum0_p0, + VQ_T& sum0_p1, VQ_T& sum1_p0, VQ_T& sum1_p1, VQ_T& sum2_p0, VQ_T& sum2_p1, + VQ_T& sum3_p0, VQ_T& sum3_p1, VQ_T& sum4_p0, VQ_T& sum4_p1, VQ_T& sum5_p0, + VQ_T& sum5_p1, VQ_T& sum6_p0, VQ_T& sum6_p1, VQ_T& sum7_p0, VQ_T& sum7_p1) { + const PackedSpan k_transposed_span = + MakeConstSpan(k_transposed_tile, gcpp::KVCache::kTileSize * qkv_dim); + HWY_DASSERT(kNumQueries <= 8); + HWY_DASSERT(gcpp::KVCache::kTileSize >= + hn::Lanes(df) * 2); // So we can decompress 2 lanes at a time. + sum0_p0 = hn::Zero(df); + sum0_p1 = hn::Zero(df); + if constexpr (kNumQueries >= 2) { + sum1_p0 = hn::Zero(df); + sum1_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 3) { + sum2_p0 = hn::Zero(df); + sum2_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 4) { + sum3_p0 = hn::Zero(df); + sum3_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 5) { + sum4_p0 = hn::Zero(df); + sum4_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 6) { + sum5_p0 = hn::Zero(df); + sum5_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 7) { + sum6_p0 = hn::Zero(df); + sum6_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 8) { + sum7_p0 = hn::Zero(df); + sum7_p1 = hn::Zero(df); + } + + constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); + constexpr int kSecondHalfAmountOfQueries = + kNumQueries - kFirstHalfAmountOfQueries; + HWY_UNROLL(1) + for (size_t i = 0; i < qkv_dim; ++i) { + VQ_T k_vec1, k_vec2; + if constexpr (HWY_TARGET == HWY_AVX2) { + hwy::Prefetch(k_transposed_span.ptr + (i + 3) * gcpp::KVCache::kTileSize); + hwy::Prefetch(k_transposed_span.ptr + (i + 4) * gcpp::KVCache::kTileSize); + } + Decompress2(df, k_transposed_span, i * gcpp::KVCache::kTileSize, k_vec1, + k_vec2); + sum0_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p0); + sum0_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p1); + if constexpr (kNumQueries >= 2) { + sum1_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p0); + sum1_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p1); + } + if constexpr (kNumQueries >= 3) { + sum2_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p0); + sum2_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p1); + } + if constexpr (kNumQueries >= 4) { + sum3_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p0); + sum3_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p1); + } + if constexpr (kNumQueries >= 5) { + sum4_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p0); + sum4_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p1); + } + if constexpr (kNumQueries >= 6) { + sum5_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p0); + sum5_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p1); + } + if constexpr (kNumQueries >= 7) { + sum6_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p0); + sum6_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p1); + } + if constexpr (kNumQueries >= 8) { + sum7_p0 = hn::MulAdd( + k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p0); + sum7_p1 = hn::MulAdd( + k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p1); + } + } +} + +template , typename T> +static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16( + DF df, const BF16* HWY_RESTRICT q, const BF16* HWY_RESTRICT q2, + const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0, + VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1, + VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0, + VF& sum5_p1, VF& sum6_p0, VF& sum6_p1, VF& sum7_p0, VF& sum7_p1) { + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + const PackedSpan k_transposed_span = + MakeConstSpan(k_transposed_tile, gcpp::KVCache::kTileSize * qkv_dim); + [[maybe_unused]] HWY_LANES_CONSTEXPR size_t lanes_bf16 = hn::Lanes(dbf); + HWY_DASSERT(hn::Lanes(dbf) <= gcpp::KVCache::kTileSize); + HWY_DASSERT(kNumQueries <= 8); + HWY_DASSERT(gcpp::KVCache::kTileSize >= + hn::Lanes(df) * 2); // So we can decompress 2 lanes at a time. + sum0_p0 = hn::Zero(df); + sum0_p1 = hn::Zero(df); + if constexpr (kNumQueries >= 2) { + sum1_p0 = hn::Zero(df); + sum1_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 3) { + sum2_p0 = hn::Zero(df); + sum2_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 4) { + sum3_p0 = hn::Zero(df); + sum3_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 5) { + sum4_p0 = hn::Zero(df); + sum4_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 6) { + sum5_p0 = hn::Zero(df); + sum5_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 7) { + sum6_p0 = hn::Zero(df); + sum6_p1 = hn::Zero(df); + } + if constexpr (kNumQueries >= 8) { + sum7_p0 = hn::Zero(df); + sum7_p1 = hn::Zero(df); + } + VF helper_sum0_p0 = hn::Zero(df), helper_sum0_p1 = hn::Zero(df); + VF helper_sum1_p0 = hn::Zero(df), helper_sum1_p1 = hn::Zero(df); + VF helper_sum2_p0 = hn::Zero(df), helper_sum2_p1 = hn::Zero(df); + VF helper_sum3_p0 = hn::Zero(df), helper_sum3_p1 = hn::Zero(df); + VF helper_sum4_p0 = hn::Zero(df), helper_sum4_p1 = hn::Zero(df); + VF helper_sum5_p0 = hn::Zero(df), helper_sum5_p1 = hn::Zero(df); + VF helper_sum6_p0 = hn::Zero(df), helper_sum6_p1 = hn::Zero(df); + VF helper_sum7_p0 = hn::Zero(df), helper_sum7_p1 = hn::Zero(df); + const float* q_float_ptr = HWY_RCAST_ALIGNED(const float*, q); + const float* q2_float_ptr = HWY_RCAST_ALIGNED(const float*, q2); + constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); + constexpr int kSecondHalfAmountOfQueries = + kNumQueries - kFirstHalfAmountOfQueries; + + for (size_t i = 0; i < qkv_dim / 2; i++) { + VBF k_vec1, k_vec2; + Decompress2(dbf, k_transposed_span, i * 2 * gcpp::KVCache::kTileSize, + k_vec1, k_vec2); + + VF q_0_as_float = hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries]); + VBF q_0 = hn::BitCast(dbf, q_0_as_float); + sum0_p0 = + hn::ReorderWidenMulAccumulate(df, k_vec1, q_0, sum0_p0, helper_sum0_p0); + sum0_p1 = + hn::ReorderWidenMulAccumulate(df, k_vec2, q_0, sum0_p1, helper_sum0_p1); + if constexpr (kNumQueries >= 2) { + VF q_1_as_float = + hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 1]); + VBF q_1 = hn::BitCast(dbf, q_1_as_float); + sum1_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_1, sum1_p0, + helper_sum1_p0); + sum1_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_1, sum1_p1, + helper_sum1_p1); + } + if constexpr (kNumQueries >= 3) { + VF q_2_as_float = + hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 2]); + VBF q_2 = hn::BitCast(dbf, q_2_as_float); + sum2_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_2, sum2_p0, + helper_sum2_p0); + sum2_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_2, sum2_p1, + helper_sum2_p1); + } + if constexpr (kNumQueries >= 4) { + VF q_3_as_float = + hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 3]); + VBF q_3 = hn::BitCast(dbf, q_3_as_float); + sum3_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_3, sum3_p0, + helper_sum3_p0); + sum3_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_3, sum3_p1, + helper_sum3_p1); + } + if constexpr (kNumQueries >= 5) { + VF q_4_as_float = + hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 0]); + VBF q_4 = hn::BitCast(dbf, q_4_as_float); + sum4_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_4, sum4_p0, + helper_sum4_p0); + sum4_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_4, sum4_p1, + helper_sum4_p1); + } + if constexpr (kNumQueries >= 6) { + VF q_5_as_float = + hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 1]); + VBF q_5 = hn::BitCast(dbf, q_5_as_float); + sum5_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_5, sum5_p0, + helper_sum5_p0); + sum5_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_5, sum5_p1, + helper_sum5_p1); + } + if constexpr (kNumQueries >= 7) { + VF q_6_as_float = + hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 2]); + VBF q_6 = hn::BitCast(dbf, q_6_as_float); + sum6_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_6, sum6_p0, + helper_sum6_p0); + sum6_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_6, sum6_p1, + helper_sum6_p1); + } + if constexpr (kNumQueries >= 8) { + VF q_7_as_float = + hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 3]); + VBF q_7 = hn::BitCast(dbf, q_7_as_float); + sum7_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_7, sum7_p0, + helper_sum7_p0); + sum7_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_7, sum7_p1, + helper_sum7_p1); + } + } +#if HWY_NATIVE_DOT_BF16 == 0 + sum0_p0 = hn::Add(sum0_p0, helper_sum0_p0); + sum0_p1 = hn::Add(sum0_p1, helper_sum0_p1); + if constexpr (kNumQueries >= 2) { + sum1_p0 = hn::Add(sum1_p0, helper_sum1_p0); + sum1_p1 = hn::Add(sum1_p1, helper_sum1_p1); + } + if constexpr (kNumQueries >= 3) { + sum2_p0 = hn::Add(sum2_p0, helper_sum2_p0); + sum2_p1 = hn::Add(sum2_p1, helper_sum2_p1); + } + if constexpr (kNumQueries >= 4) { + sum3_p0 = hn::Add(sum3_p0, helper_sum3_p0); + sum3_p1 = hn::Add(sum3_p1, helper_sum3_p1); + } + if constexpr (kNumQueries >= 5) { + sum4_p0 = hn::Add(sum4_p0, helper_sum4_p0); + sum4_p1 = hn::Add(sum4_p1, helper_sum4_p1); + } + if constexpr (kNumQueries >= 6) { + sum5_p0 = hn::Add(sum5_p0, helper_sum5_p0); + sum5_p1 = hn::Add(sum5_p1, helper_sum5_p1); + } + if constexpr (kNumQueries >= 7) { + sum6_p0 = hn::Add(sum6_p0, helper_sum6_p0); + sum6_p1 = hn::Add(sum6_p1, helper_sum6_p1); + } + if constexpr (kNumQueries >= 8) { + sum7_p0 = hn::Add(sum7_p0, helper_sum7_p0); + sum7_p1 = hn::Add(sum7_p1, helper_sum7_p1); + } +#endif +} + +template > +static HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap, + VF& x0, VF& x1, VF& x2, VF& x3, VF& x4, + VF& x5, VF& x6, VF& x7) { + if (att_cap > 0.0f) { + VF cap = hn::Set(df, att_cap); + VF one_over_cap_vec = hn::Set(df, one_over_cap); + x0 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x0, one_over_cap_vec))); + if constexpr (kVTileSize >= 2) { + x1 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x1, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 3) { + x2 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x2, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 4) { + x3 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x3, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 5) { + x4 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x4, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 6) { + x5 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x5, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 7) { + x6 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x6, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 8) { + x7 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x7, one_over_cap_vec))); + } + } +} + +template , typename DU, + class VU = hn::Vec> +static HWY_NOINLINE void ApplyMasking( + DF df, DU du, size_t position, + const size_t* HWY_RESTRICT first_pos_per_query, + const size_t* HWY_RESTRICT last_pos_per_query, VF& x0_p0, VF& x0_p1, + VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0, + VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, + VF& x7_p1) { + VU lane_indices = hn::Iota(du, 0); + HWY_LANES_CONSTEXPR size_t kTileSize = hn::Lanes(df); + auto per_lane_pos_p0 = hn::Add(hn::Set(du, position), lane_indices); + auto per_lane_pos_p1 = + hn::Add(hn::Set(du, position + kTileSize), lane_indices); + + VF neg_inf = hn::Set(df, kNegInf); + + auto apply_mask_for_query = [&](int query_idx, VF& x_p0, VF& x_p1) HWY_ATTR { + const size_t first_pos = first_pos_per_query[query_idx]; + const size_t last_pos = last_pos_per_query[query_idx]; + + auto valid_tokens_mask_p0 = hn::Ge(per_lane_pos_p0, hn::Set(du, first_pos)); + valid_tokens_mask_p0 = hn::And( + valid_tokens_mask_p0, hn::Le(per_lane_pos_p0, hn::Set(du, last_pos))); + x_p0 = + hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p0), x_p0, neg_inf); + + auto valid_tokens_mask_p1 = hn::Ge(per_lane_pos_p1, hn::Set(du, first_pos)); + valid_tokens_mask_p1 = hn::And( + valid_tokens_mask_p1, hn::Le(per_lane_pos_p1, hn::Set(du, last_pos))); + x_p1 = + hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p1), x_p1, neg_inf); + }; + + if constexpr (kNumQueries >= 1) { + apply_mask_for_query(0, x0_p0, x0_p1); + } + if constexpr (kNumQueries >= 2) { + apply_mask_for_query(1, x1_p0, x1_p1); + } + if constexpr (kNumQueries >= 3) { + apply_mask_for_query(2, x2_p0, x2_p1); + } + if constexpr (kNumQueries >= 4) { + apply_mask_for_query(3, x3_p0, x3_p1); + } + if constexpr (kNumQueries >= 5) { + apply_mask_for_query(4, x4_p0, x4_p1); + } + if constexpr (kNumQueries >= 6) { + apply_mask_for_query(5, x5_p0, x5_p1); + } + if constexpr (kNumQueries >= 7) { + apply_mask_for_query(6, x6_p0, x6_p1); + } + if constexpr (kNumQueries >= 8) { + apply_mask_for_query(7, x7_p0, x7_p1); + } +} + +// Performs tiled flash attention for arbitrary number of queries +// It depends on kv being tiled. +// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time). +// It moves NF*2 timesteps forward in kv at a time. +// Args: +// kvs - hwy::Span of MatPtrT of shape (kvs, (tile_count, qkv_dim * +// kTileSize * 2)) This span allows to pass kv cache that is not contiguous, +// all except for the last one should have theirs row count be true, +// as it will be used to figure out when to switch to the next one. +// q_T_in_groups_up_to_4 - Span of float* All except last float* +// should have (qkv_dim, 4) Last one can have any size up to 4. +// start_pos_per_query - start position in kv to start attention from () +// last_pos_per_query - last position in kv to attend to (exclusive) +// queries_per_timestep - how many queries begin/end on the same timestep +// attention_shape - see struct definition for more details. +// att_cap - soft cap on attention logits +// att_out - MatPtrT of shape (q_count, qkv_dim) +// exp_denominator_sums and max_logits: float* of shape: +// (RountedUpTo(q_count,4),) +// Need to be have multiple of 4 elements alocated and +// be initizalized If you need to compute over multiple chunks of kv's you can +// keep values between calls to this function and avoid explicit merge. +template +HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( + const hwy::Span> kvs, int q_count, + const hwy::Span q_T_in_groups_up_to_4, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, const float att_cap, + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, + float* HWY_RESTRICT max_logits) { + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + using DU = hn::ScalableTag; + [[maybe_unused]] const DU du; + constexpr int kTileSize = gcpp::KVCache::kTileSize; + HWY_LANES_CONSTEXPR size_t kHTileSize = hn::Lanes(df); + constexpr int kNumQueriesPerGroup = 4; + constexpr int kNumQueriesPerLoop = + (!HWY_ARCH_X86 || (HWY_TARGET <= HWY_AVX3)) ? 8 : 4; + constexpr int kNumGroupsPerLoop = kNumQueriesPerLoop / kNumQueriesPerGroup; + const size_t full_groups_of_queries = q_count / kNumQueriesPerGroup; + const size_t num_loops = hwy::DivCeil(q_count, kNumQueriesPerLoop); + const size_t qkv_dim = att_out.Cols(); + HWY_DASSERT(kHTileSize <= hn::MaxLanes(df)); + HWY_LANES_CONSTEXPR size_t step_size = kHTileSize * 2; + size_t smallest_start_pos = std::numeric_limits::max(); + size_t largest_last_pos = std::numeric_limits::min(); + for (size_t i = 0; i < start_pos_per_query.size(); ++i) { + smallest_start_pos = std::min(smallest_start_pos, start_pos_per_query[i]); + largest_last_pos = std::max(largest_last_pos, last_pos_per_query[i]); + } + // start / end positions per group of 4 queries. + std::vector> pos_data(num_loops * 4); + hwy::Span min_start_pos_per_group(pos_data.data(), num_loops); + hwy::Span max_start_pos_per_group(pos_data.data() + num_loops, + num_loops); + hwy::Span min_last_pos_per_group(pos_data.data() + 2 * num_loops, + num_loops); + hwy::Span max_last_pos_per_group(pos_data.data() + 3 * num_loops, + num_loops); + + for (size_t i = 0; i < num_loops; ++i) { + size_t min_start = std::numeric_limits::max(); + size_t max_start = 0; + size_t min_last = std::numeric_limits::max(); + size_t max_last = 0; + for (int j = 0; j < kNumQueriesPerLoop; ++j) { + if (i * kNumQueriesPerLoop + j < q_count) { + min_start = std::min(min_start, + start_pos_per_query[i * kNumQueriesPerLoop + j]); + max_start = std::max(max_start, + start_pos_per_query[i * kNumQueriesPerLoop + j]); + min_last = + std::min(min_last, last_pos_per_query[i * kNumQueriesPerLoop + j]); + max_last = + std::max(max_last, last_pos_per_query[i * kNumQueriesPerLoop + j]); + } + } + min_start_pos_per_group[i] = min_start; + max_start_pos_per_group[i] = max_start; + min_last_pos_per_group[i] = min_last; + max_last_pos_per_group[i] = max_last; + } + const size_t base_pos = smallest_start_pos - (smallest_start_pos % kTileSize); + const size_t rem = smallest_start_pos % kTileSize; + const size_t num_skipped_sub_tiles = rem / step_size; + size_t position = base_pos + num_skipped_sub_tiles * step_size; + [[maybe_unused]] float one_over_cap = 1.0f / att_cap; + std::vector> att_out_per_query; + att_out_per_query.reserve(num_loops); + for (size_t i = 0; i < num_loops; ++i) { + att_out_per_query.emplace_back("att_out", + Extents2D(kNumQueriesPerLoop, qkv_dim)); + att_out_per_query.back().SetPtr(att_out.Row(i * kNumQueriesPerLoop), + att_out.Stride()); + } + size_t current_kv_start_offset = 0; + size_t current_kv_idx = 0; + + auto inner_loop = [&](int q_group_idx) HWY_ATTR { + int loop_idx = q_group_idx / (kNumQueriesPerLoop / kNumQueriesPerGroup); + if (position + step_size <= min_start_pos_per_group[loop_idx] || + position > max_last_pos_per_group[loop_idx]) { + return; + } + VF x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1; + VF x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1; + const size_t pos_in_tile = position % kTileSize; + // tile base can point to same tile as previous loop iteration, hence no + // HWY_RESTRICT + // KVs are unaligned and we only use unaligned loads in this implementation. + const KV_T* tile_base = + reinterpret_cast(kvs[current_kv_idx].RowBytes( + (position - current_kv_start_offset) / kTileSize)); + + const KV_T* v_tile = + tile_base + qkv_dim * kTileSize + (pos_in_tile)*qkv_dim; + const Q_T* q_group = q_T_in_groups_up_to_4[q_group_idx]; + const Q_T* q2_group = nullptr; + if (kNumQueries > 4) { + q2_group = q_T_in_groups_up_to_4[q_group_idx + 1]; + } + if constexpr (IsF32()) { + const KV_T* k_transposed_tile = tile_base + pos_in_tile; + QDotKTilexUpTo8TransposedKDoubleWidth( + df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, + x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, + x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } else if constexpr (IsBF16()) { + const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2; + QDotKTilexUpTo8TransposedKDoubleWidthBF16( + df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1, + x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, + x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } else { + static_assert( + false, + "Query type type not supported, only float and BF16 are supported"); + } + + constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); + constexpr int kSecondHalfAmountOfQueries = + kNumQueries - kFirstHalfAmountOfQueries; + ApplySoftCap( + df, att_cap, one_over_cap, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, + x_2_p_1, x_3_p_0, x_3_p_1); + if constexpr (kNumQueries > 4) { + ApplySoftCap( + df, att_cap, one_over_cap, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, + x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } + + if (position < max_start_pos_per_group[loop_idx] || + position + step_size - 1 > min_last_pos_per_group[loop_idx]) { + ApplyMasking( + df, du, position, + start_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup, + last_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup, + x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, + x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, + x_7_p_0, x_7_p_1); + } + HWY_ALIGN float scales[kNumQueriesPerLoop]; + // HWY_UNROLL(kNumQueriesPerLoop) + for (size_t i = 0; i < kNumQueriesPerLoop; ++i) { + scales[i] = 1.0f; + } + FlashAttentionTileStepAndApplySoftCap( + df, 0.0f, 1.0f, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, + x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, + x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx, + kNumQueriesPerGroup); + if constexpr (IsF32()) { + MulByConstAndAddTileUpTo8( + df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, + x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, + x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]); + } else if constexpr (IsBF16()) { + MulByConstAndAddTileUpTo8_BF16( + df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, + x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, + x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]); + } + }; + + while (position <= largest_last_pos) { + while (position - current_kv_start_offset >= + kvs[current_kv_idx].Rows() * kTileSize) { + current_kv_start_offset += kvs[current_kv_idx].Rows() * kTileSize; + current_kv_idx++; + } + int group_idx = 0; + for (; group_idx + kNumGroupsPerLoop <= full_groups_of_queries; + group_idx += kNumGroupsPerLoop) { + inner_loop.template operator()(group_idx); + } + if (group_idx < full_groups_of_queries) { + inner_loop.template operator()<4>(group_idx); + group_idx++; + } + switch (q_count % kNumQueriesPerGroup) { + case 1: + inner_loop.template operator()<1>(group_idx); + break; + case 2: + inner_loop.template operator()<2>(group_idx); + break; + case 3: + inner_loop.template operator()<3>(group_idx); + break; + default: + break; + } + + position += step_size; + } +} + +void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( + hwy::Span kvs, int q_count, + const hwy::Span q_T_in_groups_up_to_4, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, const float att_cap, + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, + float* HWY_RESTRICT max_logits) { + CallUpcastedKVs(kvs, [&](const auto& kv_t) { + return TileFlashAttentionReturnExpSumsAndMaxLogits( + kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query, + last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits); + }); +} + +void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + hwy::Span kvs, int q_count, + const hwy::Span q_T_in_groups_up_to_4, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, const float att_cap, + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, + float* HWY_RESTRICT max_logits) { + CallUpcastedKVs(kvs, [&](const auto& kv_t) { + return TileFlashAttentionReturnExpSumsAndMaxLogits( + kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query, + last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits); + }); +} + // Rounds n to a number that can be used as the number of Q rows in a tile // of flash attention. static size_t RoundToSuitablePowerOf2(size_t n) { diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 81bfcdf7..5529d9fd 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -22,46 +22,78 @@ #include +#include "gemma/configs.h" #include "gemma/flash_structs.h" +#include "gemma/kv_cache.h" #include "gemma/query.h" +#include "util/basics.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" #include "hwy/highway.h" namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void RMSNormAndPositionalEncoding( \ - size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ - const MatPtr& query_norm_scale, size_t layer_idx, \ - const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ - \ - void SingleFlashAttention(size_t start_pos, size_t last_pos, \ - const BF16* HWY_RESTRICT q, \ - const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, \ - const AttentionActivationsPtrs& activations, \ - float* HWY_RESTRICT att_out, \ - ThreadingContext& ctx, size_t worker); \ - \ - Tile4FlashState TileFlashAttention4( \ - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ - const MatPtrT& k, size_t start_pos, \ - const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ - size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ - const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ - ThreadingContext& ctx, const size_t worker); \ - \ - size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ - size_t total_tasks, size_t target_parallelism); \ - \ - void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, const MatPtr& query_norm_scale, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx, AttentionImpl attention_impl); \ - \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding( \ + size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ + \ + void SingleFlashAttention(size_t start_pos, size_t last_pos, \ + const BF16* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, \ + const AttentionActivationsPtrs& activations, \ + float* HWY_RESTRICT att_out, \ + ThreadingContext& ctx, size_t worker); \ + \ + Tile4FlashState TileFlashAttention4( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const MatPtrT& k, size_t start_pos, \ + const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ + size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, MatPtrT& att_out, \ + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \ + const size_t worker); \ + \ + void TileFlashAttention( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const StridedView& qT, const MatPtrT& k, \ + const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, \ + const size_t min_last_pos, const size_t max_last_pos, \ + const MatPtrT& v, const size_t layer_idx, \ + const AttentionActivationsPtrs& activations, MatPtrT& att_out, \ + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \ + const size_t worker); \ + \ + size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ + size_t total_tasks, size_t target_parallelism); \ + \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const MatPtr& query_norm_scale, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx, AttentionImpl attention_impl); \ + \ + void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( \ + hwy::Span kvs, int q_count, \ + const hwy::Span q_T_in_groups_up_to_4, \ + hwy::Span start_pos_per_query, \ + hwy::Span last_pos_per_query, const float att_cap, \ + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ + float* HWY_RESTRICT max_logits); \ + \ + void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( \ + hwy::Span kvs, int q_count, \ + const hwy::Span q_T_in_groups_up_to_4, \ + hwy::Span start_pos_per_query, \ + hwy::Span last_pos_per_query, const float att_cap, \ + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ + float* HWY_RESTRICT max_logits); \ + \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 33702b88..bbb63f5a 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -181,6 +181,298 @@ void TestAttention() { TestFlashAttention(256); } +const std::vector exp_denominator_sums_gold = { + 58.722088f, 58.445938f, 58.17153f, 57.89886f, + 58.580994f, 58.302643f, 58.026085f, 57.751308f}; +const std::vector max_logits_gold = { + 0.009613638f, 0.019227259f, 0.02884084f, 0.038454376f, + 0.04888253f, 0.058658823f, 0.06843502f, 0.078211054f}; +const std::vector att_out_gold = { + 0.600945, 0.300472, 0.200315, 0.150236, 0.120189, 0.100158, 0.085849, + 0.075118, 0.066772, 0.060095, 0.054631, 0.050079, 0.046227, 0.042925, + 0.040063, 0.037559, 0.035350, 0.033386, 0.031629, 0.030047, 0.028616, + 0.027316, 0.026128, 0.025039, 0.024038, 0.023113, 0.022257, 0.021462, + 0.020722, 0.020032, 0.019385, 0.018780, 0.018210, 0.017675, 0.017170, + 0.016693, 0.016242, 0.015814, 0.015409, 0.015024, 0.014657, 0.014308, + 0.013975, 0.013658, 0.013354, 0.013064, 0.012786, 0.012520, 0.012264, + 0.012019, 0.011783, 0.011557, 0.011339, 0.011129, 0.010926, 0.010731, + 0.010543, 0.010361, 0.010186, 0.010016, 0.009852, 0.009693, 0.009539, + 0.009390, 0.601890, 0.300945, 0.200630, 0.150473, 0.120378, 0.100315, + 0.085984, 0.075236, 0.066877, 0.060189, 0.054717, 0.050158, 0.046299, + 0.042992, 0.040126, 0.037618, 0.035405, 0.033438, 0.031678, 0.030095, + 0.028661, 0.027359, 0.026169, 0.025079, 0.024076, 0.023150, 0.022292, + 0.021496, 0.020755, 0.020063, 0.019416, 0.018809, 0.018239, 0.017703, + 0.017197, 0.016719, 0.016267, 0.015839, 0.015433, 0.015047, 0.014680, + 0.014331, 0.013997, 0.013679, 0.013375, 0.013085, 0.012806, 0.012539, + 0.012283, 0.012038, 0.011802, 0.011575, 0.011356, 0.011146, 0.010943, + 0.010748, 0.010559, 0.010377, 0.010202, 0.010032, 0.009867, 0.009708, + 0.009554, 0.009405, 0.602835, 0.301418, 0.200945, 0.150709, 0.120567, + 0.100473, 0.086119, 0.075354, 0.066982, 0.060284, 0.054803, 0.050236, + 0.046372, 0.043060, 0.040189, 0.037677, 0.035461, 0.033491, 0.031728, + 0.030142, 0.028706, 0.027402, 0.026210, 0.025118, 0.024113, 0.023186, + 0.022327, 0.021530, 0.020787, 0.020095, 0.019446, 0.018839, 0.018268, + 0.017730, 0.017224, 0.016745, 0.016293, 0.015864, 0.015457, 0.015071, + 0.014703, 0.014353, 0.014019, 0.013701, 0.013396, 0.013105, 0.012826, + 0.012559, 0.012303, 0.012057, 0.011820, 0.011593, 0.011374, 0.011164, + 0.010961, 0.010765, 0.010576, 0.010394, 0.010218, 0.010047, 0.009883, + 0.009723, 0.009569, 0.009419, 0.603780, 0.301890, 0.201260, 0.150945, + 0.120756, 0.100630, 0.086254, 0.075473, 0.067087, 0.060378, 0.054889, + 0.050315, 0.046445, 0.043127, 0.040252, 0.037736, 0.035516, 0.033543, + 0.031778, 0.030189, 0.028751, 0.027445, 0.026251, 0.025158, 0.024151, + 0.023222, 0.022362, 0.021564, 0.020820, 0.020126, 0.019477, 0.018868, + 0.018296, 0.017758, 0.017251, 0.016772, 0.016318, 0.015889, 0.015482, + 0.015095, 0.014726, 0.014376, 0.014041, 0.013722, 0.013417, 0.013126, + 0.012846, 0.012579, 0.012322, 0.012076, 0.011839, 0.011611, 0.011392, + 0.011181, 0.010978, 0.010782, 0.010593, 0.010410, 0.010234, 0.010063, + 0.009898, 0.009738, 0.009584, 0.009434, 0.614887, 0.307443, 0.204962, + 0.153722, 0.122977, 0.102481, 0.087841, 0.076861, 0.068321, 0.061489, + 0.055899, 0.051241, 0.047299, 0.043920, 0.040992, 0.038430, 0.036170, + 0.034160, 0.032362, 0.030744, 0.029280, 0.027949, 0.026734, 0.025620, + 0.024595, 0.023649, 0.022774, 0.021960, 0.021203, 0.020496, 0.019835, + 0.019215, 0.018633, 0.018085, 0.017568, 0.017080, 0.016619, 0.016181, + 0.015766, 0.015372, 0.014997, 0.014640, 0.014300, 0.013975, 0.013664, + 0.013367, 0.013083, 0.012810, 0.012549, 0.012298, 0.012057, 0.011825, + 0.011602, 0.011387, 0.011180, 0.010980, 0.010787, 0.010601, 0.010422, + 0.010248, 0.010080, 0.009918, 0.009760, 0.009608, 0.615864, 0.307932, + 0.205288, 0.153966, 0.123173, 0.102644, 0.087981, 0.076983, 0.068429, + 0.061586, 0.055988, 0.051322, 0.047374, 0.043990, 0.041058, 0.038491, + 0.036227, 0.034215, 0.032414, 0.030793, 0.029327, 0.027994, 0.026777, + 0.025661, 0.024635, 0.023687, 0.022810, 0.021995, 0.021237, 0.020529, + 0.019867, 0.019246, 0.018663, 0.018114, 0.017596, 0.017107, 0.016645, + 0.016207, 0.015791, 0.015397, 0.015021, 0.014663, 0.014322, 0.013997, + 0.013686, 0.013388, 0.013103, 0.012830, 0.012569, 0.012317, 0.012076, + 0.011844, 0.011620, 0.011405, 0.011198, 0.010998, 0.010805, 0.010618, + 0.010438, 0.010264, 0.010096, 0.009933, 0.009776, 0.009623, 0.616841, + 0.308421, 0.205614, 0.154210, 0.123368, 0.102807, 0.088120, 0.077105, + 0.068538, 0.061684, 0.056076, 0.051403, 0.047449, 0.044060, 0.041123, + 0.038553, 0.036285, 0.034269, 0.032465, 0.030842, 0.029373, 0.028038, + 0.026819, 0.025702, 0.024674, 0.023725, 0.022846, 0.022030, 0.021270, + 0.020561, 0.019898, 0.019276, 0.018692, 0.018142, 0.017624, 0.017134, + 0.016671, 0.016233, 0.015816, 0.015421, 0.015045, 0.014687, 0.014345, + 0.014019, 0.013708, 0.013410, 0.013124, 0.012851, 0.012589, 0.012337, + 0.012095, 0.011862, 0.011639, 0.011423, 0.011215, 0.011015, 0.010822, + 0.010635, 0.010455, 0.010281, 0.010112, 0.009949, 0.009791, 0.009638, + 0.617818, 0.308909, 0.205939, 0.154455, 0.123564, 0.102970, 0.088260, + 0.077227, 0.068646, 0.061782, 0.056165, 0.051485, 0.047524, 0.044130, + 0.041188, 0.038614, 0.036342, 0.034323, 0.032517, 0.030891, 0.029420, + 0.028083, 0.026862, 0.025742, 0.024713, 0.023762, 0.022882, 0.022065, + 0.021304, 0.020594, 0.019930, 0.019307, 0.018722, 0.018171, 0.017652, + 0.017162, 0.016698, 0.016258, 0.015841, 0.015445, 0.015069, 0.014710, + 0.014368, 0.014041, 0.013729, 0.013431, 0.013145, 0.012871, 0.012609, + 0.012356, 0.012114, 0.011881, 0.011657, 0.011441, 0.011233, 0.011032, + 0.010839, 0.010652, 0.010471, 0.010297, 0.010128, 0.009965, 0.009807, + 0.009653}; + +void TestTiledFlashAttention() { + int qkv_dim = 64; + int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by + // tiles size to test the padding logic. + int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + float att_cap = 10.0f; + int num_queries = 8; + int num_queries_per_timestep = 4; + int num_tokens = num_queries / num_queries_per_timestep; + int kv_seq_end = + kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + MatStorageT kv( + "kv", + Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize), + ctx.allocator, MatPadding::kPacked); + // fill in kvs with predictable, synthetic data + for (int i = 0; i < padded_kv_seq_len; ++i) { + for (int j = 0; j < qkv_dim; ++j) { + const int tile_idx = i / gcpp::KVCache::kTileSize; + const int in_tile_offset = i % gcpp::KVCache::kTileSize; + const float val_k = 0.01f * (i + 1) / (j + 1); + const float val_v = 0.02f * (i + 1) / (j + 1); + kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset] = val_k; + const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; + kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j] = val_v; + } + } + std::vector q_float(4 * qkv_dim); + std::vector q_float2(4 * qkv_dim); + // fill in qs with predictable, synthetic data + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < qkv_dim; j++) { + float val_1 = 0.01f * (i + 1) / (j + 1); + float val_2 = 0.01f * (i + 4 + 1) / (j + 1); + q_float[j * 4 + i] = val_1; + q_float2[j * 4 + i] = val_2; + } + } + const float* q_T[2] = {q_float.data(), q_float2.data()}; + + MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), + ctx.allocator, MatPadding::kPacked); + using DF = hn::ScalableTag; + const DF df; + HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); + size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); + std::vector exp_denominator_sums(num_queries_rounded_to_laness); + std::vector max_logits(num_queries_rounded_to_laness); + for (size_t i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out.Row(i), + qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); + exp_denominator_sums[i] = 0.0f; + max_logits[i] = -std::numeric_limits::max() / 2.0f; + } + std::vector> start_pos_per_query; + std::vector> last_pos_per_query; + start_pos_per_query.reserve(num_queries); + last_pos_per_query.reserve(num_queries); + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + ssize_t query_last_pos = kv_seq_end + token_idx; + ssize_t query_start_pos = + std::max(query_last_pos - 100000 + 1, static_cast(0)); + for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + ++q_head_idx) { + start_pos_per_query.push_back(query_start_pos); + last_pos_per_query.push_back(query_last_pos); + } + } + + hwy::Span kvs(&kv, 1); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( + kvs, num_queries, hwy::Span(q_T, 2), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + + // TODO: Replace with Other implementation for generating goldens. + // Current values are taken from a point in time where code was run with gemma + // and output looked good. Not ideal but should be good enough to test the + // plumbing and detect regressions. + PrintMatPtr(att_out); + for (int i = 0; i < num_queries; ++i) { + std::cerr << "exp_d: " << exp_denominator_sums[i] + << " max_logit: " << max_logits[i] << std::endl; + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-4f) + << "i=" << i; + EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i; + for (int j = 0; j < qkv_dim; ++j) { + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-6f); + } + } +} + +void TestTiledFlashAttentionBF16() { + int qkv_dim = 64; + int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by + // tiles size to test the padding logic. + int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + float att_cap = 10.0f; + int num_queries = 8; + int num_queries_per_timestep = 4; + int num_tokens = num_queries / num_queries_per_timestep; + int kv_seq_end = + kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + MatStorageT kv( + "kv", + Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize), + ctx.allocator, MatPadding::kPacked); + // fill in kvs with predictable, synthetic data + for (int i = 0; i < padded_kv_seq_len; i++) { + for (int j = 0; j < qkv_dim; j+=2) { + const int tile_idx = i / gcpp::KVCache::kTileSize; + const int in_tile_offset = i % gcpp::KVCache::kTileSize; + const float val_k_1 = 0.01f * (i + 1) / (j + 1); + const float val_k_2 = 0.01f * (i + 1) / (j + 2); + kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2] = + hwy::ConvertScalarTo(val_k_1); + kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2 + 1] = + hwy::ConvertScalarTo(val_k_2); + } + } + const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; + for (int i = 0; i < padded_kv_seq_len; i += 2) { + for (int j = 0; j < qkv_dim; j++) { + const int tile_idx = i / gcpp::KVCache::kTileSize; + const int in_tile_offset = i % gcpp::KVCache::kTileSize; + const float val_v_1 = 0.02f * (i + 1) / (j + 1); + const float val_v_2 = 0.02f * (i + 2) / (j + 1); + kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2] = + hwy::ConvertScalarTo(val_v_1); + kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2 + 1] = + hwy::ConvertScalarTo(val_v_2); + } + } + + std::vector q_float(num_queries_per_timestep * qkv_dim); + std::vector q_float2(num_queries_per_timestep * qkv_dim); + // fill in qs with predictable, synthetic data + for (int i = 0; i < num_queries_per_timestep; ++i) { + for (int j = 0; j < qkv_dim; j += 2) { + q_float[j * num_queries_per_timestep + i * 2] = + hwy::ConvertScalarTo(0.01f * (i + 1) / (j + 1)); + q_float[j * num_queries_per_timestep + i * 2 + 1] = + hwy::ConvertScalarTo(0.01f * (i + 1) / (j + 2)); + + q_float2[j * num_queries_per_timestep + i * 2] = + hwy::ConvertScalarTo( + 0.01f * (i + num_queries_per_timestep + 1) / (j + 1)); + q_float2[j * num_queries_per_timestep + i * 2 + 1] = + hwy::ConvertScalarTo( + 0.01f * (i + num_queries_per_timestep + 1) / (j + 2)); + } + } + const BF16* q_T[2] = {q_float.data(), q_float2.data()}; + + MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), + ctx.allocator, MatPadding::kPacked); + + HWY_LANES_CONSTEXPR size_t lanes = 4; + size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); + std::vector exp_denominator_sums(num_queries_rounded_to_laness); + std::vector max_logits(num_queries_rounded_to_laness); + for (size_t i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out.Row(i), + qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); + exp_denominator_sums[i] = 0.0f; + max_logits[i] = -std::numeric_limits::max() / 2.0f; + } + std::vector> start_pos_per_query; + std::vector> last_pos_per_query; + start_pos_per_query.reserve(num_queries); + last_pos_per_query.reserve(num_queries); + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + ssize_t query_last_pos = kv_seq_end + token_idx; + ssize_t query_start_pos = + std::max(query_last_pos - 100000 + 1, static_cast(0)); + for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + ++q_head_idx) { + start_pos_per_query.push_back(query_start_pos); + last_pos_per_query.push_back(query_last_pos); + } + } + hwy::Span kvs(&kv, 1); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + kvs, num_queries, hwy::Span(q_T, 2), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + + // TODO: Replace with Other implementation for generating goldens. + // Current values are taken from a point in time where code was run with gemma + // and output looked good. Not ideal but should be good enough to test the + // plumbing and detect regressions. + PrintMatPtr(att_out); + for (int i = 0; i < num_queries; ++i) { + std::cerr << "exp_d: " << exp_denominator_sums[i] + << " max_logit: " << max_logits[i] << std::endl; + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f) + << "i=" << i; + EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; + for (int j = 0; j < qkv_dim; ++j) { + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f); + } + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 2450af8c..90bbca3f 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -42,7 +42,8 @@ // After highway.h #include "gemma/attention.h" // includes highway.h #include "gemma/gemma-inl.h" -#include "gemma/vit.h" // includes highway.h +#include "gemma/tiled_attention.h" // includes highway.h +#include "gemma/vit.h" // includes highway.h #ifndef GEMMA_CC_ONCE #define GEMMA_CC_ONCE @@ -80,6 +81,14 @@ namespace HWY_NAMESPACE { void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { + if (activations.attention_impl == AttentionImpl::kFlashTransposedQs || + activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + TiledAttention( + activations.attention_impl, num_tokens, layer_idx, layer, + activations.attention, qbatch, env, + AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16)); + return; + } if (type == LayerAttentionType::kGemma) { // TODO: remove flag to enable FlashAttention. diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 6ccb5b38..b7cfcb22 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -148,6 +148,14 @@ struct RuntimeConfig { // Which attention implementation to use. AttentionImpl attention_impl = AttentionImpl::kFlash; + // Right now it only work for tiled kv cache, implementations. + // If not set, it will be set based on the attention_impl. + // F32 for tiled + // BF16 for tiled bf16 + // If you want to use type other than F32 or BF16, you might need to update + // call upcasted. + std::optional kv_cache_type = {}; + // Functions operating on the generated tokens. StreamFunc stream_token; BatchStreamFunc batch_stream_token; diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 49276f83..d225f526 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -51,10 +51,72 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), allocator) {} +KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, + const RuntimeConfig& runtime_config, + const Allocator& allocator) + : allocator_(allocator) { + if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs || + runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 + ) { + const size_t num_tiles = + hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize); + tiled_seq_len = num_tiles * kTileSize; + int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize; + Type kv_cache_type; + if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 + ) { + kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16); + } else { + kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32); + } + auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size, + size_t max_seq_len) { + return hwy::DivCeil( + std::min(max_seq_len, window_size + prefill_tbatch_size), kTileSize); + }; + + size_t total_num_tiles = 0; + for (size_t window_size : config.attention_window_sizes) { + total_num_tiles += + num_tiles_per_head(window_size, runtime_config.prefill_tbatch_size, + config.max_seq_len) * + config.layer_configs[0].kv_heads; + } + Extents2D extents(total_num_tiles, tile_length); + compact_kv_cache_ptr = MatPtr("kv_tiled", kv_cache_type, extents); + compact_kv_cache.AllocateFor(compact_kv_cache_ptr, allocator, + MatPadding::kPacked); + total_num_tiles = 0; + kv_head_ptrs.reserve(config.attention_window_sizes.size() * + config.layer_configs[0].kv_heads); + for (size_t window_size : config.attention_window_sizes) { + for (size_t kv = 0; kv < config.layer_configs[0].kv_heads; ++kv) { + size_t num_tiles_per_kv_head = + num_tiles_per_head(window_size, runtime_config.prefill_tbatch_size, + config.max_seq_len); + MatPtr kv_ptr("kv_ptr", kv_cache_type, + Extents2D(num_tiles_per_kv_head, tile_length)); + kv_ptr.SetPtr(compact_kv_cache_ptr.RowBytes(total_num_tiles), + compact_kv_cache_ptr.Stride()); + kv_head_ptrs.emplace_back(std::move(kv_ptr)); + total_num_tiles += num_tiles_per_kv_head; + } + } + } else { + kv_cache = MatStorageT( + "kv", + Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), + allocator, MatPadding::kOdd); + } +} + KVCache KVCache::Copy() { KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); + + CopyMat(compact_kv_cache_ptr, copy.compact_kv_cache_ptr); + copy.tiled_seq_len = tiled_seq_len; return copy; } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index fe6a1ff9..91b6b7f3 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -31,31 +31,103 @@ namespace gcpp { using KV_t = float; +struct KVCache; // A non-owning view of a KVCache. struct KVCachePtr { bool IsEmpty() const { return kv_cache.Rows() == 0; } size_t SeqLen() const; + bool IsTiled() const; MatPtrT kv_cache; + KVCache* cache = nullptr; }; struct KVCache { KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator); + KVCache(const ModelConfig& config, const InferenceArgs& inference_args, + const RuntimeConfig& runtime_config, const Allocator& allocator); // Returns a deep copy of the KVCache. Use explicit function instead of // copy ctor to make the cost explicit. KVCache Copy(); size_t SeqLen() const { + if (IsTiled()) { + return tiled_seq_len.value(); + } return kv_cache.Rows(); } + bool IsTiled() const { + return tiled_seq_len.has_value(); + } + + // This function returns a vector of pointers and handles wraparound for local + // layers. + // You can use this function to get kv's, + // it will slice internal circular buffer and give you parts of it that are in + // order. Keep in mind that this gives out pointers to tiles, and for local + // layers start_pos might be in a middle of the first tile. At start_pos % + // kTileSize + std::vector GetPointers(int layer_idx, int kv_head_idx, + int num_kv_heads, int start_pos, + bool is_global_layer) { + if (!IsTiled()) { + HWY_ABORT("This function is only meant to be used with tiled KV caches."); + } + MatPtr& source_ptr = kv_head_ptrs[layer_idx * num_kv_heads + kv_head_idx]; + if (is_global_layer) { + return {source_ptr}; + } + size_t start_tile_mod_window = (start_pos / kTileSize) % source_ptr.Rows(); + size_t start_len = source_ptr.Rows() - start_tile_mod_window; + MatPtr start_ptr("kv_start", source_ptr.GetType(), + Extents2D(start_len, source_ptr.Cols())); + start_ptr.SetPtr(source_ptr.RowBytes(start_tile_mod_window), + source_ptr.Cols()); + return {start_ptr, source_ptr}; + } + + static constexpr size_t kTileSize = 32; + std::optional tiled_seq_len = std::nullopt; + // Default Format + // If tiled_seq_len is not set, then the kv_cache is assumed to be [seq_len, + // layers * kv_heads * qkv_dim * 2]. + // + // Tiled Format + // If tiled_seq_len is set, the kv cache is stored in tiled format. + // Allocations must happen in full tiles. + // The order of dimensions on rows is: [layer, kv_head, tile]. + // The total number of rows is: + // num_layers * num_kv_heads * (tiled_seq_len / kTileSize). + // Each tile (containing kTileSize elements from the sequence) can be thought + // of as storing K^T and V, where K is shaped [kTileSize, qkv_dim]. + + // Type erased kv cache. It's compact because local layers are allocated as + // circular buffers. + MatPtr compact_kv_cache_ptr; + MatOwner compact_kv_cache; + // Pointers to the raw KV storage indexed by layer and head. This helps + // accessing the tiles even though different layers may have a different + // number of tiles in storage. All pointers point into compact_kv_cache. + + // To access the tiles of (layer_idx, head_idx), index the array with + // layer_idx * num_kv_heads + kv_head_idx. + // Or use GetPointers function. + + // The returned MatPtr will have one tile per row. The number of rows for + // global layers is max_seq_len/kTileSize. For local layers it is slightly + // more than attention_window_size[layer_idx] / kTileSize. For local layers, a + // given token_idx is in row (token_idx / kTileSize) % + // kv_head_ptrs[...].Rows(). + std::vector kv_head_ptrs; MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] KVCachePtr ToPtr() { return KVCachePtr{ .kv_cache = kv_cache, + .cache = this, }; } @@ -67,9 +139,17 @@ struct KVCache { }; inline size_t KVCachePtr::SeqLen() const { + if (IsTiled()) { + return cache->tiled_seq_len.value(); + } return kv_cache.Rows(); } +inline bool KVCachePtr::IsTiled() const { + // MPU code create a KVCachePtr without kv_cache. + return cache != nullptr && cache->tiled_seq_len.has_value(); +} + // Convenience function to create views into KVCaches. std::vector ToKVCachePtrs(const hwy::Span& kv_caches); diff --git a/gemma/tiled_attention.cc b/gemma/tiled_attention.cc new file mode 100644 index 00000000..c36828f9 --- /dev/null +++ b/gemma/tiled_attention.cc @@ -0,0 +1,660 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "compression/compress.h" +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "gemma/configs.h" +#include "gemma/gemma.h" +#include "gemma/kv_cache.h" +#include "ops/matmul.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +// Note: HWY_DISABLED_TARGETS needs to be defined the same everywhere. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include "util/basics.h" +#include "util/mat.h" +#include "util/threading_context.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/tiled_attention.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "gemma/attention.h" +#include "gemma/flash_attention.h" // includes highway.h +#include "gemma/gemma-inl.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +static HWY_INLINE void MergeOnlineSoftmax( + const float* HWY_RESTRICT other_att_out, const float other_softmax_max, + const float other_softmax_d, int qkv_dim, + float* HWY_RESTRICT accumulator_att_out, float& accumulator_softmax_max, + float& accumulator_softmax_d) { + if (other_softmax_d == 0.0f) { + return; + } + if (accumulator_softmax_d == 0.0f) { + memcpy(accumulator_att_out, other_att_out, + qkv_dim * sizeof(*accumulator_att_out)); + accumulator_softmax_max = other_softmax_max; + accumulator_softmax_d = other_softmax_d; + return; + } + const float m_new = std::max(accumulator_softmax_max, other_softmax_max); + const float exp_l = std::exp(accumulator_softmax_max - m_new); + const float exp_r = std::exp(other_softmax_max - m_new); + const float d_new = accumulator_softmax_d * exp_l + other_softmax_d * exp_r; + const float d_new_inv = 1.0f / d_new; + const float c1 = accumulator_softmax_d * exp_l * d_new_inv; + const float c2 = other_softmax_d * exp_r * d_new_inv; + MulByConst(c1, accumulator_att_out, qkv_dim); + MulByConstAndAdd(c2, other_att_out, accumulator_att_out, qkv_dim); + accumulator_softmax_max = m_new; + accumulator_softmax_d = d_new; +} + +// Forked from ComputeQKV. But it stores the K/V in the tiled format +// KV_T is type stored in the KV cache (typically float or BF16). +template +static HWY_INLINE void ComputeQKVTransposedTile( + size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, + AttentionImpl attention_impl, AttentionActivationsPtrs& activations, + const QBatch& qbatch, const int flags, MatMulEnv& env) { + PROFILER_ZONE("Gen.Attention.QKVTiled"); + const hwy::Divisor div_qbatch(qbatch.Size()); + const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); + const LayerConfig& layer_config = layer.layer_config; + const size_t qkv_dim = layer_config.qkv_dim; + const size_t kv_heads = layer_config.kv_heads; + + // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, + // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. + // This computes Q and stores it in activations.q. + // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, + // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. + // This computes Q and stores it in activations.q. + CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1, + /*add=*/nullptr, env, activations.q); + + // Compute the combined KV output from pre_att_rms_out. + // The output shape is [num_interleaved, kv_heads * 2 * qkv_dim]. + const size_t kv_out_cols = kv_heads * 2 * qkv_dim; + hwy::AlignedFreeUniquePtr kv_out_mem = + hwy::AllocateAligned(num_interleaved * kv_out_cols); + float* kv_out_data = kv_out_mem.get(); + MatPtrT kv_out_mat("kv_out", Extents2D(num_interleaved, kv_out_cols)); + kv_out_mat.SetPtr(kv_out_data, kv_out_cols); + CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, + /*add=*/nullptr, env, kv_out_mat); + + // Apply positional encodings and store K/V in tiled format. + hwy::Divisor div_kv_heads(kv_heads); + + hn::ScalableTag df; + static hwy::Divisor tile_size_divisor(KVCache::kTileSize); + ParallelFor( + Parallelism::kFlat, kv_heads * qbatch.Size(), env.ctx, + /*cluster_idx=*/0, Callers::kAttComputeQKV, + [&](size_t task, size_t worker) HWY_ATTR { + const size_t kv_head = div_kv_heads.Remainder(task); + const size_t query_idx = div_kv_heads.Divide(task); + CompressPerThread tls; + size_t current_token_idx = 0; + float* k_tile_vec = activations.k_tile_vec.Row(task); + float* v_tile_vec = activations.v_tile_vec.Row(task); + HWY_ALIGN float k_f32[kMaxQKVDim]; + const size_t start_pos = qbatch.Pos(query_idx); + const bool is_global_layer = + activations.config.IsGlobalLayer(layer_idx); + std::vector kv_ptrs = + qbatch.KV(query_idx).cache->GetPointers( + layer_idx, kv_head, kv_heads, start_pos, is_global_layer); + size_t tile_offset = 0; + if (!is_global_layer) { + tile_offset = start_pos / KVCache::kTileSize; + } + + while (current_token_idx < num_tokens) { + const size_t pos = start_pos + current_token_idx; + const size_t pos_mod = activations.div_seq_len.Remainder(pos); + const size_t tile_idx = tile_size_divisor.Divide(pos_mod); + const size_t relative_tile_idx = tile_idx - tile_offset; + KV_T* tile_ptr; + int kv_ptr_idx = 0; + size_t absolute_rows = 0; + while (absolute_rows + kv_ptrs[kv_ptr_idx].Rows() <= + relative_tile_idx) { + absolute_rows += kv_ptrs[kv_ptr_idx].Rows(); + kv_ptr_idx++; + } + tile_ptr = HWY_RCAST_ALIGNED( + KV_T*, + kv_ptrs[kv_ptr_idx].RowBytes(relative_tile_idx - absolute_rows)); + PackedSpan tile_packed_span{tile_ptr, + 2 * qkv_dim * KVCache::kTileSize}; + + DecompressAndZeroPad(df, tile_packed_span, 0, k_tile_vec, + qkv_dim * KVCache::kTileSize); + DecompressAndZeroPad(df, tile_packed_span, + qkv_dim * KVCache::kTileSize, v_tile_vec, + qkv_dim * KVCache::kTileSize); + + size_t token_in_tile_idx = current_token_idx; + while (token_in_tile_idx < num_tokens) { + const size_t current_pos = + qbatch.Pos(query_idx) + token_in_tile_idx; + const size_t current_pos_mod = + activations.div_seq_len.Remainder(current_pos); + if (tile_size_divisor.Divide(current_pos_mod) != tile_idx) { + break; // Moved to next tile + } + + const float* kv_row = + kv_out_data + + (token_in_tile_idx * qbatch.Size() + query_idx) * kv_out_cols; + const float* k_ptr = kv_row + kv_head * 2 * qkv_dim; + const float* v_ptr = kv_row + kv_head * 2 * qkv_dim + qkv_dim; + hwy::CopyBytes(k_ptr, k_f32, qkv_dim * sizeof(float)); + if (layer.key_norm_scale.HasPtr()) { + CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, k_f32, + qkv_dim, env.ctx, worker); + }); + } + PositionalEncodingQK( + k_f32, layer_idx, activations, env.ctx, worker, + current_pos , + /*mul=*/1.0f); + + const size_t in_tile_idx = current_pos_mod % KVCache::kTileSize; + if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + const int in_tile_idx_mod_2 = in_tile_idx % 2; + for (int dim = 0; dim < qkv_dim; dim += 2) { + const int dim_mod_2 = dim % 2; + // Pack k's in pairs in preparation for BF16 dot product. + // See flash_attention.cc + // QDotKTilexUpTo4TransposedKDoubleWidthBF16 + k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize + + in_tile_idx * 2] = k_f32[dim]; + k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize + + in_tile_idx * 2 + 1] = k_f32[dim + 1]; + // Pack v's in pairs + v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim + + dim * 2 + in_tile_idx_mod_2] = v_ptr[dim]; + v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim + + (dim + 1) * 2 + in_tile_idx_mod_2] = v_ptr[dim + 1]; + } + + } else { + for (int i = 0; i < qkv_dim; ++i) { + k_tile_vec[i * KVCache::kTileSize + in_tile_idx] = k_f32[i]; + } + Compress(v_ptr, qkv_dim, tls, tile_packed_span, + qkv_dim * (KVCache::kTileSize + in_tile_idx)); + } + + token_in_tile_idx++; + } + Compress(k_tile_vec, qkv_dim * KVCache::kTileSize, tls, + tile_packed_span, 0); + if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + Compress(v_tile_vec, qkv_dim * KVCache::kTileSize, tls, + tile_packed_span, qkv_dim * KVCache::kTileSize); + } + current_token_idx = token_in_tile_idx; + } + }); +} + +// TODO: optimize with gathers +// This format might change in the future, when kernel will be updated to +// support more than 8 queries. +// Input (num_queries, qkv_dim) +// Output (qkv_dim, num_queries) +void TransposeQ(const MatPtrT& queries, + hwy::Span transposed_queries_span) { + const size_t qkv_dim = queries.Cols(); + const size_t num_queries = queries.Rows(); + HWY_ASSERT(transposed_queries_span.size() == num_queries * qkv_dim); + for (size_t i = 0; i < qkv_dim; i++) { + for (size_t j = 0; j < num_queries; ++j) { + transposed_queries_span[i * num_queries + j] = queries.Row(j)[i]; + } + } +} + +// Transposes queries +// Input: vector of pointers to subsequent queries. (allows for arbitrary +// strides) +// qkv_dim: dimension of query +// allocator: aligned allocator to use for temporary storage +// +// Output: Pointer to contiguous memory with shape (qkv_dim, +// queries.size()) +void TransposeStridedQueries( + hwy::Span queries, int qkv_dim, + hwy::Span transposed_queries) { + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + const DF df; + using VF = hn::Vec; + using DI = hn::ScalableTag; + const DI di; + using VI = hn::Vec; + const size_t lanes = hn::Lanes(df); + const size_t num_queries = queries.size(); + const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, lanes); + std::vector> query_offsets( + num_queries_rounded_up); + for (size_t i = 0; i < num_queries; ++i) { + query_offsets[i] = queries[i] - queries[0]; + } + for (size_t i = num_queries; i < num_queries_rounded_up; ++i) { + // last offset is the same so gather doesn't read out of bounds + query_offsets[i] = query_offsets[num_queries - 1]; + } + + for (size_t i = 0; i < qkv_dim; i++) { + size_t j = 0; + if (num_queries >= lanes) { + for (; j <= num_queries-lanes; j += lanes) { + const VI offsets = hn::LoadU(di, query_offsets.data() + j); + VF x = hn::GatherIndex(df, queries[0] + i, offsets); + hn::StoreU(x, df, transposed_queries.data() + i * num_queries + j); + } + } + if (j < num_queries) { + const VI offsets = hn::LoadU(di, query_offsets.data() + j); + VF x = hn::GatherIndex(df, queries[0] + i, offsets); + hn::StoreN(x, df, transposed_queries.data() + i * num_queries + j, + num_queries - j); + } + } +} + +std::pair> TransposeQueriesToGroupsOf4( + hwy::Span queries_ptrs, int qkv_dim) { + int num_queries = queries_ptrs.size(); + int num_groups = hwy::DivCeil(num_queries, 4); + AlignedFloatVector transposed_queries(num_groups * 4 * qkv_dim); + std::vector transposed_queries_ptrs; + for (int group_idx = 0; group_idx < num_groups; ++group_idx){ + int group_size = std::min(4, num_queries - group_idx * 4); + transposed_queries_ptrs.push_back(transposed_queries.data() + + group_idx * qkv_dim * 4); + TransposeStridedQueries( + hwy::Span(queries_ptrs.data() + group_idx * 4, + group_size), + qkv_dim, + hwy::Span(transposed_queries_ptrs.back(), qkv_dim * group_size)); + } + return std::make_pair(std::move(transposed_queries), + std::move(transposed_queries_ptrs)); +} + +std::pair> +TransposeTransposedQueriesAndPackIntoBF16(hwy::Span queries_ptrs, + int qkv_dim, int num_queries) { + constexpr int kMaxGroupSize = 4; + int num_groups = queries_ptrs.size(); + AlignedBF16Vector transposed_queries(num_groups * kMaxGroupSize * qkv_dim); + std::vector transposed_queries_ptrs; + transposed_queries_ptrs.reserve(num_groups); + for (int group_idx = 0; group_idx < num_groups; ++group_idx) { + int group_size = + std::min(kMaxGroupSize, num_queries - group_idx * kMaxGroupSize); + transposed_queries_ptrs.push_back(transposed_queries.data() + + group_idx * qkv_dim * kMaxGroupSize); + for (int dim_idx = 0; dim_idx < qkv_dim; dim_idx += 2) { + for (int query_idx = 0; query_idx < group_size; ++query_idx) { + transposed_queries_ptrs.back()[dim_idx * group_size + query_idx * 2] = + hwy::ConvertScalarTo( + queries_ptrs[group_idx][dim_idx * group_size + query_idx]); + transposed_queries_ptrs + .back()[dim_idx * group_size + query_idx * 2 + 1] = + hwy::ConvertScalarTo( + queries_ptrs[group_idx] + [(dim_idx + 1) * group_size + query_idx]); + } + } + } + return std::make_pair(std::move(transposed_queries), + std::move(transposed_queries_ptrs)); +} + +template +static HWY_INLINE void MaybeResizeMatStorage(MatStorageT& mat_storage, + int rows, int cols, + const char* name, + const Allocator& allocator) { + if (mat_storage.Rows() != rows || mat_storage.Cols() != cols) { + mat_storage = MatStorageT(name, Extents2D(rows, cols), allocator, + MatPadding::kOdd); + } +} + +// clang-format off +// Schedules TiledFlashAttention for all heads, tokens and batch. +// Returns partial results in the same order as queries in `activations.q`. +// Might not work yet for prefix lm. +// To help understanding how to use this function below is description of how +// parameters are used: +// +// attention_impl - Used to determine attention kernel to use. +// num_query_tokens - number of tokens/timesteps in processed in a single batch +// it will influence how many queries kvs are evaluated against. +// num_kv_tokens - number of tokens/timesteps in kv cache +// layer_idx - layer index +// layer - used to get kv_heads, heads, qkv_dim +// activations - reads: activations.q queries, att_cap, IsGlobalLayer +// qbatch - kv cache, Pos / EndPrefix +// ctx - threading context +// clang-format on +void LocalAttentionForAllHeadsTokensAndBatch( + AttentionImpl attention_impl, const size_t num_query_tokens, + const size_t layer_idx, const LayerWeightsPtrs& layer, + AttentionActivationsPtrs& activations, QBatch& qbatch, + ThreadingContext& ctx) { + const size_t heads_per_kv_head = + layer.layer_config.heads / layer.layer_config.kv_heads; + + int core_count = ctx.pools.MaxWorkers(); + int task_multiplier = 1; + while (qbatch.Size() * layer.layer_config.kv_heads * task_multiplier < + core_count * 2) { + task_multiplier++; + } + // Finding the smallest context we need to attend to avoid unnecessary + // overhead when sub-splitting doesn't make sense. This check overestimates + // context sizes because it ignores [local] layer sizes and explicit + // qbatch.Prefix settings. + size_t min_pos = qbatch.Pos(0); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + min_pos = std::min(min_pos, qbatch.Pos(qi)); + } + if (min_pos / task_multiplier < num_query_tokens) { + // In case where min_pos / task_multiplier < num_tokens + // To make sure we don't over count tokens or read out of bounds code + // requires quite a bit more involved logic. + // Also there is not much point to splitting the work into more tasks, when + // amount of work is small. + task_multiplier = 1; + } + [[maybe_unused]] int num_tasks = qbatch.Size() * layer.layer_config.kv_heads; + [[maybe_unused]] int num_sub_tasks = + qbatch.Size() * layer.layer_config.kv_heads * task_multiplier; + HWY_DASSERT_M(activations.q.Rows() == num_query_tokens * qbatch.Size(), + "qbatch size mismatch"); + int qkv_dim = layer.layer_config.qkv_dim; + + // sizes of all should be in sync + if (num_sub_tasks > activations.sub_task_att_out->size()) { + activations.sub_task_att_out->resize(num_sub_tasks); + activations.sub_task_exp_denominator_sums->resize(num_sub_tasks); + activations.sub_task_max_logits->resize(num_sub_tasks); + } + std::vector skip_sub_task(num_sub_tasks, 0); + + // This loop parallelizes over qbatch, kv_head and substrings of context + // tokens. Each parallel invocation handles all query tokens of the given + // qbatch. + ParallelFor( + Parallelism::kHierarchical, num_sub_tasks, ctx, + /*cluster_idx=*/0, Callers::kFlashAttention, + [&](size_t task_idx, size_t worker) HWY_ATTR { + size_t main_task_idx = task_idx / task_multiplier; + size_t sub_task_idx = task_idx % task_multiplier; + size_t current_qbatch_idx = + main_task_idx / layer.layer_config.kv_heads; + size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads; + // First and last context token we will attend to. + size_t global_start_context_pos = StartPos( + qbatch.Pos(current_qbatch_idx), activations.config, layer_idx); + // Keep in mind this is overestimation because some timesteps might not + // need all tokens due to causal mask. + // We will use it to determine how to divide work between sub tasks + // and make sure PrefixEnd is taken into account + size_t start_context_pos = global_start_context_pos; + size_t last_context_pos = + qbatch.Pos(current_qbatch_idx) + num_query_tokens - 1; + // In some models, context is limited to some prefix - make sure we take + // that into account. + const size_t prefix_end = qbatch.PrefixEnd(current_qbatch_idx); + if (prefix_end > 0 && prefix_end - 1 > last_context_pos) { + last_context_pos = prefix_end - 1; + } + size_t total_num_context_tokens = + last_context_pos - start_context_pos + 1; + size_t context_tokens_per_sub_task = + hwy::DivCeil(total_num_context_tokens, task_multiplier); + // Restrict tokens to attend to the substring of context tokens that + // this subtask is responsible for. + start_context_pos = + start_context_pos + context_tokens_per_sub_task * sub_task_idx; + if (start_context_pos > last_context_pos) { + skip_sub_task[task_idx] = 1; + return; + } + last_context_pos = + std::min(last_context_pos, + start_context_pos + context_tokens_per_sub_task - 1); + // pre-initialize memory [to avoid racy resizes laters]. + int num_queries = num_query_tokens * heads_per_kv_head; + std::vector queries_ptrs; + queries_ptrs.reserve(num_queries); + for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) { + for (int q_head_idx = 0; q_head_idx < heads_per_kv_head; + ++q_head_idx) { + queries_ptrs.push_back( + activations.q.Row(token_idx * qbatch.Size() + + current_qbatch_idx) + + (kv_head_idx * heads_per_kv_head + q_head_idx) * qkv_dim); + } + } + hwy::Span queries_ptrs_span(queries_ptrs.data(), + queries_ptrs.size()); + + auto [transposed_queries, transposed_queries_ptrs] = + TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim); + + MatStorageT& att_out = + activations.sub_task_att_out->at(task_idx); + AlignedFloatVector& exp_denominator_sums = + activations.sub_task_exp_denominator_sums->at(task_idx); + AlignedFloatVector& max_logits = + activations.sub_task_max_logits->at(task_idx); + MaybeResizeMatStorage(att_out, num_queries, qkv_dim, "att_out", + ctx.allocator); + for (int i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out.Row(i), + att_out.Cols() * sizeof(decltype(att_out.Row(i)[0]))); + } + + int num_queries_rounded_to_8 = hwy::RoundUpTo(num_queries, 8); + exp_denominator_sums.resize(num_queries_rounded_to_8); + max_logits.resize(num_queries_rounded_to_8); + for (int i = 0; i < num_queries_rounded_to_8; ++i) { + exp_denominator_sums[i] = 0.0f; + max_logits[i] = -std::numeric_limits::max() / 2.0f; + } + // Get pointers to the KVCache tiles, starting at global_start_pos + // Returns multiple matrices for non-contiguous memory, for example as a + // result of the wraparound in local layers. + std::vector kv_ptrs = + qbatch.KV(current_qbatch_idx) + .cache->GetPointers( + layer_idx, kv_head_idx, layer.layer_config.kv_heads, + global_start_context_pos, + activations.config.IsGlobalLayer(layer_idx)); + + std::vector> start_pos_per_query; + std::vector> last_pos_per_query; + start_pos_per_query.reserve(num_queries); + last_pos_per_query.reserve(num_queries); + // Position of the first token in the first tile whose pointer was + // returned above. Allows for handling of token positions relative to + // the KV tiles returned above. + size_t rounded_down_global_start_pos = + hwy::RoundDownTo(global_start_context_pos, KVCache::kTileSize); + for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) { + int64_t global_query_pos = + qbatch.Pos(current_qbatch_idx) + token_idx; + // Intersect context to attend to for this specific query token + // to the context tokens of the current subtask. + int64_t query_last_context_pos = std::min( + static_cast(last_context_pos), global_query_pos); + // This max is to not go into negative values, for the same reason we + // use int64_t and not size_t here. + int64_t query_start_context_pos = std::max( + global_query_pos - + static_cast( + activations.config.attention_window_sizes[layer_idx]) + + 1, + static_cast(start_context_pos)); + + // Turn token position into KV-tile relative token positions. + query_last_context_pos -= rounded_down_global_start_pos; + query_start_context_pos -= rounded_down_global_start_pos; + for (int q_head_idx = 0; q_head_idx < heads_per_kv_head; + ++q_head_idx) { + start_pos_per_query.push_back(query_start_context_pos); + last_pos_per_query.push_back(query_last_context_pos); + } + } + if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + // pack transposed queries into BF16 + hwy::Span queries_span(transposed_queries_ptrs.data(), + transposed_queries_ptrs.size()); + auto [_, transposed_queries_ptrs_bf16] = + TransposeTransposedQueriesAndPackIntoBF16(queries_span, qkv_dim, + num_queries); + hwy::Span queries_span_bf16( + const_cast(transposed_queries_ptrs_bf16.data()), + transposed_queries_ptrs_bf16.size()); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + kv_ptrs, num_queries, queries_span_bf16, + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), + activations.config.att_cap, att_out, exp_denominator_sums.data(), + max_logits.data()); + } else { + DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( + kv_ptrs, num_queries, + hwy::Span( + const_cast(transposed_queries_ptrs.data()), + transposed_queries_ptrs.size()), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), + activations.config.att_cap, att_out, exp_denominator_sums.data(), + max_logits.data()); + } + }); + + // This loop takes results from separate subtasks (subsequence of kv) and + // merges them into single att_out over whole kv sequence. + ParallelFor( + Parallelism::kFlat, num_tasks, ctx, + /*cluster_idx=*/0, Callers::kFlashAttention, + [&](size_t main_task_idx, size_t worker) HWY_ATTR { + size_t current_qbatch_idx = main_task_idx / layer.layer_config.kv_heads; + size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads; + for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) { + for (int head_in_group_idx = 0; head_in_group_idx < heads_per_kv_head; + ++head_in_group_idx) { + const size_t batch_index = + current_qbatch_idx * num_query_tokens + token_idx; + const size_t q_head_idx = + kv_head_idx * heads_per_kv_head + head_in_group_idx; + const size_t att_out_row_idx = + token_idx * heads_per_kv_head + head_in_group_idx; + const size_t activations_att_out_start_idx = q_head_idx * qkv_dim; + auto& att_out_0 = activations.sub_task_att_out->at( + main_task_idx * task_multiplier + 0); + auto& exp_denominator_sums_0 = + activations.sub_task_exp_denominator_sums->at( + main_task_idx * task_multiplier + 0); + auto& max_logits_0 = activations.sub_task_max_logits->at( + main_task_idx * task_multiplier + 0); + + hwy::CopyBytes(att_out_0.Row(att_out_row_idx), + activations.att_out.Row(batch_index) + + activations_att_out_start_idx, + qkv_dim * sizeof(float)); + activations.softmax_d.Row(batch_index)[q_head_idx] = + exp_denominator_sums_0[token_idx * heads_per_kv_head + + head_in_group_idx]; + activations.softmax_max.Row(batch_index)[q_head_idx] = + max_logits_0[token_idx * heads_per_kv_head + head_in_group_idx]; + for (int sub_task_idx = 1; sub_task_idx < task_multiplier; + ++sub_task_idx) { + int task_idx = main_task_idx * task_multiplier + sub_task_idx; + if (skip_sub_task[task_idx] == 1) { + continue; + } + auto& att_out = activations.sub_task_att_out->at(task_idx); + auto& exp_denominator_sums = + activations.sub_task_exp_denominator_sums->at(task_idx); + auto& max_logits = activations.sub_task_max_logits->at(task_idx); + MergeOnlineSoftmax( + att_out.Row(att_out_row_idx), + max_logits[token_idx * heads_per_kv_head + head_in_group_idx], + exp_denominator_sums[token_idx * heads_per_kv_head + + head_in_group_idx], + qkv_dim, + activations.att_out.Row(batch_index) + + activations_att_out_start_idx, + activations.softmax_max.Row(batch_index)[q_head_idx], + activations.softmax_d.Row(batch_index)[q_head_idx]); + } + } + } + }); +} + +void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, + const size_t layer_idx, const LayerWeightsPtrs& layer, + AttentionActivationsPtrs& activations, QBatch& qbatch, + MatMulEnv& env, int flags) { + static const auto zone = env.ctx.profiler.AddZone( + "Gen.TiledAttention", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + + const LayerConfig& layer_config = layer.layer_config; + + HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, + "query heads must be a multiple of key-value heads"); + (void)layer_config; // only used in HWY_DASSERT + if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kBF16) { + ComputeQKVTransposedTile(num_tokens, layer_idx, layer, attention_impl, + activations, qbatch, flags, env); + } else { + ComputeQKVTransposedTile(num_tokens, layer_idx, layer, attention_impl, + activations, qbatch, flags, env); + } + RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, + layer.query_norm_scale, layer_idx, activations, + env.ctx); + LocalAttentionForAllHeadsTokensAndBatch(attention_impl, num_tokens, layer_idx, + layer, activations, qbatch, env.ctx); + SumHeads(layer, activations, env); +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/tiled_attention.h b/gemma/tiled_attention.h new file mode 100644 index 00000000..e7256d66 --- /dev/null +++ b/gemma/tiled_attention.h @@ -0,0 +1,42 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_ + +#include + +#include +#include +#include + +#include "gemma/gemma.h" +#include "util/allocator.h" +#include "hwy/aligned_allocator.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ + void TransposeStridedQueries(hwy::Span queries, int qkv_dim, \ + hwy::Span transposed_queries); \ + void LocalAttentionForAllHeadsTokensAndBatch( \ + AttentionImpl attention_impl, const size_t num_tokens, \ + const size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_TILED_ATTENTION) + +#undef GEMMA_DECL_TILED_ATTENTION +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_ diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc new file mode 100644 index 00000000..7f9c8ca9 --- /dev/null +++ b/gemma/tiled_attention_test.cc @@ -0,0 +1,749 @@ +#include + +#include +#include +#include +#include + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "gemma/activations.h" +#include "gemma/configs.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" +#include "gemma/weights.h" +#include "util/mat.h" +#include "util/threading_context.h" +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/tiled_attention_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "gemma/tiled_attention.h" +#include "util/test_util.h" +#include "hwy/aligned_allocator.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +using ::testing::FloatNear; +using ::testing::Pointwise; + +struct AttentionTestEnv { + AttentionTestEnv( + int qkv_dim, int kv_seq_len, int attention_window_size, int num_kv_heads, + int num_heads, int num_tokens, int last_pos, float att_cap, int layer_idx, + int layers_total, int qbatch_size, AttentionImpl attention_impl, + ) + : ctx(threading_args), env(ctx) { + layer_config.heads = num_heads; + layer_config.kv_heads = num_kv_heads; + layer_config.qkv_dim = qkv_dim; + layer_config.model_dim = qkv_dim * num_heads; + + model_config.attention_window_sizes = { + static_cast(attention_window_size)}; + model_config.att_cap = att_cap; + model_config.max_seq_len = kv_seq_len; + model_config.num_layers = layers_total; + model_config.model_dim = layer_config.model_dim; + model_config.vocab_size = 1; // not vit + + for (int i = 0; i < model_config.num_layers; ++i) { + model_config.layer_configs.push_back(layer_config); + } + tensor_info_registry = std::make_unique(model_config); + layer = std::make_unique(layer_idx, layer_config, + *tensor_info_registry); + + runtime_config.attention_impl = attention_impl; + inference_args.seq_len = kv_seq_len; + + all_queries.Reserve(qbatch_size); + kv_caches.reserve(qbatch_size); + for (int q = 0; q < qbatch_size; ++q) { + kv_caches.emplace_back(model_config, inference_args, runtime_config, + ctx.allocator); + if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + MatPtrT compact_kv_cache = kv_caches.back().compact_kv_cache_ptr; + for (int i = 0; i < compact_kv_cache.Rows(); ++i) { + for (int j = 0; j < compact_kv_cache.Cols(); ++j) { + BF16 val = hwy::ConvertScalarTo(hwy::Unpredictable1() * + 0.01f * (i + j + 1)); + // split j into if k/v + if (j < qkv_dim * gcpp::KVCache::kTileSize) { + // split j into dim and in tile offset + const int dim = j / gcpp::KVCache::kTileSize; + const int in_tile_offset = j % gcpp::KVCache::kTileSize; + const int dim_mod_2 = dim % 2; + compact_kv_cache.Row( + i)[(dim - dim_mod_2) * gcpp::KVCache::kTileSize + + in_tile_offset * 2 + dim_mod_2] = val; + } else { + const int in_tile_offset = j / qkv_dim; + const int dim = j % qkv_dim; + const int in_tile_offset_mod_2 = in_tile_offset % 2; + compact_kv_cache.Row( + i)[(in_tile_offset - in_tile_offset_mod_2) * qkv_dim + + dim * 2 + in_tile_offset_mod_2] = val; + } + } + } + } else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) { + MatPtrT compact_kv_cache = kv_caches.back().compact_kv_cache_ptr; + FillMatPtrT(compact_kv_cache); + } else { + FillMatPtrT(kv_caches.back().kv_cache); + } + all_queries.Append({ + .prompt = PromptTokens({1, 2, 3}), + .mutable_pos = static_cast(last_pos), + .initial_pos = 0, + .prefix_end = 0, + .kv_cache = kv_caches.back().ToPtr(), + }); + } + + activations = std::make_unique(runtime_config, model_config, + qbatch_size * num_tokens, + kv_seq_len, ctx, env.row_ptrs); + + qbatch = + std::make_unique(/*start_pos=*/0, qbatch_size, all_queries); + } + + void SetupWeights() { + int model_dim = layer_config.model_dim; + int qkv_dim = layer_config.qkv_dim; + int num_heads = layer_config.heads; + int num_kv_heads = layer_config.kv_heads; + + qkv1_w_storage = + MatStorageT("qkv1", Extents2D(model_dim, qkv_dim * num_heads), + ctx.allocator, MatPadding::kPacked); + qkv2_w_storage = MatStorageT( + "qkv2", Extents2D(model_dim, num_kv_heads * 2 * qkv_dim), ctx.allocator, + MatPadding::kPacked); + wo_w_storage = MatStorageT("wo", Extents2D(model_dim, model_dim), + ctx.allocator, MatPadding::kPacked); + + FillMatPtrT(wo_w_storage); + layer->att_weights = wo_w_storage; + FillMatPtrT(qkv1_w_storage); + FillMatPtrT(qkv2_w_storage); + layer->qkv_einsum_w1 = qkv1_w_storage; + layer->qkv_einsum_w2 = qkv2_w_storage; + + query_norm_scale = MatStorageT("query_norm", qkv_dim, ctx.allocator); + FillMatPtrT(query_norm_scale); + layer->query_norm_scale = query_norm_scale; + + key_norm_scale = MatStorageT("key_norm", qkv_dim, ctx.allocator); + FillMatPtrT(key_norm_scale); + layer->key_norm_scale = key_norm_scale; + } + + AttentionTestEnv(const AttentionTestEnv&) = delete; + AttentionTestEnv& operator=(const AttentionTestEnv&) = delete; + AttentionTestEnv(AttentionTestEnv&&) = delete; + AttentionTestEnv& operator=(AttentionTestEnv&&) = delete; + + ThreadingArgs threading_args; + ThreadingContext ctx; + MatMulEnv env; + LayerConfig layer_config; + ModelConfig model_config; + std::unique_ptr tensor_info_registry; + std::unique_ptr layer; + RuntimeConfig runtime_config; + InferenceArgs inference_args; + AllQueries all_queries; + std::vector kv_caches; + std::unique_ptr activations; + std::unique_ptr qbatch; + + // Weights storage for later tests + MatStorageT qkv1_w_storage; + MatStorageT qkv2_w_storage; + MatStorageT wo_w_storage; + MatStorageT query_norm_scale; + MatStorageT key_norm_scale; +}; + +void TestTransposeStridedQueries() { + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + int qkv_dim = 64; + int num_queries = 24; + AlignedPtr input_queries = + ctx.allocator.Alloc(qkv_dim * num_queries); + AlignedPtr output_queries = + ctx.allocator.Alloc(qkv_dim * num_queries); + for (int i = 0; i < num_queries; ++i) { + for (int j = 0; j < qkv_dim; ++j) { + input_queries[i * qkv_dim + j] = i * qkv_dim + j; + } + } + std::vector queries; + for (int i = 0; i < num_queries; ++i) { + queries.push_back(input_queries.get() + i * qkv_dim); + } + hwy::Span queries_span(queries.data(), queries.size()); + + TransposeStridedQueries( + queries_span, qkv_dim, + hwy::Span(output_queries.get(), qkv_dim * num_queries)); + for (int i = 0; i < num_queries; ++i) { + for (int j = 0; j < qkv_dim; ++j) { + EXPECT_EQ(output_queries[j * num_queries + i], + input_queries[i * qkv_dim + j]) + << "i=" << i << " j=" << j; + } + } +} + +void TestLocalAttentionForAllHeadsTokensAndBatch() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 2; + int num_tokens = 2; + int last_pos = 62; // so token 0 will have 63 and token 1 will have 64 tokens + // to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl); + FillMatPtrT(test_env.activations->attention.q); + LocalAttentionForAllHeadsTokensAndBatch( + attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, test_env.ctx); + + // print states; + std::vector exp_denominator_sums_gold = {63, 63, 64, 64, + 63, 63, 64, 64}; + std::vector max_logits_gold = {10, 10, 10, 10, 10, 10, 10, 10}; + std::vector att_out_gold = { + 30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, + 30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, + 30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, + 30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, + 30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, + 30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, + 30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, + 30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, + 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475, + 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275, + 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075, + 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875, + 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675, + 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475, + 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275, + 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075, + 30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, + 30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, + 30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, + 30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, + 30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, + 30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, + 30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, + 30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, + 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505, + 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585, + 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665, + 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745, + 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825, + 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905, + 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985, + 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065, + 30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, + 30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, + 30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, + 30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, + 30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, + 30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, + 30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, + 30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, + 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475, + 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275, + 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075, + 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875, + 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675, + 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475, + 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275, + 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075, + 30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, + 30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, + 30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, + 30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, + 30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, + 30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, + 30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, + 30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, + 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505, + 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585, + 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665, + 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745, + 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825, + 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905, + 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985, + 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065, + }; + const int group_size = num_heads / num_kv_heads; + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int q_batch_idx = 0; q_batch_idx < qbatch_size; ++q_batch_idx) { + int b = token_idx * qbatch_size + q_batch_idx; + EXPECT_THAT( + absl::MakeSpan(test_env.activations->attention.softmax_d.Row(b), + num_heads), + Pointwise(FloatNear(1e-3f), absl::MakeSpan(exp_denominator_sums_gold) + .subspan(b * num_heads, num_heads))); + EXPECT_THAT( + absl::MakeSpan(test_env.activations->attention.softmax_max.Row(b), + num_heads), + Pointwise(FloatNear(1e-3f), absl::MakeSpan(max_logits_gold) + .subspan(b * num_heads, num_heads))); + for (int kv_h = 0; kv_h < num_kv_heads; ++kv_h) { + for (int g = 0; g < group_size; ++g) { + const int q_h = kv_h * group_size + g; + size_t expected_q_idx = b * num_heads + q_h; + EXPECT_THAT( + absl::MakeSpan(test_env.activations->attention.att_out.Row(b) + + q_h * qkv_dim, + qkv_dim), + Pointwise(FloatNear(1e-3f), + absl::MakeSpan(att_out_gold) + .subspan(expected_q_idx * qkv_dim, qkv_dim))); + } + } + } + } +} + +const std::vector AttentionMultipleTokensAttentionGoldens = { + 34.7414, 34.7717, 34.8022, 34.8327, 34.8631, 34.8936, 34.9241, 34.9545, + 34.985, 35.0156, 35.046, 35.0765, 35.1068, 35.1373, 35.1678, 35.1982, + 35.2286, 35.2592, 35.2895, 35.32, 35.3506, 35.381, 35.4115, 35.4421, + 35.4725, 35.503, 35.5334, 35.5638, 35.5943, 35.6247, 35.6552, 35.6857, + 35.7161, 35.7466, 35.7772, 35.8076, 35.8381, 35.8685, 35.8989, 35.9294, + 35.9598, 35.9902, 36.0208, 36.0512, 36.0816, 36.1122, 36.1426, 36.1731, + 36.2037, 36.2341, 36.2646, 36.295, 36.3254, 36.356, 36.3863, 36.4168, + 36.4474, 36.4778, 36.5082, 36.5388, 36.5692, 36.5997, 36.6301, 36.6605, + 34.6687, 34.6987, 34.7288, 34.759, 34.7891, 34.8192, 34.8495, 34.8795, + 34.9097, 34.9399, 34.97, 35.0002, 35.0302, 35.0604, 35.0906, 35.1206, + 35.1507, 35.181, 35.211, 35.2412, 35.2714, 35.3015, 35.3317, 35.3619, + 35.3921, 35.4222, 35.4523, 35.4824, 35.5126, 35.5427, 35.5728, 35.603, + 35.6331, 35.6633, 35.6935, 35.7236, 35.7538, 35.7838, 35.814, 35.8442, + 35.8742, 35.9043, 35.9346, 35.9646, 35.9948, 36.025, 36.0551, 36.0853, + 36.1155, 36.1456, 36.1759, 36.2059, 36.236, 36.2662, 36.2963, 36.3264, + 36.3566, 36.3867, 36.4169, 36.4471, 36.4772, 36.5074, 36.5374, 36.5676, + 37.0338, 37.0634, 37.0929, 37.1222, 37.1519, 37.1813, 37.2107, 37.2403, + 37.2698, 37.2992, 37.3288, 37.3584, 37.3877, 37.4174, 37.447, 37.4764, + 37.5056, 37.5352, 37.5646, 37.5938, 37.6234, 37.6528, 37.6821, 37.7117, + 37.7412, 37.7705, 37.8001, 37.8295, 37.8589, 37.8885, 37.918, 37.9473, + 37.977, 38.0065, 38.0358, 38.0655, 38.095, 38.1244, 38.1541, 38.1836, + 38.213, 38.2422, 38.2718, 38.3012, 38.3305, 38.36, 38.3895, 38.4187, + 38.4484, 38.4778, 38.5071, 38.5367, 38.5662, 38.5955, 38.6251, 38.6546, + 38.6839, 38.7136, 38.7431, 38.7725, 38.8021, 38.8316, 38.861, 38.8907, + 36.9872, 37.0167, 37.046, 37.0752, 37.1047, 37.1341, 37.1633, 37.1928, + 37.2222, 37.2514, 37.2809, 37.3103, 37.3396, 37.3691, 37.3985, 37.4278, + 37.4569, 37.4863, 37.5156, 37.5447, 37.5742, 37.6035, 37.6326, 37.6621, + 37.6914, 37.7206, 37.7501, 37.7794, 37.8086, 37.8381, 37.8674, 37.8966, + 37.9262, 37.9555, 37.9848, 38.0143, 38.0437, 38.0729, 38.1025, 38.1319, + 38.1612, 38.1903, 38.2197, 38.249, 38.2781, 38.3075, 38.3368, 38.366, + 38.3955, 38.4248, 38.4539, 38.4834, 38.5127, 38.5419, 38.5714, 38.6008, + 38.63, 38.6595, 38.6889, 38.7181, 38.7477, 38.777, 38.8063, 38.8358, + 39.0984, 39.1479, 39.1976, 39.2475, 39.297, 39.3468, 39.3967, 39.4463, + 39.4961, 39.546, 39.5957, 39.6455, 39.695, 39.7447, 39.7946, 39.8441, + 39.8939, 39.9438, 39.9934, 40.0431, 40.0931, 40.1427, 40.1925, 40.2425, + 40.2921, 40.342, 40.3915, 40.4412, 40.4911, 40.5407, 40.5904, 40.6403, + 40.6899, 40.7397, 40.7897, 40.8393, 40.8892, 40.9387, 40.9884, 41.0382, + 41.0878, 41.1375, 41.1874, 41.237, 41.2868, 41.3367, 41.3863, 41.4361, + 41.4861, 41.5358, 41.5856, 41.6351, 41.6849, 41.7347, 41.7843, 41.834, + 41.884, 41.9336, 41.9834, 42.0333, 42.083, 42.1328, 42.1823, 42.232, + 38.9699, 39.0188, 39.068, 39.1173, 39.1663, 39.2155, 39.2648, 39.3138, + 39.3631, 39.4124, 39.4615, 39.5108, 39.5597, 39.6089, 39.6581, 39.7071, + 39.7563, 39.8056, 39.8546, 39.9039, 39.9532, 40.0023, 40.0515, 40.1009, + 40.15, 40.1993, 40.2483, 40.2974, 40.3467, 40.3957, 40.4449, 40.4942, + 40.5433, 40.5925, 40.6419, 40.691, 40.7402, 40.7892, 40.8383, 40.8876, + 40.9366, 40.9857, 41.035, 41.0841, 41.1333, 41.1826, 41.2317, 41.2809, + 41.3303, 41.3794, 41.4287, 41.4777, 41.5268, 41.5761, 41.6251, 41.6743, + 41.7237, 41.7727, 41.8219, 41.8713, 41.9204, 41.9697, 42.0186, 42.0677, + 43.4945, 43.5425, 43.5902, 43.6376, 43.6856, 43.7334, 43.7808, 43.8289, + 43.8766, 43.9241, 43.9722, 44.02, 44.0675, 44.1157, 44.1635, 44.2111, + 44.2583, 44.3062, 44.3538, 44.4011, 44.449, 44.4966, 44.544, 44.5919, + 44.6396, 44.6869, 44.735, 44.7826, 44.8301, 44.8781, 44.9258, 44.9733, + 45.0213, 45.0691, 45.1166, 45.1647, 45.2125, 45.26, 45.3081, 45.356, + 45.4035, 45.4508, 45.4987, 45.5462, 45.5936, 45.6415, 45.6891, 45.7364, + 45.7844, 45.832, 45.8794, 45.9274, 45.9751, 46.0225, 46.0705, 46.1183, + 46.1657, 46.2138, 46.2615, 46.309, 46.3571, 46.4049, 46.4525, 46.5006, + 43.4125, 43.4603, 43.5077, 43.5549, 43.6027, 43.6502, 43.6974, 43.7453, + 43.7928, 43.84, 43.8879, 43.9355, 43.9828, 44.0307, 44.0783, 44.1256, + 44.1726, 44.2203, 44.2676, 44.3147, 44.3624, 44.4098, 44.4569, 44.5046, + 44.552, 44.5992, 44.6469, 44.6944, 44.7416, 44.7894, 44.8369, 44.8841, + 44.9319, 44.9795, 45.0267, 45.0746, 45.1222, 45.1694, 45.2173, 45.265, + 45.3123, 45.3593, 45.407, 45.4543, 45.5014, 45.5491, 45.5965, 45.6436, + 45.6913, 45.7387, 45.7859, 45.8336, 45.8811, 45.9283, 45.9761, 46.0236, + 46.0708, 46.1186, 46.1661, 46.2134, 46.2613, 46.3088, 46.3561, 46.404, + 34.7729, 34.8035, 34.8341, 34.8648, 34.8953, 34.9259, 34.9567, 34.9872, + 35.0179, 35.0486, 35.0792, 35.1098, 35.1404, 35.171, 35.2016, 35.2322, + 35.2628, 35.2935, 35.324, 35.3547, 35.3854, 35.416, 35.4466, 35.4774, + 35.508, 35.5387, 35.5692, 35.5998, 35.6305, 35.661, 35.6916, 35.7224, + 35.7529, 35.7836, 35.8143, 35.8449, 35.8755, 35.9061, 35.9367, 35.9674, + 35.9979, 36.0285, 36.0592, 36.0898, 36.1204, 36.1511, 36.1817, 36.2123, + 36.2431, 36.2737, 36.3044, 36.3349, 36.3655, 36.3962, 36.4267, 36.4574, + 36.4881, 36.5186, 36.5493, 36.58, 36.6106, 36.6413, 36.6718, 36.7024, + 34.6995, 34.7297, 34.76, 34.7904, 34.8206, 34.8509, 34.8813, 34.9115, + 34.9418, 34.9722, 35.0025, 35.0328, 35.063, 35.0933, 35.1237, 35.1539, + 35.1842, 35.2146, 35.2448, 35.2751, 35.3055, 35.3357, 35.3661, 35.3965, + 35.4268, 35.4571, 35.4873, 35.5176, 35.548, 35.5782, 35.6085, 35.6389, + 35.6691, 35.6994, 35.7298, 35.7601, 35.7904, 35.8206, 35.8509, 35.8813, + 35.9115, 35.9418, 35.9721, 36.0024, 36.0327, 36.0631, 36.0933, 36.1237, + 36.1541, 36.1843, 36.2147, 36.2449, 36.2752, 36.3056, 36.3358, 36.3661, + 36.3965, 36.4267, 36.457, 36.4874, 36.5177, 36.548, 36.5782, 36.6085, + 37.0829, 37.1127, 37.1423, 37.1717, 37.2015, 37.2312, 37.2607, 37.2905, + 37.3201, 37.3496, 37.3795, 37.4091, 37.4386, 37.4685, 37.4982, 37.5277, + 37.5571, 37.5868, 37.6164, 37.6458, 37.6755, 37.7051, 37.7346, 37.7643, + 37.7939, 37.8234, 37.8531, 37.8827, 37.9122, 37.942, 37.9716, 38.0011, + 38.0309, 38.0606, 38.0901, 38.1199, 38.1496, 38.1791, 38.209, 38.2387, + 38.2682, 38.2976, 38.3273, 38.3569, 38.3863, 38.416, 38.4456, 38.475, + 38.5048, 38.5344, 38.5638, 38.5936, 38.6232, 38.6527, 38.6825, 38.7121, + 38.7416, 38.7714, 38.8011, 38.8306, 38.8604, 38.8901, 38.9196, 38.9494, + 37.0359, 37.0655, 37.095, 37.1243, 37.154, 37.1835, 37.2129, 37.2425, + 37.2721, 37.3014, 37.3311, 37.3607, 37.39, 37.4198, 37.4493, 37.4787, + 37.508, 37.5376, 37.567, 37.5963, 37.6259, 37.6553, 37.6846, 37.7142, + 37.7437, 37.773, 37.8027, 37.8322, 37.8615, 37.8911, 37.9207, 37.95, + 37.9797, 38.0092, 38.0386, 38.0683, 38.0978, 38.1272, 38.1569, 38.1865, + 38.2159, 38.2451, 38.2747, 38.3042, 38.3334, 38.363, 38.3925, 38.4218, + 38.4514, 38.4809, 38.5102, 38.5398, 38.5693, 38.5986, 38.6283, 38.6578, + 38.6872, 38.7168, 38.7464, 38.7757, 38.8054, 38.835, 38.8644, 38.8941, + 39.1594, 39.2093, 39.2593, 39.3095, 39.3594, 39.4094, 39.4597, 39.5096, + 39.5597, 39.61, 39.6599, 39.7101, 39.7599, 39.8099, 39.8601, 39.91, + 39.96, 40.0102, 40.0601, 40.1102, 40.1605, 40.2104, 40.2605, 40.3108, + 40.3608, 40.411, 40.4608, 40.5108, 40.561, 40.6109, 40.661, 40.7112, + 40.7611, 40.8112, 40.8615, 40.9115, 40.9616, 41.0114, 41.0614, 41.1116, + 41.1615, 41.2115, 41.2617, 41.3116, 41.3617, 41.412, 41.4619, 41.512, + 41.5624, 41.6123, 41.6625, 41.7123, 41.7623, 41.8126, 41.8624, 41.9125, + 41.9627, 42.0127, 42.0628, 42.113, 42.163, 42.2131, 42.263, 42.313, + 39.0297, 39.079, 39.1284, 39.1781, 39.2274, 39.2769, 39.3265, 39.3759, + 39.4254, 39.4751, 39.5245, 39.5741, 39.6233, 39.6727, 39.7224, 39.7716, + 39.8211, 39.8708, 39.9201, 39.9696, 40.0193, 40.0686, 40.1182, 40.1679, + 40.2173, 40.2669, 40.3162, 40.3656, 40.4153, 40.4646, 40.514, 40.5637, + 40.6131, 40.6626, 40.7123, 40.7617, 40.8112, 40.8605, 40.9099, 40.9595, + 41.0088, 41.0583, 41.1079, 41.1573, 41.2068, 41.2565, 41.3058, 41.3554, + 41.4051, 41.4545, 41.5041, 41.5534, 41.6028, 41.6524, 41.7017, 41.7512, + 41.8009, 41.8502, 41.8998, 41.9495, 41.9988, 42.0484, 42.0977, 42.1471, + 43.5891, 43.6374, 43.6854, 43.7331, 43.7814, 43.8294, 43.8772, 43.9255, + 43.9736, 44.0214, 44.0698, 44.1179, 44.1657, 44.2141, 44.2623, 44.3101, + 44.3577, 44.4058, 44.4537, 44.5013, 44.5495, 44.5974, 44.6451, 44.6933, + 44.7413, 44.7889, 44.8372, 44.8852, 44.9329, 44.9812, 45.0293, 45.077, + 45.1254, 45.1734, 45.2212, 45.2696, 45.3177, 45.3655, 45.414, 45.4621, + 45.5099, 45.5575, 45.6057, 45.6535, 45.7011, 45.7493, 45.7973, 45.8449, + 45.8931, 45.9411, 45.9888, 46.037, 46.085, 46.1327, 46.1811, 46.2291, + 46.2768, 46.3252, 46.3733, 46.421, 46.4694, 46.5175, 46.5653, 46.6138, + 43.5064, 43.5544, 43.6022, 43.6497, 43.6978, 43.7456, 43.7931, 43.8412, + 43.889, 43.9366, 43.9847, 44.0326, 44.0802, 44.1284, 44.1763, 44.2239, + 44.2712, 44.3191, 44.3668, 44.4141, 44.4621, 44.5098, 44.5572, 44.6052, + 44.6529, 44.7004, 44.7484, 44.7962, 44.8436, 44.8918, 44.9395, 44.987, + 45.0352, 45.083, 45.1305, 45.1787, 45.2266, 45.2742, 45.3223, 45.3703, + 45.4179, 45.4652, 45.5131, 45.5608, 45.6081, 45.6561, 45.7038, 45.7512, + 45.7992, 45.8469, 45.8944, 45.9424, 45.9902, 46.0376, 46.0857, 46.1335, + 46.181, 46.2292, 46.277, 46.3245, 46.3727, 46.4206, 46.4682, 46.5164, +}; + +void TestAttentionMultipleTokens() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.vit_Q); + FillMatPtrT(test_env.activations->attention.vit_K); + FillMatPtrT(test_env.activations->attention.att); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, flags); + + std::cerr << "att_out\n"; + PrintMatPtr(test_env.activations->attention.att_out); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + AttentionMultipleTokensAttentionGoldens.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-3, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + +void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() { + int qkv_dim = 64; + int kv_seq_len = 34; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 31; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + int attention_window_size = 32; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, attention_window_size, + num_kv_heads, num_heads, num_tokens, last_pos, + att_cap, layer_idx, layers_total, qbatch_size, + attention_impl); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.vit_Q); + FillMatPtrT(test_env.activations->attention.vit_K); + FillMatPtrT(test_env.activations->attention.att); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, flags); + + std::cerr << "att_out\n"; + std::vector att_out_golden_test_local = { + 39.3051, 39.3556, 39.4062, 39.4571, 39.5075, 39.5582, 39.6091, 39.6596, + 39.7103, 39.7612, 39.8118, 39.8626, 39.913, 39.9636, 40.0144, 40.0649, + 40.1155, 40.1664, 40.2169, 40.2676, 40.3185, 40.369, 40.4198, 40.4707, + 40.5213, 40.572, 40.6225, 40.6731, 40.724, 40.7744, 40.8251, 40.876, + 40.9265, 40.9772, 41.0281, 41.0787, 41.1295, 41.1799, 41.2305, 41.2813, + 41.3318, 41.3824, 41.4333, 41.4838, 41.5345, 41.5854, 41.6359, 41.6867, + 41.7376, 41.7882, 41.839, 41.8894, 41.94, 41.9908, 42.0413, 42.092, + 42.1429, 42.1934, 42.2441, 42.295, 42.3456, 42.3964, 42.4468, 42.4974, + 39.1614, 39.2113, 39.2613, 39.3114, 39.3613, 39.4113, 39.4616, 39.5115, + 39.5616, 39.6118, 39.6618, 39.7119, 39.7617, 39.8117, 39.8618, 39.9117, + 39.9617, 40.0119, 40.0618, 40.1118, 40.1621, 40.212, 40.2621, 40.3124, + 40.3623, 40.4125, 40.4623, 40.5123, 40.5625, 40.6123, 40.6624, 40.7126, + 40.7625, 40.8126, 40.8629, 40.9128, 40.9629, 41.0127, 41.0627, 41.1129, + 41.1627, 41.2127, 41.2629, 41.3128, 41.3629, 41.4131, 41.463, 41.5131, + 41.5634, 41.6134, 41.6635, 41.7133, 41.7634, 41.8135, 41.8634, 41.9134, + 41.9637, 42.0135, 42.0636, 42.1139, 42.1638, 42.214, 42.2637, 42.3137, + 43.8459, 43.895, 43.9437, 43.9921, 44.0411, 44.0898, 44.1383, 44.1874, + 44.2361, 44.2846, 44.3337, 44.3825, 44.4311, 44.4802, 44.529, 44.5776, + 44.6258, 44.6747, 44.7233, 44.7716, 44.8205, 44.8692, 44.9175, 44.9665, + 45.0151, 45.0635, 45.1125, 45.1612, 45.2096, 45.2586, 45.3074, 45.3558, + 45.4049, 45.4537, 45.5021, 45.5513, 45.6001, 45.6486, 45.6977, 45.7466, + 45.7951, 45.8434, 45.8923, 45.9409, 45.9891, 46.0381, 46.0867, 46.135, + 46.184, 46.2327, 46.281, 46.33, 46.3787, 46.4271, 46.4762, 46.5249, + 46.5733, 46.6224, 46.6712, 46.7197, 46.7688, 46.8176, 46.8661, 46.9153, + 43.7538, 43.8026, 43.851, 43.8992, 43.948, 43.9964, 44.0446, 44.0934, + 44.142, 44.1902, 44.239, 44.2876, 44.3358, 44.3847, 44.4333, 44.4816, + 44.5296, 44.5782, 44.6266, 44.6746, 44.7232, 44.7716, 44.8197, 44.8684, + 44.9168, 44.9649, 45.0136, 45.0621, 45.1102, 45.159, 45.2075, 45.2557, + 45.3045, 45.353, 45.4012, 45.4501, 45.4986, 45.5469, 45.5958, 45.6444, + 45.6927, 45.7406, 45.7893, 45.8376, 45.8856, 45.9343, 45.9827, 46.0307, + 46.0794, 46.1278, 46.1759, 46.2247, 46.2731, 46.3213, 46.3701, 46.4185, + 46.4667, 46.5155, 46.564, 46.6123, 46.6611, 46.7097, 46.7579, 46.8068, + 48.7531, 48.8438, 48.9348, 49.0262, 49.1169, 49.208, 49.2995, 49.3903, + 49.4815, 49.573, 49.6639, 49.7552, 49.8458, 49.9368, 50.0281, 50.1188, + 50.2099, 50.3013, 50.3921, 50.4832, 50.5747, 50.6656, 50.7568, 50.8484, + 50.9393, 51.0306, 51.1213, 51.2123, 51.3037, 51.3944, 51.4855, 51.577, + 51.6678, 51.759, 51.8505, 51.9414, 52.0327, 52.1233, 52.2143, 52.3056, + 52.3963, 52.4874, 52.5788, 52.6696, 52.7607, 52.8522, 52.9431, 53.0343, + 53.1259, 53.2168, 53.3081, 53.3988, 53.4898, 53.5812, 53.6719, 53.763, + 53.8545, 53.9453, 54.0365, 54.128, 54.2189, 54.3102, 54.4008, 54.4918, + 48.4943, 48.5838, 48.6737, 48.7639, 48.8535, 48.9435, 49.0338, 49.1235, + 49.2135, 49.3039, 49.3937, 49.4838, 49.5732, 49.6631, 49.7533, 49.8428, + 49.9328, 50.023, 50.1127, 50.2027, 50.293, 50.3827, 50.4728, 50.5632, + 50.653, 50.7432, 50.8327, 50.9226, 51.0128, 51.1024, 51.1924, 51.2827, + 51.3724, 51.4624, 51.5528, 51.6425, 51.7327, 51.8221, 51.912, 52.0022, + 52.0917, 52.1817, 52.2719, 52.3616, 52.4516, 52.5419, 52.6316, 52.7217, + 52.8121, 52.9019, 52.9921, 53.0816, 53.1715, 53.2617, 53.3513, 53.4413, + 53.5316, 53.6212, 53.7113, 53.8017, 53.8914, 53.9815, 54.071, 54.1609, + 57.7208, 57.8084, 57.8954, 57.9818, 58.0694, 58.1564, 58.2429, 58.3306, + 58.4177, 58.5043, 58.5921, 58.6793, 58.7659, 58.8537, 58.941, 59.0277, + 59.1137, 59.2011, 59.2878, 59.374, 59.4614, 59.5482, 59.6345, 59.722, + 59.8089, 59.8952, 59.9827, 60.0697, 60.1561, 60.2437, 60.3308, 60.4172, + 60.505, 60.5921, 60.6786, 60.7664, 60.8536, 60.9402, 61.0281, 61.1153, + 61.202, 61.2881, 61.3755, 61.4622, 61.5483, 61.6358, 61.7226, 61.8088, + 61.8963, 61.9832, 62.0695, 62.1571, 62.244, 62.3304, 62.4181, 62.5051, + 62.5916, 62.6793, 62.7664, 62.853, 62.9407, 63.0279, 63.1146, 63.2024, + 57.5554, 57.6426, 57.729, 57.815, 57.9021, 57.9887, 58.0747, 58.162, + 58.2486, 58.3347, 58.422, 58.5087, 58.5949, 58.6823, 58.7691, 58.8553, + 58.9409, 59.0278, 59.114, 59.1997, 59.2867, 59.373, 59.4588, 59.5458, + 59.6323, 59.7181, 59.8052, 59.8917, 59.9776, 60.0648, 60.1514, 60.2374, + 60.3246, 60.4113, 60.4974, 60.5847, 60.6714, 60.7576, 60.8449, 60.9317, + 61.018, 61.1036, 61.1905, 61.2767, 61.3624, 61.4494, 61.5357, 61.6215, + 61.7085, 61.7949, 61.8808, 61.9679, 62.0544, 62.1403, 62.2275, 62.3141, + 62.4001, 62.4873, 62.574, 62.66, 62.7474, 62.8341, 62.9202, 63.0076, + 39.3678, 39.4186, 39.4696, 39.5207, 39.5715, 39.6225, 39.6737, 39.7246, + 39.7756, 39.8268, 39.8777, 39.9288, 39.9796, 40.0305, 40.0816, 40.1324, + 40.1834, 40.2346, 40.2854, 40.3364, 40.3876, 40.4385, 40.4896, 40.5408, + 40.5917, 40.6428, 40.6936, 40.7446, 40.7957, 40.8466, 40.8975, 40.9487, + 40.9996, 41.0506, 41.1019, 41.1528, 41.2038, 41.2546, 41.3055, 41.3567, + 41.4075, 41.4584, 41.5096, 41.5605, 41.6115, 41.6627, 41.7136, 41.7646, + 41.8159, 41.8668, 41.9179, 41.9687, 42.0196, 42.0708, 42.1216, 42.1726, + 42.2238, 42.2746, 42.3256, 42.3769, 42.4278, 42.4789, 42.5296, 42.5806, + 39.2228, 39.2729, 39.3232, 39.3737, 39.4239, 39.4743, 39.5248, 39.575, + 39.6254, 39.676, 39.7263, 39.7767, 39.8268, 39.8771, 39.9276, 39.9778, + 40.0281, 40.0786, 40.1288, 40.1792, 40.2298, 40.28, 40.3304, 40.381, + 40.4313, 40.4818, 40.5319, 40.5822, 40.6327, 40.6829, 40.7333, 40.7838, + 40.834, 40.8844, 40.935, 40.9853, 41.0357, 41.0858, 41.1361, 41.1866, + 41.2368, 41.2871, 41.3376, 41.3878, 41.4382, 41.4888, 41.539, 41.5894, + 41.64, 41.6903, 41.7408, 41.7909, 41.8412, 41.8917, 41.9419, 41.9922, + 42.0428, 42.093, 42.1434, 42.194, 42.2442, 42.2947, 42.3448, 42.3951, + 43.9435, 43.9928, 44.0418, 44.0905, 44.1399, 44.1889, 44.2376, 44.287, + 44.3361, 44.3849, 44.4343, 44.4834, 44.5322, 44.5817, 44.6308, 44.6797, + 44.7283, 44.7774, 44.8263, 44.8749, 44.9241, 44.9731, 45.0217, 45.071, + 45.12, 45.1686, 45.2179, 45.2669, 45.3156, 45.365, 45.414, 45.4628, + 45.5122, 45.5613, 45.61, 45.6595, 45.7086, 45.7574, 45.8068, 45.856, + 45.9048, 45.9534, 46.0026, 46.0515, 46.1001, 46.1493, 46.1982, 46.2469, + 46.2961, 46.3451, 46.3938, 46.4431, 46.4921, 46.5408, 46.5901, 46.6392, + 46.6879, 46.7373, 46.7864, 46.8352, 46.8846, 46.9337, 46.9825, 47.032, + 43.8506, 43.8996, 43.9484, 43.9968, 44.0459, 44.0947, 44.1432, 44.1923, + 44.2411, 44.2896, 44.3388, 44.3876, 44.4362, 44.4854, 44.5343, 44.5829, + 44.6312, 44.6801, 44.7287, 44.7771, 44.826, 44.8747, 44.9231, 44.9721, + 45.0208, 45.0692, 45.1182, 45.167, 45.2154, 45.2645, 45.3133, 45.3617, + 45.4109, 45.4597, 45.5082, 45.5574, 45.6062, 45.6548, 45.704, 45.7529, + 45.8015, 45.8498, 45.8987, 45.9473, 45.9957, 46.0446, 46.0933, 46.1416, + 46.1906, 46.2394, 46.2878, 46.3368, 46.3856, 46.434, 46.4831, 46.5319, + 46.5803, 46.6295, 46.6783, 46.7268, 46.776, 46.8248, 46.8734, 46.9226, + 48.8777, 48.969, 49.0607, 49.1527, 49.2441, 49.3358, 49.4279, 49.5194, + 49.6112, 49.7034, 49.7949, 49.8868, 49.9781, 50.0697, 50.1617, 50.2531, + 50.3448, 50.4368, 50.5283, 50.62, 50.7122, 50.8037, 50.8956, 50.9878, + 51.0794, 51.1713, 51.2626, 51.3543, 51.4463, 51.5377, 51.6294, 51.7215, + 51.813, 51.9048, 51.997, 52.0885, 52.1805, 52.2717, 52.3633, 52.4553, + 52.5467, 52.6384, 52.7305, 52.8219, 52.9137, 53.0058, 53.0973, 53.1892, + 53.2814, 53.373, 53.4649, 53.5562, 53.6479, 53.7399, 53.8313, 53.923, + 54.0152, 54.1066, 54.1984, 54.2906, 54.3821, 54.4741, 54.5653, 54.6569, + 48.6164, 48.7066, 48.7971, 48.888, 48.9782, 49.0688, 49.1597, 49.25, + 49.3407, 49.4317, 49.5221, 49.6129, 49.703, 49.7934, 49.8843, 49.9745, + 50.065, 50.1559, 50.2462, 50.3368, 50.4278, 50.5181, 50.6089, 50.6999, + 50.7903, 50.8811, 50.9713, 51.0618, 51.1527, 51.2429, 51.3335, 51.4244, + 51.5147, 51.6054, 51.6964, 51.7868, 51.8776, 51.9677, 52.0581, 52.149, + 52.2392, 52.3297, 52.4206, 52.5109, 52.6015, 52.6925, 52.7828, 52.8736, + 52.9646, 53.055, 53.1458, 53.236, 53.3265, 53.4174, 53.5076, 53.5982, + 53.6891, 53.7794, 53.8701, 53.9611, 54.0515, 54.1423, 54.2324, 54.3228, + 57.914, 58.0021, 58.0897, 58.1767, 58.265, 58.3526, 58.4397, 58.528, + 58.6157, 58.7028, 58.7912, 58.879, 58.9662, 59.0547, 59.1426, 59.2299, + 59.3165, 59.4045, 59.4918, 59.5786, 59.6666, 59.754, 59.8408, 59.9289, + 60.0165, 60.1033, 60.1915, 60.2791, 60.3661, 60.4544, 60.542, 60.629, + 60.7174, 60.8051, 60.8922, 60.9806, 61.0684, 61.1556, 61.2441, 61.332, + 61.4193, 61.5059, 61.5939, 61.6812, 61.768, 61.856, 61.9434, 62.0302, + 62.1183, 62.2059, 62.2927, 62.3809, 62.4685, 62.5555, 62.6437, 62.7314, + 62.8184, 62.9068, 62.9945, 63.0816, 63.17, 63.2578, 63.345, 63.4335, + 57.7471, 57.8348, 57.9219, 58.0084, 58.0962, 58.1834, 58.27, 58.3578, + 58.4451, 58.5317, 58.6197, 58.707, 58.7937, 58.8817, 58.9691, 59.0559, + 59.1421, 59.2296, 59.3165, 59.4028, 59.4903, 59.5773, 59.6636, 59.7512, + 59.8383, 59.9247, 60.0124, 60.0995, 60.186, 60.2738, 60.361, 60.4476, + 60.5354, 60.6227, 60.7093, 60.7973, 60.8846, 60.9713, 61.0593, 61.1467, + 61.2335, 61.3197, 61.4072, 61.4941, 61.5804, 61.6679, 61.7549, 61.8412, + 61.9289, 62.0159, 62.1023, 62.19, 62.2772, 62.3636, 62.4514, 62.5386, + 62.6252, 62.7131, 62.8003, 62.887, 62.9749, 63.0622, 63.1489, 63.237}; + PrintMatPtr(test_env.activations->attention.att_out); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + att_out_golden_test_local.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-3, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + +void TestAttentionMultipleTokensBF16() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.vit_Q); + FillMatPtrT(test_env.activations->attention.vit_K); + FillMatPtrT(test_env.activations->attention.att); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, flags); + std::cerr << "att_out\n"; + PrintMatPtr(test_env.activations->attention.att_out); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + AttentionMultipleTokensAttentionGoldens.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-1, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(TiledAttentionTest); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, + TestLocalAttentionForAllHeadsTokensAndBatch); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, + TestAttentionMultipleTokensAttentionWindowSizeEdgeCase); + +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0eeec31b..5ff11a85 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -1026,6 +1026,450 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( } HWY_DASSERT(size == i); } +template > +static HWY_INLINE void StoreUpTo8Times2(DF df, MatPtrT& out, + size_t start_col, VF out0_0, VF out0_1, + VF out1_0, VF out1_1, VF out2_0, + VF out2_1, VF out3_0, VF out3_1, + VF out4_0, VF out4_1, VF out5_0, + VF out5_1, VF out6_0, VF out6_1, + VF out7_0, VF out7_1) { + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + hn::Store(out0_0, df, out.Row(0) + start_col); + hn::Store(out0_1, df, out.Row(0) + start_col + NF); + if constexpr (N >= 2) { + hn::Store(out1_0, df, out.Row(1) + start_col); + hn::Store(out1_1, df, out.Row(1) + start_col + NF); + } + if constexpr (N >= 3) { + hn::Store(out2_0, df, out.Row(2) + start_col); + hn::Store(out2_1, df, out.Row(2) + start_col + NF); + } + if constexpr (N >= 4) { + hn::Store(out3_0, df, out.Row(3) + start_col); + hn::Store(out3_1, df, out.Row(3) + start_col + NF); + } + if constexpr (N >= 5) { + hn::Store(out4_0, df, out.Row(4) + start_col); + hn::Store(out4_1, df, out.Row(4) + start_col + NF); + } + if constexpr (N >= 6) { + hn::Store(out5_0, df, out.Row(5) + start_col); + hn::Store(out5_1, df, out.Row(5) + start_col + NF); + } + if constexpr (N >= 7) { + hn::Store(out6_0, df, out.Row(6) + start_col); + hn::Store(out6_1, df, out.Row(6) + start_col + NF); + } + if constexpr (N >= 8) { + hn::Store(out7_0, df, out.Row(7) + start_col); + hn::Store(out7_1, df, out.Row(7) + start_col + NF); + } +} + +template > +static HWY_INLINE void LoadAndMulUpTo8Times2( + DF df, MatPtrT& out, size_t column, const float* HWY_RESTRICT scales, + VF& out0_0, VF& out0_1, VF& out1_0, VF& out1_1, VF& out2_0, VF& out2_1, + VF& out3_0, VF& out3_1, VF& out4_0, VF& out4_1, VF& out5_0, VF& out5_1, + VF& out6_0, VF& out6_1, VF& out7_0, VF& out7_1) { + namespace hn = hwy::HWY_NAMESPACE; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + out0_0 = hn::Load(df, out.Row(0) + column); + out0_0 = hn::Mul(out0_0, hn::Set(df, scales[0])); + out0_1 = hn::Load(df, out.Row(0) + column + NF); + out0_1 = hn::Mul(out0_1, hn::Set(df, scales[0])); + if constexpr (N >= 2) { + out1_0 = hn::Load(df, out.Row(1) + column); + out1_0 = hn::Mul(out1_0, hn::Set(df, scales[1])); + out1_1 = hn::Load(df, out.Row(1) + column + NF); + out1_1 = hn::Mul(out1_1, hn::Set(df, scales[1])); + } + if constexpr (N >= 3) { + out2_0 = hn::Load(df, out.Row(2) + column); + out2_0 = hn::Mul(out2_0, hn::Set(df, scales[2])); + out2_1 = hn::Load(df, out.Row(2) + column + NF); + out2_1 = hn::Mul(out2_1, hn::Set(df, scales[2])); + } + if constexpr (N >= 4) { + out3_0 = hn::Load(df, out.Row(3) + column); + out3_0 = hn::Mul(out3_0, hn::Set(df, scales[3])); + out3_1 = hn::Load(df, out.Row(3) + column + NF); + out3_1 = hn::Mul(out3_1, hn::Set(df, scales[3])); + } + if constexpr (N >= 5) { + out4_0 = hn::Load(df, out.Row(4) + column); + out4_0 = hn::Mul(out4_0, hn::Set(df, scales[4])); + out4_1 = hn::Load(df, out.Row(4) + column + NF); + out4_1 = hn::Mul(out4_1, hn::Set(df, scales[4])); + } + if constexpr (N >= 6) { + out5_0 = hn::Load(df, out.Row(5) + column); + out5_0 = hn::Mul(out5_0, hn::Set(df, scales[5])); + out5_1 = hn::Load(df, out.Row(5) + column + NF); + out5_1 = hn::Mul(out5_1, hn::Set(df, scales[5])); + } + if constexpr (N >= 7) { + out6_0 = hn::Load(df, out.Row(6) + column); + out6_0 = hn::Mul(out6_0, hn::Set(df, scales[6])); + out6_1 = hn::Load(df, out.Row(6) + column + NF); + out6_1 = hn::Mul(out6_1, hn::Set(df, scales[6])); + } + if constexpr (N >= 8) { + out7_0 = hn::Load(df, out.Row(7) + column); + out7_0 = hn::Mul(out7_0, hn::Set(df, scales[7])); + out7_1 = hn::Load(df, out.Row(7) + column + NF); + out7_1 = hn::Mul(out7_1, hn::Set(df, scales[7])); + } +} + +template , typename VType> +HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8( + DF df, const float* HWY_RESTRICT scales, const VF& c0_p0, const VF& c0_p1, + const VF& c1_p0, const VF& c1_p1, const VF& c2_p0, const VF& c2_p1, + const VF& c3_p0, const VF& c3_p1, const VF& c4_p0, const VF& c4_p1, + const VF& c5_p0, const VF& c5_p1, const VF& c6_p0, const VF& c6_p1, + const VF& c7_p0, const VF& c7_p1, const VType* HWY_RESTRICT v_tile, + MatPtrT& out) { + static_assert(N <= 8); + namespace hn = hwy::HWY_NAMESPACE; + const size_t qkv_dim = out.Cols(); + constexpr size_t kMaxLanes = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + + PackedSpan v_span = MakeConstSpan(v_tile, qkv_dim * 2 * NF); + + size_t i = 0; + HWY_DASSERT(qkv_dim % (NF * 2) == 0); + HWY_ALIGN float consts_buffer[kMaxLanes * N * 2]; + hn::Store(c0_p0, df, consts_buffer); + hn::Store(c0_p1, df, consts_buffer + kMaxLanes); + if constexpr (N >= 2) { + hn::Store(c1_p0, df, consts_buffer + 2 * kMaxLanes); + hn::Store(c1_p1, df, consts_buffer + 3 * kMaxLanes); + } + if constexpr (N >= 3) { + hn::Store(c2_p0, df, consts_buffer + 4 * kMaxLanes); + hn::Store(c2_p1, df, consts_buffer + 5 * kMaxLanes); + } + if constexpr (N >= 4) { + hn::Store(c3_p0, df, consts_buffer + 6 * kMaxLanes); + hn::Store(c3_p1, df, consts_buffer + 7 * kMaxLanes); + } + if constexpr (N >= 5) { + hn::Store(c4_p0, df, consts_buffer + 8 * kMaxLanes); + hn::Store(c4_p1, df, consts_buffer + 9 * kMaxLanes); + } + if constexpr (N >= 6) { + hn::Store(c5_p0, df, consts_buffer + 10 * kMaxLanes); + hn::Store(c5_p1, df, consts_buffer + 11 * kMaxLanes); + } + if constexpr (N >= 7) { + hn::Store(c6_p0, df, consts_buffer + 12 * kMaxLanes); + hn::Store(c6_p1, df, consts_buffer + 13 * kMaxLanes); + } + if constexpr (N >= 8) { + hn::Store(c7_p0, df, consts_buffer + 14 * kMaxLanes); + hn::Store(c7_p1, df, consts_buffer + 15 * kMaxLanes); + } + HWY_DASSERT(qkv_dim % (NF * 2) == 0); + while (i + NF * 2 <= qkv_dim) { + VF out0_0, out1_0, out2_0, out3_0, out4_0, out5_0, out6_0, out7_0; + VF out0_1, out1_1, out2_1, out3_1, out4_1, out5_1, out6_1, out7_1; + LoadAndMulUpTo8Times2(df, out, i, scales, out0_0, out0_1, out1_0, out1_1, + out2_0, out2_1, out3_0, out3_1, out4_0, out4_1, + out5_0, out5_1, out6_0, out6_1, out7_0, out7_1); + for (int lane = 0; lane < NF; ++lane) { + VF xI1, xI2; + Decompress2(df, v_span, qkv_dim * lane + i, xI1, xI2); + + out0_0 = hn::MulAdd(xI1, hn::Set(df, consts_buffer[lane + 0 * kMaxLanes]), + out0_0); + out0_1 = hn::MulAdd(xI2, hn::Set(df, consts_buffer[lane + 0 * kMaxLanes]), + out0_1); + if constexpr (N >= 2) { + out1_0 = hn::MulAdd( + xI1, hn::Set(df, consts_buffer[lane + 2 * kMaxLanes]), out1_0); + out1_1 = hn::MulAdd( + xI2, hn::Set(df, consts_buffer[lane + 2 * kMaxLanes]), out1_1); + } + if constexpr (N >= 3) { + out2_0 = hn::MulAdd( + xI1, hn::Set(df, consts_buffer[lane + 4 * kMaxLanes]), out2_0); + out2_1 = hn::MulAdd( + xI2, hn::Set(df, consts_buffer[lane + 4 * kMaxLanes]), out2_1); + } + if constexpr (N >= 4) { + out3_0 = hn::MulAdd( + xI1, hn::Set(df, consts_buffer[lane + 6 * kMaxLanes]), out3_0); + out3_1 = hn::MulAdd( + xI2, hn::Set(df, consts_buffer[lane + 6 * kMaxLanes]), out3_1); + } + if constexpr (N >= 5) { + out4_0 = hn::MulAdd( + xI1, hn::Set(df, consts_buffer[lane + 8 * kMaxLanes]), out4_0); + out4_1 = hn::MulAdd( + xI2, hn::Set(df, consts_buffer[lane + 8 * kMaxLanes]), out4_1); + } + if constexpr (N >= 6) { + out5_0 = hn::MulAdd( + xI1, hn::Set(df, consts_buffer[lane + 10 * kMaxLanes]), out5_0); + out5_1 = hn::MulAdd( + xI2, hn::Set(df, consts_buffer[lane + 10 * kMaxLanes]), out5_1); + } + if constexpr (N >= 7) { + out6_0 = hn::MulAdd( + xI1, hn::Set(df, consts_buffer[lane + 12 * kMaxLanes]), out6_0); + out6_1 = hn::MulAdd( + xI2, hn::Set(df, consts_buffer[lane + 12 * kMaxLanes]), out6_1); + } + if constexpr (N >= 8) { + out7_0 = hn::MulAdd( + xI1, hn::Set(df, consts_buffer[lane + 14 * kMaxLanes]), out7_0); + out7_1 = hn::MulAdd( + xI2, hn::Set(df, consts_buffer[lane + 14 * kMaxLanes]), out7_1); + } + VF xI3, xI4; + Decompress2(df, v_span, qkv_dim * (NF + lane) + i, xI3, xI4); + + out0_0 = hn::MulAdd(xI3, hn::Set(df, consts_buffer[lane + 1 * kMaxLanes]), + out0_0); + out0_1 = hn::MulAdd(xI4, hn::Set(df, consts_buffer[lane + 1 * kMaxLanes]), + out0_1); + if constexpr (N >= 2) { + out1_0 = hn::MulAdd( + xI3, hn::Set(df, consts_buffer[lane + 3 * kMaxLanes]), out1_0); + out1_1 = hn::MulAdd( + xI4, hn::Set(df, consts_buffer[lane + 3 * kMaxLanes]), out1_1); + } + if constexpr (N >= 3) { + out2_0 = hn::MulAdd( + xI3, hn::Set(df, consts_buffer[lane + 5 * kMaxLanes]), out2_0); + out2_1 = hn::MulAdd( + xI4, hn::Set(df, consts_buffer[lane + 5 * kMaxLanes]), out2_1); + } + if constexpr (N >= 4) { + out3_0 = hn::MulAdd( + xI3, hn::Set(df, consts_buffer[lane + 7 * kMaxLanes]), out3_0); + out3_1 = hn::MulAdd( + xI4, hn::Set(df, consts_buffer[lane + 7 * kMaxLanes]), out3_1); + } + if constexpr (N >= 5) { + out4_0 = hn::MulAdd( + xI3, hn::Set(df, consts_buffer[lane + 9 * kMaxLanes]), out4_0); + out4_1 = hn::MulAdd( + xI4, hn::Set(df, consts_buffer[lane + 9 * kMaxLanes]), out4_1); + } + if constexpr (N >= 6) { + out5_0 = hn::MulAdd( + xI3, hn::Set(df, consts_buffer[lane + 11 * kMaxLanes]), out5_0); + out5_1 = hn::MulAdd( + xI4, hn::Set(df, consts_buffer[lane + 11 * kMaxLanes]), out5_1); + } + if constexpr (N >= 7) { + out6_0 = hn::MulAdd( + xI3, hn::Set(df, consts_buffer[lane + 13 * kMaxLanes]), out6_0); + out6_1 = hn::MulAdd( + xI4, hn::Set(df, consts_buffer[lane + 13 * kMaxLanes]), out6_1); + } + if constexpr (N >= 8) { + out7_0 = hn::MulAdd( + xI3, hn::Set(df, consts_buffer[lane + 15 * kMaxLanes]), out7_0); + out7_1 = hn::MulAdd( + xI4, hn::Set(df, consts_buffer[lane + 15 * kMaxLanes]), out7_1); + } + } + StoreUpTo8Times2(df, out, i, out0_0, out0_1, out1_0, out1_1, out2_0, + out2_1, out3_0, out3_1, out4_0, out4_1, out5_0, out5_1, + out6_0, out6_1, out7_0, out7_1); + + i += 2 * NF; + } + HWY_DASSERT(qkv_dim == i); +} + +template , typename VType> +HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16( + DF df, const float* HWY_RESTRICT scales, VF c0_p0, VF c0_p1, VF c1_p0, + VF c1_p1, VF c2_p0, VF c2_p1, VF c3_p0, VF c3_p1, VF c4_p0, VF c4_p1, + VF c5_p0, VF c5_p1, VF c6_p0, VF c6_p1, VF c7_p0, VF c7_p1, + VType* HWY_RESTRICT v_tile, MatPtrT& out) { + static_assert(N <= 8); + namespace hn = hwy::HWY_NAMESPACE; + const size_t qkv_dim = out.Cols(); + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + constexpr size_t kMaxLanes = hn::MaxLanes(df); + using DBF = hn::ScalableTag; + const DBF dbf; + using VBF = hn::Vec; + PackedSpan v_span = MakeConstSpan(v_tile, qkv_dim * 2 * NF); + HWY_ALIGN BF16 cs[N * kMaxLanes * 2]; + PackedSpan cs_span = MakeSpan(cs, N * kMaxLanes * 2); + float* cs_as_float = HWY_RCAST_ALIGNED(float*, cs); + Compress2(df, c0_p0, c0_p1, cs_span, 0); + if constexpr (N >= 2) { + Compress2(df, c1_p0, c1_p1, cs_span, kMaxLanes * 2); + } + if constexpr (N >= 3) { + Compress2(df, c2_p0, c2_p1, cs_span, 2 * kMaxLanes * 2); + } + if constexpr (N >= 4) { + Compress2(df, c3_p0, c3_p1, cs_span, 3 * kMaxLanes * 2); + } + if constexpr (N >= 5) { + Compress2(df, c4_p0, c4_p1, cs_span, 4 * kMaxLanes * 2); + } + if constexpr (N >= 6) { + Compress2(df, c5_p0, c5_p1, cs_span, 5 * kMaxLanes * 2); + } + if constexpr (N >= 7) { + Compress2(df, c6_p0, c6_p1, cs_span, 6 * kMaxLanes * 2); + } + if constexpr (N >= 8) { + Compress2(df, c7_p0, c7_p1, cs_span, 7 * kMaxLanes * 2); + } + VF zero = hn::Zero(df); + size_t i = 0; + HWY_DASSERT(qkv_dim % (NF * 2) == 0); + while (i + NF * 2 <= qkv_dim) { + VF out0_0, out1_0, out2_0, out3_0; + VF out0_1, out1_1, out2_1, out3_1; + VF out4_0, out5_0, out6_0, out7_0; + VF out4_1, out5_1, out6_1, out7_1; + VF helper_out0_0 = hn::Zero(df), helper_out0_1 = hn::Zero(df), + helper_out1_0 = hn::Zero(df), helper_out1_1 = hn::Zero(df), + helper_out2_0 = hn::Zero(df), helper_out2_1 = hn::Zero(df), + helper_out3_0 = hn::Zero(df), helper_out3_1 = hn::Zero(df), + helper_out4_0 = hn::Zero(df), helper_out4_1 = hn::Zero(df), + helper_out5_0 = hn::Zero(df), helper_out5_1 = hn::Zero(df), + helper_out6_0 = hn::Zero(df), helper_out6_1 = hn::Zero(df), + helper_out7_0 = hn::Zero(df), helper_out7_1 = hn::Zero(df); + LoadAndMulUpTo8Times2(df, out, i, scales, out0_0, out0_1, out1_0, out1_1, + out2_0, out2_1, out3_0, out3_1, out4_0, out4_1, + out5_0, out5_1, out6_0, out6_1, out7_0, out7_1); + for (int lane = 0; lane < NF; ++lane) { + VBF xI, xI2; + Decompress2(dbf, v_span, 2 * qkv_dim * lane + i * 2, xI, xI2); + + // Set pair of c scales for 2 value vectors + out0_0 = hn::ReorderWidenMulAccumulate( + df, xI, hn::BitCast(dbf, hn::Set(df, cs_as_float[lane])), out0_0, + helper_out0_0); + out0_1 = hn::ReorderWidenMulAccumulate( + df, xI2, hn::BitCast(dbf, hn::Set(df, cs_as_float[lane])), out0_1, + helper_out0_1); + if constexpr (N >= 2) { + out1_0 = hn::ReorderWidenMulAccumulate( + df, xI, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + kMaxLanes])), + out1_0, helper_out1_0); + out1_1 = hn::ReorderWidenMulAccumulate( + df, xI2, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + kMaxLanes])), + out1_1, helper_out1_1); + } + if constexpr (N >= 3) { + out2_0 = hn::ReorderWidenMulAccumulate( + df, xI, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 2 * kMaxLanes])), + out2_0, helper_out2_0); + out2_1 = hn::ReorderWidenMulAccumulate( + df, xI2, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 2 * kMaxLanes])), + out2_1, helper_out2_1); + } + if constexpr (N >= 4) { + out3_0 = hn::ReorderWidenMulAccumulate( + df, xI, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 3 * kMaxLanes])), + out3_0, helper_out3_0); + out3_1 = hn::ReorderWidenMulAccumulate( + df, xI2, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 3 * kMaxLanes])), + out3_1, helper_out3_1); + } + if constexpr (N >= 5) { + out4_0 = hn::ReorderWidenMulAccumulate( + df, xI, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 4 * kMaxLanes])), + out4_0, helper_out4_0); + out4_1 = hn::ReorderWidenMulAccumulate( + df, xI2, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 4 * kMaxLanes])), + out4_1, helper_out4_1); + } + if constexpr (N >= 6) { + out5_0 = hn::ReorderWidenMulAccumulate( + df, xI, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 5 * kMaxLanes])), + out5_0, helper_out5_0); + out5_1 = hn::ReorderWidenMulAccumulate( + df, xI2, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 5 * kMaxLanes])), + out5_1, helper_out5_1); + } + if constexpr (N >= 7) { + out6_0 = hn::ReorderWidenMulAccumulate( + df, xI, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 6 * kMaxLanes])), + out6_0, helper_out6_0); + out6_1 = hn::ReorderWidenMulAccumulate( + df, xI2, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 6 * kMaxLanes])), + out6_1, helper_out6_1); + } + if constexpr (N >= 8) { + out7_0 = hn::ReorderWidenMulAccumulate( + df, xI, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 7 * kMaxLanes])), + out7_0, helper_out7_0); + out7_1 = hn::ReorderWidenMulAccumulate( + df, xI2, + hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 7 * kMaxLanes])), + out7_1, helper_out7_1); + } + } +#if HWY_NATIVE_DOT_BF16 == 0 + out0_0 = hn::Add(out0_0, helper_out0_0); + out0_1 = hn::Add(out0_1, helper_out0_1); + if constexpr (N >= 2) { + out1_0 = hn::Add(out1_0, helper_out1_0); + out1_1 = hn::Add(out1_1, helper_out1_1); + } + if constexpr (N >= 3) { + out2_0 = hn::Add(out2_0, helper_out2_0); + out2_1 = hn::Add(out2_1, helper_out2_1); + } + if constexpr (N >= 4) { + out3_0 = hn::Add(out3_0, helper_out3_0); + out3_1 = hn::Add(out3_1, helper_out3_1); + } + if constexpr (N >= 5) { + out4_0 = hn::Add(out4_0, helper_out4_0); + out4_1 = hn::Add(out4_1, helper_out4_1); + } + if constexpr (N >= 6) { + out5_0 = hn::Add(out5_0, helper_out5_0); + out5_1 = hn::Add(out5_1, helper_out5_1); + } + if constexpr (N >= 7) { + out6_0 = hn::Add(out6_0, helper_out6_0); + out6_1 = hn::Add(out6_1, helper_out6_1); + } + if constexpr (N >= 8) { + out7_0 = hn::Add(out7_0, helper_out7_0); + out7_1 = hn::Add(out7_1, helper_out7_1); + } +#endif + StoreUpTo8Times2(df, out, i, out0_0, out0_1, out1_0, out1_1, out2_0, + out2_1, out3_0, out3_1, out4_0, out4_1, out5_0, out5_1, + out6_0, out6_1, out7_0, out7_1); + + i += 2 * NF; + } + HWY_DASSERT(qkv_dim == i); +} // Prescales NF rows of out by scale, then multiplies 1 row of V by the // corresponding values in c0 and adds them to the NF rows of out.