Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _make_merged_kv_cache(
[toks["attention_mask"] for toks in tok_parts], dim=1
)
assert input_ids.shape == attention_mask.shape
merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches(dc_parts)
merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches_v5(dc_parts)
# TODO: also assert that the merged cached is the correct shape given the input_ids and attention_mask shapes.
# rewind merged cache by 1 for safety.
merged_cache.crop(-1) # type: ignore
Expand Down
133 changes: 96 additions & 37 deletions mellea/backends/kv_block_helpers.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,106 @@
"""Utilities for KV smashing."""

from collections.abc import Iterable
from functools import reduce
from typing import Any

import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.cache_utils import DynamicCache
from transformers.tokenization_utils_base import BatchEncoding

TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache]
LegacyCache = Any

@torch.no_grad()
def prefill_cache_v5(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
text: str,
device: torch.device,
) -> tuple[dict, DynamicCache]:
"""Prefills cache for transformers v5."""
toks = tokenizer(text, return_tensors="pt")
toks = {k: v.to(device) for k, v in toks.items()}

def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache:
"""Concatenates two LegacyCache Ks and Vs along the time axis."""
legacy_merged = tuple(
(torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2))
for i in range(len(a))
dc = DynamicCache()
out = model(
input_ids=toks["input_ids"],
attention_mask=toks["attention_mask"],
past_key_values=dc,
use_cache=True,
)
return legacy_merged


def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache:
"""Merges two DynamicCache Ks and Vs along the time axis."""
legacies = [c.to_legacy_cache() for c in caches] # type: ignore
assert len(legacies) >= 1
rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore
return rv # type: ignore


def tokens_to_legacy_cache(
model, device: str, tokens_or_cache: BatchEncoding | DynamicCache
) -> Iterable[LegacyCache]:
"""Prefills and returns Ks and Vs as a LegacyCache."""
if type(tokens_or_cache) is DynamicCache:
return tokens_or_cache.to_legacy_cache() # type: ignore
else:
tokens = tokens_or_cache
dc = DynamicCache()
with torch.no_grad():
dc = model(
tokens["input_ids"].to(device), # type: ignore
attention_mask=tokens["attention_mask"].to(device), # type: ignore
past_key_values=dc,
).past_key_values
return dc.to_legacy_cache()
dc = out.past_key_values
dc.crop(-1)
return toks, dc # v5 returns DynamicCache (not legacy tuple)


def merge_dynamic_caches_v5(caches: Iterable[DynamicCache]) -> DynamicCache:
"""Merge multiple v5 DynamicCache objects by concatenating KV states along the time axis."""
caches = list(caches)
assert len(caches) >= 1

for c in caches:
if any(
getattr(layer, "is_sliding", False) for layer in getattr(c, "layers", [])
):
raise ValueError("Check the issue.")

merged = DynamicCache()

# reuse Cache.update() to append each segment's KV to the merged cache per layer.
# DynamicLayer.update(): self.keys = cat([self.keys, key_states], dim=-2).
for c in caches:
for layer_idx, layer in enumerate(c.layers):
if layer.keys is None or layer.values is None:
continue
merged.update(layer.keys, layer.values, layer_idx=layer_idx)

return merged


def merge_v5(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
strs: list[str],
device: torch.device,
):
"""Merges DynamicCache for transformers>=5.0.0."""
strs_toks, strs_dcs = [], []
for s in strs:
toks, dc = prefill_cache_v5(model, tokenizer, s, device)
strs_toks.append(toks)
strs_dcs.append(dc)

merged_toks = torch.cat([t["input_ids"] for t in strs_toks], dim=1)
merged_masks = torch.cat([t["attention_mask"] for t in strs_toks], dim=1)

merged_dc = merge_dynamic_caches_v5(strs_dcs)

return merged_toks, merged_masks, merged_dc


if __name__ == "__main__":
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B

backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B.hf_model_name)
model, tokenizer, device = backend._model, backend._tokenizer, backend._device
model: PreTrainedModel = model

docs = [
"Nathan Fulton is expert in large language models, formal verification, and reinforcement learning. He holds a Ph.D. from Carnegie Mellon University's Computer Science Department and has worked at Amazon Web Services and IBM Research. He currently works at IBM Research - Cambridge.",
"IBM Research has a headquarters at 1101 Kitchawan Rd in Yorktown Heights and a Cambridge office at 314 Main Street in Cambridge, MA.",
"What is the address of Nathan's place of work?",
]

merged_tokens, merged_masks, merged_cache = merge_v5(
model, tokenizer, docs, device=backend._device
)
input_ids = merged_tokens.to(device)
result = model.generate(
input_ids=input_ids,
use_cache=True,
return_dict_in_generate=True,
past_key_values=merged_cache,
max_new_tokens=512,
)
result = tokenizer.decode(
result.sequences[0, input_ids.shape[1] :], skip_special_tokens=True
)
print(result)
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ hf = [
"outlines-core==0.1.26",
"outlines", # intentionally un-versioned, expecting a minor update. coutlines-core version should be enough to specify it
"peft>=0.18.0", # aLoRA support was added in Peft 0.18.0
"transformers>=4.53.2,<5",
"transformers==5.0.0",
"trl==0.19.1",
"granite-common[transformers]",
]

vllm = [
"transformers<4.54.0",
"transformers", # Removing the <4.54.0 pin; need to figure out if this breaks anything. - TODO-nrf
# "transformers<4.54.0",
# see https://github.com/vllm-project/vllm-ascend/issues/2046
"numpy<2.0.0", # patching incorrect dependencies in vllm and outlines.
# see https://github.com/vllm-project/vllm/issues/5587
Expand All @@ -96,7 +97,7 @@ watsonx = [
"ibm-watsonx-ai>=1.3.31",
]
docling = [
"docling>=2.45.0",
# TODO-nrf re-enable this "docling>=2.45.0",
]

all = ["mellea[watsonx,docling,hf,vllm,litellm]"]
Expand Down
Loading