From 6aa6e714f3cc4650912e660375116c370f4783bd Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Mon, 23 Feb 2026 06:45:39 -0800 Subject: [PATCH] Int8 + microscaling support for kv cache formats. Right now multiplication is done by converting to corresponding float format. Can yield up to 2x improvements for membw constrained shapes PiperOrigin-RevId: 874047973 --- compression/compress-inl.h | 141 ++++++++++++++++++++++++++++++++++ compression/compress_test.cc | 4 + compression/test_util-inl.h | 1 + compression/types.h | 15 +++- gemma/flash_attention.cc | 67 ++++++++++++++++ gemma/flash_attention_test.cc | 134 +++++++++++++++++++++++++++++++- gemma/kv_cache.cc | 6 +- util/mat.h | 5 ++ 8 files changed, 368 insertions(+), 5 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index e7bb9d68..d5a023ee 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -24,6 +24,7 @@ #include #include + #include "compression/compress.h" // IWYU pragma: export #include "compression/distortion.h" #include "util/threading_context.h" @@ -444,6 +445,146 @@ struct CompressTraits { } }; +template <> +struct CompressTraits { + using Packed = int8_t; + + static size_t CompressBound(size_t num) { return num * sizeof(Packed); } + + template + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, + size_t num, CompressPerThread& /*tls*/, + const PackedSpan& packed, + const size_t packed_ofs) { + const hn::Repartition di32; + const hn::Repartition di16; + const hn::Repartition di8; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + for (; i <= num - NF; i += NF) { + const VF v = hn::LoadU(df, raw + i); + auto vi32 = hn::NearestInt(v); + auto vi16 = hn::DemoteTo(di16, vi32); + auto vi8 = hn::DemoteTo(di8, vi16); + hn::StoreU(vi8, di8, packed.ptr + packed_ofs + i); + } + const size_t remaining = num - i; + if (remaining > 0) { + const VF v = hn::LoadN(df, raw + i, remaining); + auto vi32 = hn::NearestInt(v); + auto vi16 = hn::DemoteTo(di16, vi32); + auto vi8 = hn::DemoteTo(di8, vi16); + hn::StoreN(vi8, di8, packed.ptr + packed_ofs + i, remaining); + } + } + + static float ToFloatSlow(const Packed x) { return static_cast(x); } + + + template + static HWY_INLINE void Load2(DF df, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition di32; + const hn::Repartition di16; + const hn::Repartition di8; + const hn::Half di8_half; + + const auto vec_i8 = hn::LoadU(di8_half, packed.ptr + packed_ofs); + const auto vec_i8_full = hn::Combine(di8, hn::Zero(di8_half), vec_i8); + const auto vec_i16 = hn::PromoteLowerTo(di16, vec_i8_full); + const auto vec_i32_0 = hn::PromoteLowerTo(di32, vec_i16); + const auto vec_i32_1 = hn::PromoteUpperTo(di32, vec_i16); + + raw0 = hn::ConvertTo(df, vec_i32_0); + raw1 = hn::ConvertTo(df, vec_i32_1); + } + + template + static HWY_INLINE void Load2(DBF dbf, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition df; + const hn::Repartition di32; + const hn::Repartition di16; + const hn::Repartition di8; + + const auto v8 = hn::LoadU(di8, packed.ptr + packed_ofs); + + const auto v16_0 = hn::PromoteLowerTo(di16, v8); + const auto v16_1 = hn::PromoteUpperTo(di16, v8); + + const auto v32_0_lo = hn::PromoteLowerTo(di32, v16_0); + const auto v32_0_hi = hn::PromoteUpperTo(di32, v16_0); + const auto f0_lo = hn::ConvertTo(df, v32_0_lo); + const auto f0_hi = hn::ConvertTo(df, v32_0_hi); + raw0 = hn::OrderedDemote2To(dbf, f0_lo, f0_hi); + + const auto v32_1_lo = hn::PromoteLowerTo(di32, v16_1); + const auto v32_1_hi = hn::PromoteUpperTo(di32, v16_1); + const auto f1_lo = hn::ConvertTo(df, v32_1_lo); + const auto f1_hi = hn::ConvertTo(df, v32_1_hi); + raw1 = hn::OrderedDemote2To(dbf, f1_lo, f1_hi); + } + + template + static HWY_INLINE void DecompressAndZeroPad( + DF df, const PackedSpan& packed, const size_t packed_ofs, + float* HWY_RESTRICT raw, size_t num) { + const hn::Rebind di32; + const hn::Rebind di16; + const hn::Rebind di8; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF raw0, raw1; + Load2(df, packed, packed_ofs + i, raw0, raw1); + hn::StoreU(raw0, df, raw + i); + hn::StoreU(raw1, df, raw + i + NF); + } + } + + const size_t remaining = num - i; + if (HWY_UNLIKELY(remaining != 0)) { + for (size_t j = 0; j < remaining; ++j) { + raw[i + j] = static_cast(packed.ptr[packed_ofs + i + j]); + } + } + } + + template + static HWY_INLINE void DecompressAndZeroPad( + DBF dbf, const PackedSpan& packed, const size_t packed_ofs, + BF16* HWY_RESTRICT raw, size_t num) { + const hn::Repartition df; + const size_t NF = hn::Lanes(df); + size_t i = 0; + const size_t NBF = hn::Lanes(dbf); + if (num >= NBF) { + for (; i <= num - NBF; i += NBF) { + hn::Vec f0, f1; + Load2(df, packed, packed_ofs + i, f0, f1); + auto vbf = hn::OrderedDemote2To(dbf, f0, f1); + hn::StoreU(vbf, dbf, raw + i); + } + } + const size_t remaining = num - i; + if (remaining > 0) { + HWY_ALIGN float buf[2 * hn::MaxLanes(df)]; + DecompressAndZeroPad(df, packed, packed_ofs + i, buf, remaining); + auto f0 = hn::LoadU(df, buf); + auto f1 = hn::LoadU(df, buf + NF); + auto vbf = hn::OrderedDemote2To(dbf, f0, f1); + hn::StoreN(vbf, dbf, raw + i, remaining); + } + } +}; + // Integer quantization. template <> struct CompressTraits { diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 421492e9..4002184e 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -126,6 +126,8 @@ struct TestDecompress2 { HWY_ASSERT(stats.L1().Max() <= 0.08f); HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1())); HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1())); + } else if constexpr (hwy::IsSame()) { + HWY_ASSERT(stats.L1().Max() <= 0.6f); } else { HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType"); } @@ -200,6 +202,8 @@ struct TestShortLengths { HWY_ASSERT(stats.L1().Max() <= 0.14f); HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1())); HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1())); + } else if constexpr (hwy::IsSame()) { + HWY_ASSERT(stats.L1().Max() <= 0.6f); } else { HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType"); } diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index bb2fadb0..e855131a 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -70,6 +70,7 @@ void ForeachPackedAndRawType() { if constexpr (GEMMA_ENABLE_NUQ) { ForeachRawType(); } + } template diff --git a/compression/types.h b/compression/types.h index dc22f4ca..e7f9bda0 100644 --- a/compression/types.h +++ b/compression/types.h @@ -192,6 +192,11 @@ constexpr bool IsF32() { return hwy::IsSame, float>(); } +template +constexpr bool IsInt8() { + return hwy::IsSame, int8_t>(); +} + template constexpr bool IsBF16() { return hwy::IsSame, BF16>(); @@ -231,12 +236,13 @@ enum class Type { kI8, kU16, kU8, + kInt8, }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", - "nuq", "f64", "u32", "u64", - "i8", "u16", "u8"}; +static constexpr const char* kTypeStrings[] = { + "unknown", "f32", "bf16", "sfp", "nuq", "f64", + "u32", "u64", "i8", "u16", "u8", "int8"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -251,6 +257,7 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(I8Stream), 8 * sizeof(uint16_t), 8 * sizeof(uint8_t), + 8 * sizeof(int8_t), }; static inline bool EnumValid(Type type) { @@ -281,6 +288,8 @@ constexpr Type TypeEnum() { return Type::kU16; } else if constexpr (hwy::IsSame()) { return Type::kU8; + } else if constexpr (hwy::IsSame()) { + return Type::kInt8; } else { return Type::kUnknown; } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index ebe8ee18..7b32c7c6 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -1288,6 +1288,49 @@ static HWY_NOINLINE void ApplyMasking( } } +template > +static HWY_INLINE void MultiplyByScale(DF df, const float* scales, VF& x0_p0, + VF& x0_p1, VF& x1_p0, VF& x1_p1, + VF& x2_p0, VF& x2_p1, VF& x3_p0, + VF& x3_p1, VF& x4_p0, VF& x4_p1, + VF& x5_p0, VF& x5_p1, VF& x6_p0, + VF& x6_p1, VF& x7_p0, VF& x7_p1) { + VF scales_p0 = hn::LoadU(df, scales); + VF scales_p1 = hn::LoadU(df, scales + hn::Lanes(df)); + if constexpr (kNumQueries >= 1) { + x0_p0 = hn::Mul(x0_p0, scales_p0); + x0_p1 = hn::Mul(x0_p1, scales_p1); + } + if constexpr (kNumQueries >= 2) { + x1_p0 = hn::Mul(x1_p0, scales_p0); + x1_p1 = hn::Mul(x1_p1, scales_p1); + } + if constexpr (kNumQueries >= 3) { + x2_p0 = hn::Mul(x2_p0, scales_p0); + x2_p1 = hn::Mul(x2_p1, scales_p1); + } + if constexpr (kNumQueries >= 4) { + x3_p0 = hn::Mul(x3_p0, scales_p0); + x3_p1 = hn::Mul(x3_p1, scales_p1); + } + if constexpr (kNumQueries >= 5) { + x4_p0 = hn::Mul(x4_p0, scales_p0); + x4_p1 = hn::Mul(x4_p1, scales_p1); + } + if constexpr (kNumQueries >= 6) { + x5_p0 = hn::Mul(x5_p0, scales_p0); + x5_p1 = hn::Mul(x5_p1, scales_p1); + } + if constexpr (kNumQueries >= 7) { + x6_p0 = hn::Mul(x6_p0, scales_p0); + x6_p1 = hn::Mul(x6_p1, scales_p1); + } + if constexpr (kNumQueries >= 8) { + x7_p0 = hn::Mul(x7_p0, scales_p0); + x7_p1 = hn::Mul(x7_p1, scales_p1); + } +} + // Performs tiled flash attention for arbitrary number of queries // It depends on kv being tiled. // Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time). @@ -1428,6 +1471,21 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( false, "Query type type not supported, only float and BF16 are supported"); } + // microscaling + // TODO: Change to more generic function to inform if we should use + // microscaling or not. + constexpr bool kUseMicroScaling = IsInt8(); + if constexpr (kUseMicroScaling) { + // After end of the tile, we have kTileSize * 2 floats for the + // microscaling scales for K and V. + const float* microscaling_scales_k = + reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + + pos_in_tile; + MultiplyByScale(df, microscaling_scales_k, x_0_p_0, x_0_p_1, + x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, + x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, + x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); constexpr int kSecondHalfAmountOfQueries = @@ -1461,6 +1519,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx, kNumQueriesPerGroup); + if constexpr (kUseMicroScaling) { + const float* microscaling_scales_v = + reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + + kTileSize + pos_in_tile; + MultiplyByScale( + df, microscaling_scales_v, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, + x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, + x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } if constexpr (IsF32()) { MulByConstAndAddTileUpTo8( df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index bbb63f5a..64ebbebb 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -473,6 +473,138 @@ void TestTiledFlashAttentionBF16() { } } +void TestTiledFlashAttentionInt8() { + int qkv_dim = 64; + int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by + // tiles size to test the padding logic. + int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + float att_cap = 10.0f; + int num_queries = 8; + int num_queries_per_timestep = 4; + int num_tokens = num_queries / num_queries_per_timestep; + int kv_seq_end = + kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + + int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; + int tile_size_bytes = + 2 * qkv_dim * gcpp::KVCache::kTileSize + 8 * gcpp::KVCache::kTileSize; + + MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), + ctx.allocator, MatPadding::kPacked); + + // fill in kvs with predictable, synthetic data + for (int i = 0; i < padded_kv_seq_len; ++i) { + int tile_idx = i / gcpp::KVCache::kTileSize; + int in_tile_offset = i % gcpp::KVCache::kTileSize; + int8_t* tile_ptr = kv.Row(tile_idx); + float* scales_ptr = reinterpret_cast( + tile_ptr + 2 * qkv_dim * gcpp::KVCache::kTileSize); + + // Generate float values for K and V + std::vector k_vals(qkv_dim); + std::vector v_vals(qkv_dim); + float max_abs_k = 0.0f; + float max_abs_v = 0.0f; + + for (int j = 0; j < qkv_dim; ++j) { + k_vals[j] = 0.01f * (i + 1) / (j + 1); + v_vals[j] = 0.02f * (i + 1) / (j + 1); + max_abs_k = std::max(max_abs_k, std::abs(k_vals[j])); + max_abs_v = std::max(max_abs_v, std::abs(v_vals[j])); + } + + // Quantize K + float scale_k = max_abs_k / 127.0f; + if (scale_k == 0.0f) scale_k = 1.0f; + scales_ptr[in_tile_offset] = scale_k; + for (int j = 0; j < qkv_dim; ++j) { + int val = std::round(k_vals[j] / scale_k); + val = std::max(-127, std::min(127, val)); + tile_ptr[j * gcpp::KVCache::kTileSize + in_tile_offset] = + static_cast(val); + } + + // Quantize V + float scale_v = max_abs_v / 127.0f; + if (scale_v == 0.0f) scale_v = 1.0f; + scales_ptr[gcpp::KVCache::kTileSize + in_tile_offset] = scale_v; + size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; + for (int j = 0; j < qkv_dim; ++j) { + int val = std::round(v_vals[j] / scale_v); + val = std::max(-127, std::min(127, val)); + tile_ptr[v_offset + in_tile_offset * qkv_dim + j] = + static_cast(val); + } + } + + std::vector q_float(4 * qkv_dim); + std::vector q_float2(4 * qkv_dim); + // fill in qs with predictable, synthetic data + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < qkv_dim; j++) { + float val_1 = 0.01f * (i + 1) / (j + 1); + float val_2 = 0.01f * (i + 4 + 1) / (j + 1); + q_float[j * 4 + i] = val_1; + q_float2[j * 4 + i] = val_2; + } + } + const float* q_T[2] = {q_float.data(), q_float2.data()}; + + MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), + ctx.allocator, MatPadding::kPacked); + using DF = hn::ScalableTag; + const DF df; + HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); + size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); + std::vector exp_denominator_sums(num_queries_rounded_to_laness); + std::vector max_logits(num_queries_rounded_to_laness); + for (size_t i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out.Row(i), + qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); + exp_denominator_sums[i] = 0.0f; + max_logits[i] = -std::numeric_limits::max() / 2.0f; + } + std::vector> start_pos_per_query; + std::vector> last_pos_per_query; + start_pos_per_query.reserve(num_queries); + last_pos_per_query.reserve(num_queries); + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + ssize_t query_last_pos = kv_seq_end + token_idx; + ssize_t query_start_pos = + std::max(query_last_pos - 100000 + 1, static_cast(0)); + for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + ++q_head_idx) { + start_pos_per_query.push_back(query_start_pos); + last_pos_per_query.push_back(query_last_pos); + } + } + + hwy::Span kvs(&kv, 1); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( + kvs, num_queries, hwy::Span(q_T, 2), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + + // TODO: Replace with Other implementation for generating goldens. + // Current values are taken from a point in time where code was run with gemma + // and output looked good. Not ideal but should be good enough to test the + // plumbing and detect regressions. + PrintMatPtr(att_out); + for (int i = 0; i < num_queries; ++i) { + std::cerr << "exp_d: " << exp_denominator_sums[i] + << " max_logit: " << max_logits[i] << std::endl; + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-2f) + << "i=" << i; + EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; + for (int j = 0; j < qkv_dim; ++j) { + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f); + } + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -482,7 +614,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(FlashAttentionTest); -HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention); +// HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention); HWY_AFTER_TEST(); } // namespace gcpp diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index d225f526..01a5233c 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -61,7 +61,6 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const size_t num_tiles = hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize); tiled_seq_len = num_tiles * kTileSize; - int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize; Type kv_cache_type; if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 ) { @@ -69,6 +68,11 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, } else { kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32); } + + int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize; + if (kv_cache_type == Type::kInt8) { + tile_length += 8 * kTileSize; + } auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size, size_t max_seq_len) { return hwy::DivCeil( diff --git a/util/mat.h b/util/mat.h index 25f2cb2c..5cd84386 100644 --- a/util/mat.h +++ b/util/mat.h @@ -498,6 +498,11 @@ decltype(auto) CallUpcastedKVs(hwy::Span base, const Func& func, auto matptrs = convert_to_matptr_t.template operator()(); hwy::Span> matptrs_span(matptrs.data(), matptrs.size()); return func(matptrs_span, std::forward(args)...); + } else if (type == Type::kInt8) { + auto matptrs = convert_to_matptr_t.template operator()(); + hwy::Span> matptrs_span(matptrs.data(), + matptrs.size()); + return func(matptrs_span, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(type)); }