Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,12 @@ cc_library(
name = "gemma_lib",
srcs = [
"gemma/gemma.cc",
"gemma/tiled_attention.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/gemma.h",
"gemma/tiled_attention.h",
"gemma/vit.h",
],
exec_properties = {
Expand Down
10 changes: 6 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ set(SOURCES
gemma/model_store.h
gemma/tensor_info.cc
gemma/tensor_info.h
gemma/tiled_attention.cc
gemma/tiled_attention.h
gemma/tokenizer.cc
gemma/tokenizer.h
gemma/vit.cc
Expand Down Expand Up @@ -171,20 +173,20 @@ install(TARGETS libgemma DESTINATION lib)
if(BUILD_GEMMA_DLL)
add_library(gemma_shared SHARED ${SOURCES})
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
set_target_properties(gemma_shared PROPERTIES
set_target_properties(gemma_shared PROPERTIES
PREFIX ""
OUTPUT_NAME "gemma"
)
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(gemma_shared PUBLIC ./)
target_link_libraries(gemma_shared PRIVATE
target_link_libraries(gemma_shared PRIVATE
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
)
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(gemma_shared
PRIVATE
target_compile_definitions(gemma_shared
PRIVATE
GEMMA_EXPORTS
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
)
Expand Down
18 changes: 18 additions & 0 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ struct AttentionActivations {
// Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums;

MatStorageT<float> k_tile_vec;
MatStorageT<float> v_tile_vec;
std::vector<MatStorageT<float>> sub_task_att_out;
std::vector<AlignedFloatVector>
sub_task_exp_denominator_sums;
std::vector<AlignedFloatVector>
sub_task_max_logits;

// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
Expand Down Expand Up @@ -244,6 +252,16 @@ struct AttentionActivationsPtrs {
// Accumulation of attention outputs over heads, size batch_size x
// model_dim.
MatPtrT<BF16> att_sums;
// Stores intermediate results of computing QKV,
// [qbatch * kv_heads , k_tile_size * qkv_dim]
MatPtrT<float> k_tile_vec;
MatPtrT<float> v_tile_vec;
// Used by TiledFlashAttention to store intermediate results.
std::vector<MatStorageT<float>>* sub_task_att_out;
std::vector<AlignedFloatVector>*
sub_task_exp_denominator_sums;
std::vector<AlignedFloatVector>*
sub_task_max_logits;
// Inverse timescales for RoPE computation.
MatPtrT<float> inv_timescale;
// Inverse timescales for global RoPE computation.
Expand Down
4 changes: 4 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ static inline bool EnumValid(LayerAttentionType type) {
enum class AttentionImpl {
kOld,
kFlash,
kFlashTransposedQs,
kFlashTransposedQsBF16,
kSentinel,
};

Expand All @@ -108,6 +110,8 @@ static inline int AttentionImplToFlags(AttentionImpl impl,
case AttentionImpl::kOld:
return kAttentionUseOld;
case AttentionImpl::kFlash:
case AttentionImpl::kFlashTransposedQs:
case AttentionImpl::kFlashTransposedQsBF16:
default:
return 0;
}
Expand Down
Loading
Loading