From cb9afdb78f283d568688dfaf8df4c35637f1c700 Mon Sep 17 00:00:00 2001 From: Viktor Shipitsin Date: Wed, 25 Feb 2026 09:35:01 -0800 Subject: [PATCH] Use `gtl::fixed_flat_map_of` to manage the mapping between `AttentionImpl` enum values and their string names, simplifying `GetAttentionImplName` function. Add a test to ensure all valid `AttentionImpl` enums have a corresponding name and can be looked up. PiperOrigin-RevId: 875204584 --- BUILD.bazel | 2 ++ gemma/configs.cc | 30 +++++++++++++++++++++++------- gemma/configs_test.cc | 12 ++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 885bb665..fc7d5b39 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -191,10 +191,12 @@ cc_library( hdrs = ["gemma/configs.h"], deps = [ ":basics", + "//third_party/absl/strings:string_view", "//compression:types", "//io", "//io:fields", "@highway//:hwy", # base.h + "//util/gtl:flat_map", ], ) diff --git a/gemma/configs.cc b/gemma/configs.cc index 000e2786..3ca47b61 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -21,10 +21,12 @@ #include #include +#include "third_party/absl/strings/string_view.h" #include "compression/types.h" // Type #include "io/fields.h" // IFields #include "io/io.h" // Path #include "hwy/base.h" +#include "util/gtl/flat_map.h" namespace gcpp { @@ -678,7 +680,7 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { return Model::GEMMA3_270M; case 26: - if (layer_types & (kDeducedViT|kDeducedKqNorm)) { + if (layer_types & (kDeducedViT | kDeducedKqNorm)) { return Model::GEMMA3_1B; } return Model::GEMMA2_2B; @@ -718,17 +720,31 @@ const char* kAttentionImplNames[] = { "unknown" // keep last }; +constexpr auto kAttentionImplNameToEnum = + gtl::fixed_flat_map_of({ + {"old", AttentionImpl::kOld}, + {"flash", AttentionImpl::kFlash}, + {"flash_transposed_qs", AttentionImpl::kFlashTransposedQs}, + {"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16}, + }); + std::string GetAttentionImplName(AttentionImpl impl) { - return kAttentionImplNames[static_cast(impl)]; + for (auto const& [name, attention_impl] : kAttentionImplNameToEnum) { + if (attention_impl == impl) { + return std::string(name); + } + } + return "unknown"; } AttentionImpl GetAttentionImpl(const std::string& impl) { - if (impl == GetAttentionImplName(AttentionImpl::kOld)) + auto pair = kAttentionImplNameToEnum.find(impl); + if (pair == kAttentionImplNameToEnum.end()) { + HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", + impl.c_str()); return AttentionImpl::kOld; - if (impl == GetAttentionImplName(AttentionImpl::kFlash)) - return AttentionImpl::kFlash; - HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str()); - return AttentionImpl::kOld; + } + return pair->second; } } // namespace gcpp diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 0ca4a848..e6f02579 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -41,4 +41,16 @@ TEST(ConfigsTest, TestAll) { }); } +TEST(ConfigsTest, TestAttentionImpl) { + for (int i = 0; i < static_cast(AttentionImpl::kSentinel); ++i) { + AttentionImpl impl = static_cast(i); + std::string name = GetAttentionImplName(impl); + ASSERT_NE(name, "unknown"); + ASSERT_EQ(GetAttentionImpl(name), impl); + } + ASSERT_EQ(GetAttentionImplName(AttentionImpl::kSentinel), "unknown"); + ASSERT_EQ(GetAttentionImpl("unknown"), AttentionImpl::kOld); + ASSERT_EQ(GetAttentionImpl("invalid"), AttentionImpl::kOld); +} + } // namespace gcpp