Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <memory>
#include <vector>


#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "util/threading_context.h"
Expand Down Expand Up @@ -444,6 +445,146 @@ struct CompressTraits<SfpStream> {
}
};

template <>
struct CompressTraits<int8_t> {
using Packed = int8_t;

static size_t CompressBound(size_t num) { return num * sizeof(Packed); }

template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
size_t num, CompressPerThread& /*tls*/,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
const hn::Repartition<int32_t, DF> di32;
const hn::Repartition<int16_t, DF> di16;
const hn::Repartition<int8_t, DF> di8;
using VF = hn::Vec<DF>;
const size_t NF = hn::Lanes(df);

size_t i = 0;
for (; i <= num - NF; i += NF) {
const VF v = hn::LoadU(df, raw + i);
auto vi32 = hn::NearestInt(v);
auto vi16 = hn::DemoteTo(di16, vi32);
auto vi8 = hn::DemoteTo(di8, vi16);
hn::StoreU(vi8, di8, packed.ptr + packed_ofs + i);
}
const size_t remaining = num - i;
if (remaining > 0) {
const VF v = hn::LoadN(df, raw + i, remaining);
auto vi32 = hn::NearestInt(v);
auto vi16 = hn::DemoteTo(di16, vi32);
auto vi8 = hn::DemoteTo(di8, vi16);
hn::StoreN(vi8, di8, packed.ptr + packed_ofs + i, remaining);
}
}

static float ToFloatSlow(const Packed x) { return static_cast<float>(x); }


template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<DF>& raw0,
hn::Vec<DF>& raw1) {
const hn::Repartition<int32_t, DF> di32;
const hn::Repartition<int16_t, DF> di16;
const hn::Repartition<int8_t, DF> di8;
const hn::Half<decltype(di8)> di8_half;

const auto vec_i8 = hn::LoadU(di8_half, packed.ptr + packed_ofs);
const auto vec_i8_full = hn::Combine(di8, hn::Zero(di8_half), vec_i8);
const auto vec_i16 = hn::PromoteLowerTo(di16, vec_i8_full);
const auto vec_i32_0 = hn::PromoteLowerTo(di32, vec_i16);
const auto vec_i32_1 = hn::PromoteUpperTo(di32, vec_i16);

raw0 = hn::ConvertTo(df, vec_i32_0);
raw1 = hn::ConvertTo(df, vec_i32_1);
}

template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Load2(DBF dbf, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<DBF>& raw0,
hn::Vec<DBF>& raw1) {
const hn::Repartition<float, DBF> df;
const hn::Repartition<int32_t, DBF> di32;
const hn::Repartition<int16_t, DBF> di16;
const hn::Repartition<int8_t, DBF> di8;

const auto v8 = hn::LoadU(di8, packed.ptr + packed_ofs);

const auto v16_0 = hn::PromoteLowerTo(di16, v8);
const auto v16_1 = hn::PromoteUpperTo(di16, v8);

const auto v32_0_lo = hn::PromoteLowerTo(di32, v16_0);
const auto v32_0_hi = hn::PromoteUpperTo(di32, v16_0);
const auto f0_lo = hn::ConvertTo(df, v32_0_lo);
const auto f0_hi = hn::ConvertTo(df, v32_0_hi);
raw0 = hn::OrderedDemote2To(dbf, f0_lo, f0_hi);

const auto v32_1_lo = hn::PromoteLowerTo(di32, v16_1);
const auto v32_1_hi = hn::PromoteUpperTo(di32, v16_1);
const auto f1_lo = hn::ConvertTo(df, v32_1_lo);
const auto f1_hi = hn::ConvertTo(df, v32_1_hi);
raw1 = hn::OrderedDemote2To(dbf, f1_lo, f1_hi);
}

template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DecompressAndZeroPad(
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
float* HWY_RESTRICT raw, size_t num) {
const hn::Rebind<int32_t, DF> di32;
const hn::Rebind<int16_t, DF> di16;
const hn::Rebind<int8_t, DF> di8;
using VF = hn::Vec<DF>;
const size_t NF = hn::Lanes(df);

size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF raw0, raw1;
Load2(df, packed, packed_ofs + i, raw0, raw1);
hn::StoreU(raw0, df, raw + i);
hn::StoreU(raw1, df, raw + i + NF);
}
}

const size_t remaining = num - i;
if (HWY_UNLIKELY(remaining != 0)) {
for (size_t j = 0; j < remaining; ++j) {
raw[i + j] = static_cast<float>(packed.ptr[packed_ofs + i + j]);
}
}
}

template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void DecompressAndZeroPad(
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
BF16* HWY_RESTRICT raw, size_t num) {
const hn::Repartition<float, DBF> df;
const size_t NF = hn::Lanes(df);
size_t i = 0;
const size_t NBF = hn::Lanes(dbf);
if (num >= NBF) {
for (; i <= num - NBF; i += NBF) {
hn::Vec<decltype(df)> f0, f1;
Load2(df, packed, packed_ofs + i, f0, f1);
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
hn::StoreU(vbf, dbf, raw + i);
}
}
const size_t remaining = num - i;
if (remaining > 0) {
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed, packed_ofs + i, buf, remaining);
auto f0 = hn::LoadU(df, buf);
auto f1 = hn::LoadU(df, buf + NF);
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
hn::StoreN(vbf, dbf, raw + i, remaining);
}
}
};

// Integer quantization.
template <>
struct CompressTraits<I8Stream> {
Expand Down
4 changes: 4 additions & 0 deletions compression/compress_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ struct TestDecompress2 {
HWY_ASSERT(stats.L1().Max() <= 0.08f);
HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1()));
HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1()));
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
HWY_ASSERT(stats.L1().Max() <= 0.6f);
} else {
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
}
Expand Down Expand Up @@ -200,6 +202,8 @@ struct TestShortLengths {
HWY_ASSERT(stats.L1().Max() <= 0.14f);
HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1()));
HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1()));
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
HWY_ASSERT(stats.L1().Max() <= 0.6f);
} else {
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
}
Expand Down
1 change: 1 addition & 0 deletions compression/test_util-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void ForeachPackedAndRawType() {
if constexpr (GEMMA_ENABLE_NUQ) {
ForeachRawType<NuqStream, TestT>();
}

}

template <class Test, class D>
Expand Down
15 changes: 12 additions & 3 deletions compression/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}

template <typename Packed>
constexpr bool IsInt8() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int8_t>();
}

template <typename Packed>
constexpr bool IsBF16() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
Expand Down Expand Up @@ -231,12 +236,13 @@ enum class Type {
kI8,
kU16,
kU8,
kInt8,
};
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "u32", "u64",
"i8", "u16", "u8"};
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64",
"u32", "u64", "i8", "u16", "u8", "int8"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
Expand All @@ -251,6 +257,7 @@ static constexpr size_t kTypeBits[] = {
8 * sizeof(I8Stream),
8 * sizeof(uint16_t),
8 * sizeof(uint8_t),
8 * sizeof(int8_t),
};

static inline bool EnumValid(Type type) {
Expand Down Expand Up @@ -281,6 +288,8 @@ constexpr Type TypeEnum() {
return Type::kU16;
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
return Type::kU8;
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
return Type::kInt8;
} else {
return Type::kUnknown;
}
Expand Down
67 changes: 67 additions & 0 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,49 @@ static HWY_NOINLINE void ApplyMasking(
}
}

template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void MultiplyByScale(DF df, const float* scales, VF& x0_p0,
VF& x0_p1, VF& x1_p0, VF& x1_p1,
VF& x2_p0, VF& x2_p1, VF& x3_p0,
VF& x3_p1, VF& x4_p0, VF& x4_p1,
VF& x5_p0, VF& x5_p1, VF& x6_p0,
VF& x6_p1, VF& x7_p0, VF& x7_p1) {
VF scales_p0 = hn::LoadU(df, scales);
VF scales_p1 = hn::LoadU(df, scales + hn::Lanes(df));
if constexpr (kNumQueries >= 1) {
x0_p0 = hn::Mul(x0_p0, scales_p0);
x0_p1 = hn::Mul(x0_p1, scales_p1);
}
if constexpr (kNumQueries >= 2) {
x1_p0 = hn::Mul(x1_p0, scales_p0);
x1_p1 = hn::Mul(x1_p1, scales_p1);
}
if constexpr (kNumQueries >= 3) {
x2_p0 = hn::Mul(x2_p0, scales_p0);
x2_p1 = hn::Mul(x2_p1, scales_p1);
}
if constexpr (kNumQueries >= 4) {
x3_p0 = hn::Mul(x3_p0, scales_p0);
x3_p1 = hn::Mul(x3_p1, scales_p1);
}
if constexpr (kNumQueries >= 5) {
x4_p0 = hn::Mul(x4_p0, scales_p0);
x4_p1 = hn::Mul(x4_p1, scales_p1);
}
if constexpr (kNumQueries >= 6) {
x5_p0 = hn::Mul(x5_p0, scales_p0);
x5_p1 = hn::Mul(x5_p1, scales_p1);
}
if constexpr (kNumQueries >= 7) {
x6_p0 = hn::Mul(x6_p0, scales_p0);
x6_p1 = hn::Mul(x6_p1, scales_p1);
}
if constexpr (kNumQueries >= 8) {
x7_p0 = hn::Mul(x7_p0, scales_p0);
x7_p1 = hn::Mul(x7_p1, scales_p1);
}
}

// Performs tiled flash attention for arbitrary number of queries
// It depends on kv being tiled.
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
Expand Down Expand Up @@ -1428,6 +1471,21 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
false,
"Query type type not supported, only float and BF16 are supported");
}
// microscaling
// TODO: Change to more generic function to inform if we should use
// microscaling or not.
constexpr bool kUseMicroScaling = IsInt8<KV_T>();
if constexpr (kUseMicroScaling) {
// After end of the tile, we have kTileSize * 2 floats for the
// microscaling scales for K and V.
const float* microscaling_scales_k =
reinterpret_cast<const float*>(tile_base + qkv_dim * 2 * kTileSize) +
pos_in_tile;
MultiplyByScale<kNumQueries>(df, microscaling_scales_k, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
}

constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
Expand Down Expand Up @@ -1461,6 +1519,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
kNumQueriesPerGroup);
if constexpr (kUseMicroScaling) {
const float* microscaling_scales_v =
reinterpret_cast<const float*>(tile_base + qkv_dim * 2 * kTileSize) +
kTileSize + pos_in_tile;
MultiplyByScale<kNumQueries>(
df, microscaling_scales_v, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1,
x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0,
x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
}
if constexpr (IsF32<Q_T>()) {
MulByConstAndAddTileUpTo8<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
Expand Down
Loading
Loading