@@ -75,8 +75,10 @@ def __init__(
7575 self .model = model
7676 self .vocab = vocab
7777
78+ self ._lora_registry : Dict [str , LlamaLoraAdapter ] = {}
79+
7880 def close (self ):
79- """Manually free LlamaModel and Vocab resources."""
81+ """Manually free LlamaModel and Vocab/Lora resources."""
8082 if getattr (self , "model" , None ) is not None :
8183 try :
8284 llama_cpp .llama_model_free (self .model )
@@ -85,6 +87,10 @@ def close(self):
8587 self .model = None
8688 self .vocab = None
8789
90+ if hasattr (self , "_lora_registry" ) and self ._lora_registry :
91+ self .unload_all_loras ()
92+ self ._lora_registry = None
93+
8894 if getattr (self , "_exit_stack" , None ) is not None and hasattr (self ._exit_stack , "close" ):
8995 self ._exit_stack .close ()
9096 self ._exit_stack = None
@@ -311,8 +317,62 @@ def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
311317
312318 return bytes (buffer [:n_chars ])
313319
320+ # Lora
321+
322+ def load_lora (self , name : str , path : str ):
323+ """Loads a LoRA adapter into VRAM without applying it yet."""
324+ # Skip if it's already loaded
325+ if name in self ._lora_registry :
326+ return
327+
328+ adapter = LlamaLoraAdapter (self .model , path )
329+ self ._lora_registry [name ] = adapter
330+
331+ self ._exit_stack .callback (adapter .free )
332+
333+ if self .verbose :
334+ print (f"Loaded LoRA '{ name } ' into memory." )
335+
336+ def unload_lora (self , name : str ):
337+ """Actively unloads a specific LoRA to free up VRAM."""
338+ if name in self ._lora_registry :
339+ adapter = self ._lora_registry .pop (name )
340+ adapter .free ()
341+ if self .verbose :
342+ print (f"Unloaded LoRA '{ name } ' and freed memory." )
343+
344+ @property
345+ def loaded_lora_count (self ) -> int :
346+ """
347+ Returns the total number of LoRA adapters currently loaded in VRAM.
348+ """
349+ return len (self ._lora_registry )
350+
351+ def list_loras (self ) -> List [str ]:
352+ """
353+ Returns a list of all registered LoRA names.
354+ """
355+ return list (self ._lora_registry .keys ())
356+
357+ def unload_all_loras (self ):
358+ """
359+ Iterates through the registry and forces VRAM release for all loaded LoRAs.
360+ """
361+ if not self ._lora_registry :
362+ return
363+
364+ # Cast keys to a list first to avoid RuntimeError:
365+ # 'dictionary changed size during iteration' when pop() is called inside unload_lora.
366+ loaded_names = list (self ._lora_registry .keys ())
367+
368+ for name in loaded_names :
369+ self .unload_lora (name )
370+
371+ if self .verbose :
372+ print (f"Successfully unloaded all { len (loaded_names )} LoRA adapters and cleared the registry." )
314373
315374 # Extra
375+
316376 def metadata (self ) -> Dict [str , str ]:
317377 metadata : Dict [str , str ] = {}
318378 # Pre-allocate a 16KB buffer. This is large enough to handle almost all
@@ -356,6 +416,61 @@ def default_params():
356416 return llama_cpp .llama_model_default_params ()
357417
358418
419+ class LlamaLoraAdapter :
420+ """Wrapper for llama_adapter_lora_p to safely manage C++ memory lifecycle."""
421+ def __init__ (self , model : llama_cpp .llama_model_p , path : str ):
422+ """
423+ Initializes and loads the LoRA adapter into memory.
424+
425+ Args:
426+ model: The pointer to the base Llama model.
427+ path: The file path to the LoRA adapter (.gguf).
428+ """
429+ self .path = path
430+ # Load the LoRA adapter from file into memory via llama.cpp API
431+ # Note: The path string must be encoded to UTF-8 bytes for ctypes compatibility.
432+ self .adapter = llama_cpp .llama_adapter_lora_init (
433+ model ,
434+ path .encode ("utf-8" )
435+ )
436+ if not self .adapter :
437+ raise RuntimeError (f"Failed to load LoRA from { path } " )
438+
439+ def free (self ):
440+ """
441+ Explicitly frees the underlying C++ memory allocated for the LoRA adapter.
442+ Should be called when the adapter is actively unloaded to instantly release VRAM.
443+ """
444+ # Check if the adapter exists and hasn't been freed yet
445+ if getattr (self , "adapter" , None ) is not None :
446+ llama_cpp .llama_adapter_lora_free (self .adapter )
447+ self .adapter = None
448+ self .path = None
449+
450+ def __del__ (self ):
451+ self .free ()
452+
453+ @property
454+ def alora_invocation_tokens (self ) -> List [int ]:
455+ """
456+ Retrieves the list of invocation (trigger) tokens if this adapter is an ALoRA (Activation LoRA).
457+ Returns an empty list for standard LoRA adapters.
458+ """
459+ if getattr (self , "adapter" , None ) is None :
460+ return []
461+
462+ # 1. Query the C++ backend for the exact number of trigger tokens
463+ n_tokens = llama_cpp .llama_adapter_get_alora_n_invocation_tokens (self .adapter )
464+ if n_tokens == 0 :
465+ return []
466+
467+ # 2. Retrieve the underlying C pointer to the contiguous array of tokens
468+ tokens_ptr = llama_cpp .llama_adapter_get_alora_invocation_tokens (self .adapter )
469+
470+ # 3. Safely iterate through the C memory block and convert it into a native Python list
471+ return [tokens_ptr [i ] for i in range (n_tokens )]
472+
473+
359474class LlamaContext :
360475 """Intermediate Python wrapper for a llama.cpp llama_context.
361476 NOTE: For stability it's recommended you use the Llama class instead."""
@@ -577,21 +692,43 @@ def decode(self, batch: 'LlamaBatch') -> int:
577692 raise RuntimeError (f"llama_decode failed (code { return_code } ): { msg } " )
578693
579694 def set_n_threads (self , n_threads : int , n_threads_batch : int ):
695+ """
696+ Set the number of threads used for decoding
697+
698+ Args:
699+ n_threads: the number of threads used for generation (single token)
700+ n_threads_batch: the number of threads used for prompt and batch processing (multiple tokens)
701+ """
580702 llama_cpp .llama_set_n_threads (self .ctx , n_threads , n_threads_batch )
581703
582704 def n_threads (self ) -> int :
705+ """Get the number of threads used for generation of a single token."""
583706 return llama_cpp .llama_n_threads (self .ctx )
584707
585708 def n_threads_batch (self ) -> int :
709+ """Get the number of threads used for prompt and batch processing (multiple token)."""
586710 return llama_cpp .llama_n_threads_batch (self .ctx )
587711
588712 def set_causal_attn (self , causal_attn : bool ):
713+ """
714+ Set whether to use causal attention or not
715+ If set to true, the model will only attend to the past tokens
716+ """
589717 llama_cpp .llama_set_causal_attn (self .ctx , causal_attn )
590718
591719 def set_warmup (self , warmup : bool ):
720+ """
721+ Set whether the model is in warmup mode or not
722+ If true, all model tensors are activated during llama_decode() to load and cache their weights.
723+ """
592724 llama_cpp .llama_set_warmup (self .ctx , warmup )
593725
594726 def synchronize (self ):
727+ """
728+ Wait until all computations are finished
729+ This is automatically done when using one of the functions below to obtain the computation results
730+ and is not necessary to call it explicitly in most cases
731+ """
595732 llama_cpp .llama_synchronize (self .ctx )
596733
597734 def get_logits (self ):
@@ -619,9 +756,128 @@ def print_timings(self):
619756 llama_cpp .llama_perf_context_print (self .ctx )
620757
621758 def print_memory_breakdown (self ):
759+ """print a breakdown of per-device memory use via LLAMA_LOG"""
622760 llama_cpp .llama_memory_breakdown_print (self .ctx )
623761
762+ # LoRA / ALoRA Dynamic Routing Methods
763+
764+ def clear_loras (self ):
765+ """
766+ Clears all currently applied LoRA weights from the context.
767+ Restores the computational graph to the base model state.
768+ """
769+ llama_cpp .llama_set_adapters_lora (self .ctx , None , 0 , None )
770+
771+ def apply_loras (self , active_loras : List [Tuple ["LlamaLoraAdapter" , float ]]):
772+ """
773+ Dynamically mounts a combination of LoRAs and their scales to the current context.
774+ This must be called immediately before evaluating/decoding the computation graph.
775+
776+ Args:
777+ active_loras: A list of tuples containing (LlamaLoraAdapter instance, scale float).
778+ """
779+ # If no LoRAs are requested, ensure the context is wiped clean to prevent contamination
780+ if not active_loras :
781+ self .clear_loras ()
782+ return
783+
784+ n_adapters = len (active_loras )
785+
786+ # 1. Dynamically construct contiguous C-array types required by the C++ backend
787+ AdapterArrayType = llama_cpp .llama_adapter_lora_p_ctypes * n_adapters
788+ ScaleArrayType = ctypes .c_float * n_adapters
789+
790+ # 2. Instantiate the arrays in memory
791+ c_adapters = AdapterArrayType ()
792+ c_scales = ScaleArrayType ()
793+
794+ # 3. Populate the C-arrays with the underlying adapter pointers and float scales
795+ for i , (adapter_obj , scale ) in enumerate (active_loras ):
796+ c_adapters [i ] = adapter_obj .adapter
797+ c_scales [i ] = scale
798+
799+ # 4. Atomically apply the requested adapters to the computation graph
800+ ret = llama_cpp .llama_set_adapters_lora (
801+ self .ctx ,
802+ c_adapters ,
803+ n_adapters ,
804+ c_scales
805+ )
806+
807+ if ret != 0 :
808+ raise RuntimeError ("LlamaContext(apply_loras): Failed to set LoRA adapters dynamically." )
809+
810+ if self .verbose :
811+ print (f"LlamaContext(apply_loras): Successfully applied { n_adapters } LoRA adapter(s) to the compute graph." )
812+
813+ # Control Vector (CVec) Methods
814+
815+ def clear_cvec (self ):
816+ """
817+ Clears the currently loaded control vector from the context.
818+ Passing NULL (None) and zeros safely resets the graph.
819+ """
820+ llama_cpp .llama_set_adapter_cvec (self .ctx , None , 0 , 0 , 0 , 0 )
821+
822+ def apply_cvec (self , data : List [float ], n_embd : int , il_start : int , il_end : int ):
823+ """
824+ Dynamically applies a Control Vector (CVec) to the specified layer range.
825+
826+ Args:
827+ data: Flattened 1D list of floats.
828+ [CRITICAL_LAYOUT_RULE]: Based on llama.cpp source, the data buffer
829+ is strictly mapped starting from Layer 1. Even if il_start > 1,
830+ the `data` array must contain zero-padding for the skipped early layers.
831+ Total length MUST be >= n_embd * il_end.
832+ n_embd: The embedding dimension of the model.
833+ il_start: The starting layer to apply the vector (inclusive, 1-indexed).
834+ il_end: The ending layer to apply the vector (inclusive).
835+ """
836+ if not data :
837+ self .clear_cvec ()
838+ return
839+
840+ length = len (data )
841+
842+ # Strictly validate length based on C++ buffer mapping rules
843+ # The C++ backend uses offset: off = n_embd * (il - 1).
844+ # To successfully write up to il_end, the buffer length must be at least n_embd * il_end.
845+ minimum_required_len = n_embd * il_end
846+ if length < minimum_required_len :
847+ raise ValueError (
848+ f"LlamaContext(apply_cvec): "
849+ f"[Memory Layout Error] Control vector data length ({ length } ) is too short. "
850+ f"llama.cpp requires the buffer to map continuously from Layer 1. "
851+ f"To apply up to layer { il_end } , length must be at least { minimum_required_len } ."
852+ )
853+
854+ # 1. Convert to C Array
855+ CFloatArrayType = ctypes .c_float * length
856+ c_data = CFloatArrayType (* data )
857+
858+ # 2. Inject into graph
859+ ret = llama_cpp .llama_set_adapter_cvec (
860+ self .ctx ,
861+ c_data ,
862+ length ,
863+ n_embd ,
864+ il_start ,
865+ il_end
866+ )
867+
868+ # 3. Handle specific C++ boolean false (converted to -1)
869+ if ret != 0 :
870+ raise RuntimeError (
871+ f"LlamaContext(apply_cvec): "
872+ f"C++ backend rejected the Control Vector. "
873+ f"Usually indicates n_embd ({ n_embd } ) does not match the model's actual embedding dimension."
874+ )
875+
876+ if self .verbose :
877+ print (f"LlamaContext(apply_cvec): Applied Control Vector to layers { il_start } -{ il_end } (Buffer size matched C++ layout)." )
878+
624879 # Utility functions
880+
625881 @staticmethod
626882 def default_params ():
627883 """Get the default llama_context_params."""
0 commit comments