Skip to content

Commit 6aa6e71

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Int8 + microscaling support for kv cache formats.
Right now multiplication is done by converting to corresponding float format. Can yield up to 2x improvements for membw constrained shapes PiperOrigin-RevId: 874047973
1 parent df162ea commit 6aa6e71

8 files changed

Lines changed: 368 additions & 5 deletions

File tree

compression/compress-inl.h

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <memory>
2525
#include <vector>
2626

27+
2728
#include "compression/compress.h" // IWYU pragma: export
2829
#include "compression/distortion.h"
2930
#include "util/threading_context.h"
@@ -444,6 +445,146 @@ struct CompressTraits<SfpStream> {
444445
}
445446
};
446447

448+
template <>
449+
struct CompressTraits<int8_t> {
450+
using Packed = int8_t;
451+
452+
static size_t CompressBound(size_t num) { return num * sizeof(Packed); }
453+
454+
template <class DF, HWY_IF_F32_D(DF)>
455+
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
456+
size_t num, CompressPerThread& /*tls*/,
457+
const PackedSpan<Packed>& packed,
458+
const size_t packed_ofs) {
459+
const hn::Repartition<int32_t, DF> di32;
460+
const hn::Repartition<int16_t, DF> di16;
461+
const hn::Repartition<int8_t, DF> di8;
462+
using VF = hn::Vec<DF>;
463+
const size_t NF = hn::Lanes(df);
464+
465+
size_t i = 0;
466+
for (; i <= num - NF; i += NF) {
467+
const VF v = hn::LoadU(df, raw + i);
468+
auto vi32 = hn::NearestInt(v);
469+
auto vi16 = hn::DemoteTo(di16, vi32);
470+
auto vi8 = hn::DemoteTo(di8, vi16);
471+
hn::StoreU(vi8, di8, packed.ptr + packed_ofs + i);
472+
}
473+
const size_t remaining = num - i;
474+
if (remaining > 0) {
475+
const VF v = hn::LoadN(df, raw + i, remaining);
476+
auto vi32 = hn::NearestInt(v);
477+
auto vi16 = hn::DemoteTo(di16, vi32);
478+
auto vi8 = hn::DemoteTo(di8, vi16);
479+
hn::StoreN(vi8, di8, packed.ptr + packed_ofs + i, remaining);
480+
}
481+
}
482+
483+
static float ToFloatSlow(const Packed x) { return static_cast<float>(x); }
484+
485+
486+
template <class DF, HWY_IF_F32_D(DF)>
487+
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
488+
const size_t packed_ofs, hn::Vec<DF>& raw0,
489+
hn::Vec<DF>& raw1) {
490+
const hn::Repartition<int32_t, DF> di32;
491+
const hn::Repartition<int16_t, DF> di16;
492+
const hn::Repartition<int8_t, DF> di8;
493+
const hn::Half<decltype(di8)> di8_half;
494+
495+
const auto vec_i8 = hn::LoadU(di8_half, packed.ptr + packed_ofs);
496+
const auto vec_i8_full = hn::Combine(di8, hn::Zero(di8_half), vec_i8);
497+
const auto vec_i16 = hn::PromoteLowerTo(di16, vec_i8_full);
498+
const auto vec_i32_0 = hn::PromoteLowerTo(di32, vec_i16);
499+
const auto vec_i32_1 = hn::PromoteUpperTo(di32, vec_i16);
500+
501+
raw0 = hn::ConvertTo(df, vec_i32_0);
502+
raw1 = hn::ConvertTo(df, vec_i32_1);
503+
}
504+
505+
template <class DBF, HWY_IF_BF16_D(DBF)>
506+
static HWY_INLINE void Load2(DBF dbf, const PackedSpan<const Packed>& packed,
507+
const size_t packed_ofs, hn::Vec<DBF>& raw0,
508+
hn::Vec<DBF>& raw1) {
509+
const hn::Repartition<float, DBF> df;
510+
const hn::Repartition<int32_t, DBF> di32;
511+
const hn::Repartition<int16_t, DBF> di16;
512+
const hn::Repartition<int8_t, DBF> di8;
513+
514+
const auto v8 = hn::LoadU(di8, packed.ptr + packed_ofs);
515+
516+
const auto v16_0 = hn::PromoteLowerTo(di16, v8);
517+
const auto v16_1 = hn::PromoteUpperTo(di16, v8);
518+
519+
const auto v32_0_lo = hn::PromoteLowerTo(di32, v16_0);
520+
const auto v32_0_hi = hn::PromoteUpperTo(di32, v16_0);
521+
const auto f0_lo = hn::ConvertTo(df, v32_0_lo);
522+
const auto f0_hi = hn::ConvertTo(df, v32_0_hi);
523+
raw0 = hn::OrderedDemote2To(dbf, f0_lo, f0_hi);
524+
525+
const auto v32_1_lo = hn::PromoteLowerTo(di32, v16_1);
526+
const auto v32_1_hi = hn::PromoteUpperTo(di32, v16_1);
527+
const auto f1_lo = hn::ConvertTo(df, v32_1_lo);
528+
const auto f1_hi = hn::ConvertTo(df, v32_1_hi);
529+
raw1 = hn::OrderedDemote2To(dbf, f1_lo, f1_hi);
530+
}
531+
532+
template <class DF, HWY_IF_F32_D(DF)>
533+
static HWY_INLINE void DecompressAndZeroPad(
534+
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
535+
float* HWY_RESTRICT raw, size_t num) {
536+
const hn::Rebind<int32_t, DF> di32;
537+
const hn::Rebind<int16_t, DF> di16;
538+
const hn::Rebind<int8_t, DF> di8;
539+
using VF = hn::Vec<DF>;
540+
const size_t NF = hn::Lanes(df);
541+
542+
size_t i = 0;
543+
if (num >= 2 * NF) {
544+
for (; i <= num - 2 * NF; i += 2 * NF) {
545+
VF raw0, raw1;
546+
Load2(df, packed, packed_ofs + i, raw0, raw1);
547+
hn::StoreU(raw0, df, raw + i);
548+
hn::StoreU(raw1, df, raw + i + NF);
549+
}
550+
}
551+
552+
const size_t remaining = num - i;
553+
if (HWY_UNLIKELY(remaining != 0)) {
554+
for (size_t j = 0; j < remaining; ++j) {
555+
raw[i + j] = static_cast<float>(packed.ptr[packed_ofs + i + j]);
556+
}
557+
}
558+
}
559+
560+
template <class DBF, HWY_IF_BF16_D(DBF)>
561+
static HWY_INLINE void DecompressAndZeroPad(
562+
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
563+
BF16* HWY_RESTRICT raw, size_t num) {
564+
const hn::Repartition<float, DBF> df;
565+
const size_t NF = hn::Lanes(df);
566+
size_t i = 0;
567+
const size_t NBF = hn::Lanes(dbf);
568+
if (num >= NBF) {
569+
for (; i <= num - NBF; i += NBF) {
570+
hn::Vec<decltype(df)> f0, f1;
571+
Load2(df, packed, packed_ofs + i, f0, f1);
572+
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
573+
hn::StoreU(vbf, dbf, raw + i);
574+
}
575+
}
576+
const size_t remaining = num - i;
577+
if (remaining > 0) {
578+
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
579+
DecompressAndZeroPad(df, packed, packed_ofs + i, buf, remaining);
580+
auto f0 = hn::LoadU(df, buf);
581+
auto f1 = hn::LoadU(df, buf + NF);
582+
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
583+
hn::StoreN(vbf, dbf, raw + i, remaining);
584+
}
585+
}
586+
};
587+
447588
// Integer quantization.
448589
template <>
449590
struct CompressTraits<I8Stream> {

compression/compress_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ struct TestDecompress2 {
126126
HWY_ASSERT(stats.L1().Max() <= 0.08f);
127127
HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1()));
128128
HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1()));
129+
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
130+
HWY_ASSERT(stats.L1().Max() <= 0.6f);
129131
} else {
130132
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
131133
}
@@ -200,6 +202,8 @@ struct TestShortLengths {
200202
HWY_ASSERT(stats.L1().Max() <= 0.14f);
201203
HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1()));
202204
HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1()));
205+
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
206+
HWY_ASSERT(stats.L1().Max() <= 0.6f);
203207
} else {
204208
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
205209
}

compression/test_util-inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ void ForeachPackedAndRawType() {
7070
if constexpr (GEMMA_ENABLE_NUQ) {
7171
ForeachRawType<NuqStream, TestT>();
7272
}
73+
7374
}
7475

7576
template <class Test, class D>

compression/types.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ constexpr bool IsF32() {
192192
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
193193
}
194194

195+
template <typename Packed>
196+
constexpr bool IsInt8() {
197+
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int8_t>();
198+
}
199+
195200
template <typename Packed>
196201
constexpr bool IsBF16() {
197202
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
@@ -231,12 +236,13 @@ enum class Type {
231236
kI8,
232237
kU16,
233238
kU8,
239+
kInt8,
234240
};
235241
// These are used in `ModelConfig.Specifier`, hence the strings will not
236242
// change, though new ones may be added.
237-
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
238-
"nuq", "f64", "u32", "u64",
239-
"i8", "u16", "u8"};
243+
static constexpr const char* kTypeStrings[] = {
244+
"unknown", "f32", "bf16", "sfp", "nuq", "f64",
245+
"u32", "u64", "i8", "u16", "u8", "int8"};
240246
static constexpr size_t kNumTypes =
241247
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
242248
static constexpr size_t kTypeBits[] = {
@@ -251,6 +257,7 @@ static constexpr size_t kTypeBits[] = {
251257
8 * sizeof(I8Stream),
252258
8 * sizeof(uint16_t),
253259
8 * sizeof(uint8_t),
260+
8 * sizeof(int8_t),
254261
};
255262

256263
static inline bool EnumValid(Type type) {
@@ -281,6 +288,8 @@ constexpr Type TypeEnum() {
281288
return Type::kU16;
282289
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
283290
return Type::kU8;
291+
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
292+
return Type::kInt8;
284293
} else {
285294
return Type::kUnknown;
286295
}

gemma/flash_attention.cc

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,49 @@ static HWY_NOINLINE void ApplyMasking(
12881288
}
12891289
}
12901290

1291+
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
1292+
static HWY_INLINE void MultiplyByScale(DF df, const float* scales, VF& x0_p0,
1293+
VF& x0_p1, VF& x1_p0, VF& x1_p1,
1294+
VF& x2_p0, VF& x2_p1, VF& x3_p0,
1295+
VF& x3_p1, VF& x4_p0, VF& x4_p1,
1296+
VF& x5_p0, VF& x5_p1, VF& x6_p0,
1297+
VF& x6_p1, VF& x7_p0, VF& x7_p1) {
1298+
VF scales_p0 = hn::LoadU(df, scales);
1299+
VF scales_p1 = hn::LoadU(df, scales + hn::Lanes(df));
1300+
if constexpr (kNumQueries >= 1) {
1301+
x0_p0 = hn::Mul(x0_p0, scales_p0);
1302+
x0_p1 = hn::Mul(x0_p1, scales_p1);
1303+
}
1304+
if constexpr (kNumQueries >= 2) {
1305+
x1_p0 = hn::Mul(x1_p0, scales_p0);
1306+
x1_p1 = hn::Mul(x1_p1, scales_p1);
1307+
}
1308+
if constexpr (kNumQueries >= 3) {
1309+
x2_p0 = hn::Mul(x2_p0, scales_p0);
1310+
x2_p1 = hn::Mul(x2_p1, scales_p1);
1311+
}
1312+
if constexpr (kNumQueries >= 4) {
1313+
x3_p0 = hn::Mul(x3_p0, scales_p0);
1314+
x3_p1 = hn::Mul(x3_p1, scales_p1);
1315+
}
1316+
if constexpr (kNumQueries >= 5) {
1317+
x4_p0 = hn::Mul(x4_p0, scales_p0);
1318+
x4_p1 = hn::Mul(x4_p1, scales_p1);
1319+
}
1320+
if constexpr (kNumQueries >= 6) {
1321+
x5_p0 = hn::Mul(x5_p0, scales_p0);
1322+
x5_p1 = hn::Mul(x5_p1, scales_p1);
1323+
}
1324+
if constexpr (kNumQueries >= 7) {
1325+
x6_p0 = hn::Mul(x6_p0, scales_p0);
1326+
x6_p1 = hn::Mul(x6_p1, scales_p1);
1327+
}
1328+
if constexpr (kNumQueries >= 8) {
1329+
x7_p0 = hn::Mul(x7_p0, scales_p0);
1330+
x7_p1 = hn::Mul(x7_p1, scales_p1);
1331+
}
1332+
}
1333+
12911334
// Performs tiled flash attention for arbitrary number of queries
12921335
// It depends on kv being tiled.
12931336
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
@@ -1428,6 +1471,21 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
14281471
false,
14291472
"Query type type not supported, only float and BF16 are supported");
14301473
}
1474+
// microscaling
1475+
// TODO: Change to more generic function to inform if we should use
1476+
// microscaling or not.
1477+
constexpr bool kUseMicroScaling = IsInt8<KV_T>();
1478+
if constexpr (kUseMicroScaling) {
1479+
// After end of the tile, we have kTileSize * 2 floats for the
1480+
// microscaling scales for K and V.
1481+
const float* microscaling_scales_k =
1482+
reinterpret_cast<const float*>(tile_base + qkv_dim * 2 * kTileSize) +
1483+
pos_in_tile;
1484+
MultiplyByScale<kNumQueries>(df, microscaling_scales_k, x_0_p_0, x_0_p_1,
1485+
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
1486+
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
1487+
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
1488+
}
14311489

14321490
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
14331491
constexpr int kSecondHalfAmountOfQueries =
@@ -1461,6 +1519,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
14611519
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
14621520
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
14631521
kNumQueriesPerGroup);
1522+
if constexpr (kUseMicroScaling) {
1523+
const float* microscaling_scales_v =
1524+
reinterpret_cast<const float*>(tile_base + qkv_dim * 2 * kTileSize) +
1525+
kTileSize + pos_in_tile;
1526+
MultiplyByScale<kNumQueries>(
1527+
df, microscaling_scales_v, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1,
1528+
x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0,
1529+
x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
1530+
}
14641531
if constexpr (IsF32<Q_T>()) {
14651532
MulByConstAndAddTileUpTo8<kNumQueries>(
14661533
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,

0 commit comments

Comments
 (0)