Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def __init__(
type_v: KV cache data type for V (default: f16)
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
Note:
Recurrent and hybrid models (Mamba, RWKV, Nemotron-A3B, Jamba) cannot
rewind their state and require full reset on history edits. This is handled
automatically to maintain compatibility. Standard transformers are unaffected.
Raises:
ValueError: If the model path does not exist.
Expand Down Expand Up @@ -553,6 +558,11 @@ def free_lora_adapter():

self._sampler = None

# Cache recurrent/hybrid model detection to avoid repeated FFI calls
self._is_recurrent_model = llama_cpp.llama_model_is_recurrent(
self._model.model
) or llama_cpp.llama_model_is_hybrid(self._model.model)

@property
def ctx(self) -> llama_cpp.llama_context_p:
return self._ctx.ctx
Expand Down Expand Up @@ -580,6 +590,19 @@ def eval_logits(self) -> Deque[List[float]]:
maxlen=self._n_ctx if self._logits_all else 1,
)

@property
def _is_recurrent(self) -> bool:
"""Check if model is recurrent (SSM) or hybrid (SSM+Attention).
These models (Mamba, RWKV, Nemotron, Jamba, etc.) cannot rewind their
recurrent state without snapshots. Only strict forward progression or
full reset is allowed.
Returns:
True if model has recurrent state that cannot be rewound.
"""
return self._is_recurrent_model

def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = False
) -> List[int]:
Expand Down Expand Up @@ -638,6 +661,11 @@ def reset(self):
"""Reset the model state."""
self.n_tokens = 0

if self._is_recurrent:
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
if mem is not None:
llama_cpp.llama_memory_clear(mem, True)

def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
Expand Down Expand Up @@ -888,11 +916,22 @@ def generate(
# Check for kv cache prefix match
if reset and self.n_tokens > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
for a, b in zip(self._input_ids, tokens):
if a == b:
longest_prefix += 1
else:
break

# Recurrent models cannot rewind state; reset if needed
if self._is_recurrent and longest_prefix < self.n_tokens:
longest_prefix = 0
reset = True
if self.verbose:
print(
"Llama.generate: recurrent model requires full state reset",
file=sys.stderr,
)

if longest_prefix > 0:
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
reset = False
Expand Down
2 changes: 1 addition & 1 deletion vendor/llama.cpp