From c91f665605fe3168c619df2a1209d5edfb07a6bc Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Mon, 23 Feb 2026 09:02:27 -0600 Subject: [PATCH 1/5] Add MLX KV cache reuse for incremental prefill Persist KV caches across respond()/streamResponse() calls within the same LanguageModelSession. On subsequent turns only the new tokens are prefilled instead of re-encoding the entire conversation history, dramatically reducing time to first token. - Add maxKVSize, kvBits, kvGroupSize to GenerationOptions - Add SessionCacheEntry store with NSMapTable weak keys - Implement incremental prefill in streamResponse() and respond() - Enhance prewarm() to prefill system prompt into KV cache Co-Authored-By: Claude Opus 4.6 --- .../AnyLanguageModel/GenerationOptions.swift | 30 +++- .../Models/MLXLanguageModel.swift | 150 +++++++++++++++++- 2 files changed, 172 insertions(+), 8 deletions(-) diff --git a/Sources/AnyLanguageModel/GenerationOptions.swift b/Sources/AnyLanguageModel/GenerationOptions.swift index 5d0f6338..17da525c 100644 --- a/Sources/AnyLanguageModel/GenerationOptions.swift +++ b/Sources/AnyLanguageModel/GenerationOptions.swift @@ -121,6 +121,25 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// an error will be thrown. public var maximumResponseTokens: Int? + /// Maximum number of tokens to retain in the KV cache. + /// + /// When set, uses a rotating cache that evicts oldest tokens beyond this limit. + /// When `nil` (default), the cache grows unbounded. + /// + /// Recommended values: 2048–4096 for iPhone, `nil` for Mac. + public var maxKVSize: Int? + + /// Bit width for KV cache quantization (for example, 4 or 8). + /// + /// Reduces cache memory usage at slight quality cost. + /// When `nil` (default), the cache uses full precision. + public var kvBits: Int? + + /// Group size for KV cache quantization. + /// + /// Only meaningful when ``kvBits`` is set. Default is 64. + public var kvGroupSize: Int + /// Storage for model-specific custom options. private var customOptionsStorage: CustomOptionsStorage = .init() @@ -157,14 +176,23 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// responses. Must be between `0` and `1`, inclusive. /// - maximumResponseTokens: The maximum number of tokens the model is allowed /// to produce before being artificially halted. Must be positive. + /// - maxKVSize: Maximum tokens in the KV cache. When set, enables a rotating cache. + /// - kvBits: Bit width for KV cache quantization. + /// - kvGroupSize: Group size for KV cache quantization. Default is 64. public init( sampling: SamplingMode? = nil, temperature: Double? = nil, - maximumResponseTokens: Int? = nil + maximumResponseTokens: Int? = nil, + maxKVSize: Int? = nil, + kvBits: Int? = nil, + kvGroupSize: Int = 64 ) { self.sampling = sampling self.temperature = temperature self.maximumResponseTokens = maximumResponseTokens + self.maxKVSize = maxKVSize + self.kvBits = kvBits + self.kvGroupSize = kvGroupSize } } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 4ffb877a..e84a6aee 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -120,6 +120,42 @@ import Foundation /// Shared cache across MLXLanguageModel instances. private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) + // MARK: - Session KV Cache Store + + /// Stores a KV cache and its prefill token count for a session. + private final class SessionCacheEntry: NSObject, @unchecked Sendable { + var kvCache: [MLXLMCommon.KVCache] + var prefillTokenCount: Int + + init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int) { + self.kvCache = kvCache + self.prefillTokenCount = prefillTokenCount + } + } + + /// Maps LanguageModelSession (weak key) → SessionCacheEntry. + /// When a session is deallocated, its cache entry is automatically released. + private nonisolated(unsafe) let sessionKVCache = NSMapTable.weakToStrongObjects() + private let sessionKVCacheLock = NSLock() + + private func getSessionCache(_ session: LanguageModelSession) -> SessionCacheEntry? { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + return sessionKVCache.object(forKey: session) + } + + private func setSessionCache(_ entry: SessionCacheEntry, for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.setObject(entry, forKey: session) + } + + private func removeSessionCache(for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.removeObject(forKey: session) + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -228,6 +264,11 @@ import Foundation var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] + // Track the KV cache across the tool-calling loop. + // On the first iteration we try to reuse the session's cached KV state; + // on subsequent iterations (tool results added) we must rebuild. + var isFirstIteration = true + // Loop until no more tool calls while true { // Build user input with current chat history and tools @@ -238,9 +279,45 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Determine cache and input for generation + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + if isFirstIteration { + let existingEntry = getSessionCache(session) + let fullTokenCount = lmInput.text.tokens.dim(0) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + lmInput.image == nil + { + // Cache HIT: only prefill new tokens + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + inputForGeneration = MLXLMCommon.LMInput(text: partialText) + cache = existingEntry.kvCache + } else { + // Cache MISS: create fresh cache + if existingEntry != nil { + removeSessionCache(for: session) + } + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + } else { + // Tool-calling iterations: fresh cache (chat has been mutated) + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + + isFirstIteration = false + // Generate let stream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -259,6 +336,11 @@ import Foundation } } + // Update session cache with current offset after generation + let currentOffset = cache.first?.offset ?? 0 + let cacheEntry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + setSessionCache(cacheEntry, for: session) + let assistantText = chunks.joined() allTextChunks.append(assistantText) @@ -344,8 +426,37 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Check for existing KV cache for this session + let existingEntry = getSessionCache(session) + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + let fullTokenCount = lmInput.text.tokens.dim(0) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + lmInput.image == nil + { + // Cache HIT: only prefill new tokens + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + inputForGeneration = MLXLMCommon.LMInput(text: partialText) + cache = existingEntry.kvCache + } else { + // Cache MISS: create fresh cache, prefill everything + if existingEntry != nil { + removeSessionCache(for: session) + } + let newCache = context.model.newCache(parameters: generateParameters) + cache = newCache + inputForGeneration = lmInput + } + let mlxStream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -366,6 +477,11 @@ import Foundation } } + // Update the session cache with current offset + let currentOffset = cache.first?.offset ?? 0 + let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + setSessionCache(entry, for: session) + continuation.finish() } catch { continuation.finish(throwing: error) @@ -377,7 +493,7 @@ import Foundation return LanguageModelSession.ResponseStream(stream: stream) } - /// Prewarms the model + /// Prewarms the model by loading it and optionally prefilling the system prompt into a KV cache. public func prewarm( for session: LanguageModelSession, promptPrefix: Prompt? @@ -388,7 +504,27 @@ import Foundation Task { do { - _ = try await loadContext(modelId: modelId, hub: hub, directory: directory) + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + + // Prefill the system prompt into a KV cache so the first turn is faster + if let instructions = session.instructions?.description, !instructions.isEmpty { + let params = MLXLMCommon.GenerateParameters() + let newCache = context.model.newCache(parameters: params) + let chat: [MLXLMCommon.Chat.Message] = [.init(role: .system, content: instructions)] + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: nil + ) + let lmInput = try await context.processor.prepare(input: userInput) + _ = try context.model.prepare(lmInput, cache: newCache, windowSize: nil) + + let entry = SessionCacheEntry( + kvCache: newCache, + prefillTokenCount: newCache.first?.offset ?? 0 + ) + setSessionCache(entry, for: session) + } } catch { // Ignore errors during prewarm } @@ -401,9 +537,9 @@ import Foundation private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: nil, - kvBits: nil, - kvGroupSize: 64, + maxKVSize: options.maxKVSize, + kvBits: options.kvBits, + kvGroupSize: options.kvGroupSize, quantizedKVStart: 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0, From 2d2269e21f4a341c59cdf0c92964e376f8aecc99 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Mon, 23 Feb 2026 14:39:17 -0600 Subject: [PATCH 2/5] Add GPU memory management and upgrade mlx-swift-lm to 2.30.6 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add GPUMemoryConfiguration struct with .automatic (RAM-scaled) and .unconstrained presets for controlling Metal buffer pool limits - Add GPUMemoryManager singleton with reference-counted active/idle toggling — cache stays high during concurrent generations, drops to idle limit only when all sessions complete - Wrap respond(), streamResponse(), and prewarm() with markActive/markIdle - Call evict() on removeFromCache/removeAllFromCache to reclaim GPU buffers - Upgrade mlx-swift from 0.29.1 to 0.30.6 (fast SDPA, cache race fix, Memory API, wired memory, iPhone 16 Pro NAX fix) - Upgrade mlx-swift-lm from 2.29.3 to 2.30.6 (Gemma3n per-layer intermediate_size, model loading perf, chat rehydration, tool calling) - Migrate deprecated GPU.set(cacheLimit:)/GPU.clearCache() to Memory.* Co-Authored-By: Claude Opus 4.6 --- Package.resolved | 58 ++++--- Package.swift | 5 +- .../Models/MLXLanguageModel.swift | 161 +++++++++++++++++- 3 files changed, 200 insertions(+), 24 deletions(-) diff --git a/Package.resolved b/Package.resolved index 837d7768..1495dc88 100644 --- a/Package.resolved +++ b/Package.resolved @@ -4,10 +4,10 @@ { "identity" : "eventsource", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/EventSource.git", + "location" : "https://github.com/mattt/EventSource", "state" : { - "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", - "version" : "1.3.0" + "revision" : "bd64824505da71a1a403adb221f6e25413c0bc7f", + "version" : "1.4.0" } }, { @@ -19,22 +19,13 @@ "version" : "1.3.1" } }, - { - "identity" : "llama.swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/llama.swift", - "state" : { - "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", - "version" : "2.7966.0" - } - }, { "identity" : "mlx-swift", "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "072b684acaae80b6a463abab3a103732f33774bf", - "version" : "0.29.1" + "revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d", + "version" : "0.30.6" } }, { @@ -42,8 +33,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift-lm", "state" : { - "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", - "version" : "2.29.3" + "revision" : "7e19e09027923d89ac47dd087d9627f610e5a91a", + "version" : "2.30.6" } }, { @@ -55,6 +46,15 @@ "version" : "1.0.0" } }, + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" + } + }, { "identity" : "swift-collections", "kind" : "remoteSourceControl", @@ -64,13 +64,22 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + }, { "identity" : "swift-jinja", "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-jinja.git", "state" : { - "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", - "version" : "2.3.1" + "revision" : "f731f03bf746481d4fda07f817c3774390c4d5b9", + "version" : "2.3.2" } }, { @@ -96,8 +105,17 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", - "version" : "1.1.6" + "revision" : "3aecdf18e62303fb5a5543f8e87502b13474573f", + "version" : "1.1.7" + } + }, + { + "identity" : "yyjson", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ibireme/yyjson.git", + "state" : { + "revision" : "8b4a38dc994a110abaec8a400615567bd996105f", + "version" : "0.12.0" } } ], diff --git a/Package.swift b/Package.swift index 3916bf01..d3a6893e 100644 --- a/Package.swift +++ b/Package.swift @@ -33,8 +33,9 @@ let package = Package( .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), - // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). - .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), + // mlx-swift-lm >= 2.30.3 for fast SDPA, Gemma3n per-layer intermediate_size, + // cache race fix, Memory API, and chat rehydration. >= 2.25.5 for ToolSpec/tool calls. + .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.30.3"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), ], targets: [ diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index e84a6aee..0001b022 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -18,6 +18,135 @@ import Foundation import Tokenizers import Hub + // MARK: - GPU Memory Configuration + + /// Controls Metal buffer pool behavior during and between MLX inference. + /// + /// MLX maintains a recycled buffer pool to avoid repeated Metal allocations. + /// This configuration sets the pool size during active inference (`activeCacheLimit`) + /// and between generations (`idleCacheLimit`). + /// + /// ```swift + /// // Automatic (scaled by device RAM): + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .automatic) + /// + /// // Custom: + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .init( + /// activeCacheLimit: 256_000_000, + /// idleCacheLimit: 50_000_000 + /// )) + /// ``` + public struct GPUMemoryConfiguration: Sendable, Equatable { + /// Maximum Metal buffer cache size in bytes during active inference. + public var activeCacheLimit: Int + + /// Maximum Metal buffer cache size in bytes when no inference is running. + public var idleCacheLimit: Int + + /// Whether to call `Memory.clearCache()` when a model is evicted. + public var clearCacheOnEviction: Bool + + public init( + activeCacheLimit: Int, + idleCacheLimit: Int, + clearCacheOnEviction: Bool = true + ) { + self.activeCacheLimit = activeCacheLimit + self.idleCacheLimit = idleCacheLimit + self.clearCacheOnEviction = clearCacheOnEviction + } + + /// Scaled by device RAM. Idle: 50 MB. Clear cache on eviction. + /// + /// Active limits: <4 GB → 128 MB, <6 GB → 256 MB, <8 GB → 512 MB, 8+ GB → 768 MB. + public static var automatic: GPUMemoryConfiguration { + let ramBytes = ProcessInfo.processInfo.physicalMemory + let ramGB = ramBytes / (1024 * 1024 * 1024) + + let active: Int + switch ramGB { + case ..<4: + active = 128_000_000 + case ..<6: + active = 256_000_000 + case ..<8: + active = 512_000_000 + default: + active = 768_000_000 + } + + return GPUMemoryConfiguration( + activeCacheLimit: active, + idleCacheLimit: 50_000_000, + clearCacheOnEviction: true + ) + } + + /// No management — MLX defaults, unbounded cache. + public static var unconstrained: GPUMemoryConfiguration { + GPUMemoryConfiguration( + activeCacheLimit: Int.max, + idleCacheLimit: Int.max, + clearCacheOnEviction: false + ) + } + } + + // MARK: - GPU Memory Manager + + /// Reference-counted active/idle toggling for the global Metal buffer cache. + /// + /// Multiple sessions can generate concurrently. The cache stays at `activeCacheLimit` + /// as long as ANY session is generating, and drops to `idleCacheLimit` only when ALL + /// sessions complete. + private final class GPUMemoryManager: @unchecked Sendable { + static let shared = GPUMemoryManager() + + private let lock = NSLock() + private var activeCount = 0 + private var config: GPUMemoryConfiguration = .automatic + + private init() { + Memory.cacheLimit = config.idleCacheLimit + } + + func configure(_ configuration: GPUMemoryConfiguration) { + lock.withLock { + config = configuration + if activeCount == 0 { + Memory.cacheLimit = configuration.idleCacheLimit + } + } + } + + func markActive() { + lock.withLock { + if activeCount == 0 { + Memory.cacheLimit = config.activeCacheLimit + } + activeCount += 1 + } + } + + func markIdle() { + lock.withLock { + activeCount = max(0, activeCount - 1) + if activeCount == 0 { + Memory.cacheLimit = config.idleCacheLimit + } + } + } + + func evict() { + lock.withLock { + Memory.cacheLimit = config.idleCacheLimit + if config.clearCacheOnEviction { + Memory.clearCache() + } + } + } + } + /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). private final class CachedContext: NSObject, @unchecked Sendable { let context: ModelContext @@ -180,16 +309,22 @@ import Foundation /// The local directory containing the model files. public let directory: URL? + /// GPU memory management configuration for Metal buffer pools. + public let gpuMemory: GPUMemoryConfiguration + /// Creates an MLX language model. /// /// - Parameters: /// - modelId: The model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit"). /// - hub: An optional Hub API instance for downloading models. If not provided, the default Hub API is used. /// - directory: An optional local directory URL containing the model files. If provided, the model is loaded from this directory instead of downloading. - public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil) { + /// - gpuMemory: GPU memory configuration. Defaults to `.automatic` which scales by device RAM. + public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil, gpuMemory: GPUMemoryConfiguration = .automatic) { self.modelId = modelId self.hub = hub self.directory = directory + self.gpuMemory = gpuMemory + GPUMemoryManager.shared.configure(gpuMemory) } /// Removes this model from the shared cache and cancels any in-flight load. @@ -199,11 +334,13 @@ import Foundation public func removeFromCache() async { let key = directory?.absoluteString ?? modelId await modelCache.removeAndCancel(for: key) + GPUMemoryManager.shared.evict() } /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + GPUMemoryManager.shared.evict() } /// Get or load model context with caching @@ -229,6 +366,9 @@ import Foundation // Get cached or load fresh ModelContext let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + if type != String.self { let jsonString = try await generateStructuredJSON( context: context, @@ -410,6 +550,9 @@ import Foundation let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in + let didMarkIdle = Locked(false) + GPUMemoryManager.shared.markActive() + let task = Task { @Sendable in do { // Get cached or load fresh ModelContext @@ -482,12 +625,23 @@ import Foundation let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) setSessionCache(entry, for: session) + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish() } catch { + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish(throwing: error) } } - continuation.onTermination = { _ in task.cancel() } + continuation.onTermination = { _ in + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } + task.cancel() + } } return LanguageModelSession.ResponseStream(stream: stream) @@ -503,6 +657,9 @@ import Foundation let directory = self.directory Task { + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + do { let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) From ec343604f95a5fdf884381db5ba0e1d79dabec84 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Wed, 25 Feb 2026 12:05:29 -0600 Subject: [PATCH 3/5] Address PR review: cache hash validation, dedup, GPU config safety, tool-aware prewarm - Add prefillTokenHash to SessionCacheEntry to detect stale cache from replaced conversations (not just token count) - Extract resolveCache() helper to deduplicate cache hit/miss logic between respond() and streamResponse() - GPUMemoryManager.configure() now uses first-write-wins to prevent multiple MLXLanguageModel instances from silently overwriting config - prewarm() accepts tools via protocol and session automatically forwards registered tools so prefill tokenization matches respond() Co-Authored-By: Claude Opus 4.6 --- Sources/AnyLanguageModel/LanguageModel.swift | 6 +- .../LanguageModelSession.swift | 2 +- .../Models/MLXLanguageModel.swift | 154 +++++++++++------- 3 files changed, 104 insertions(+), 58 deletions(-) diff --git a/Sources/AnyLanguageModel/LanguageModel.swift b/Sources/AnyLanguageModel/LanguageModel.swift index 635de68c..2d801da8 100644 --- a/Sources/AnyLanguageModel/LanguageModel.swift +++ b/Sources/AnyLanguageModel/LanguageModel.swift @@ -14,7 +14,8 @@ public protocol LanguageModel: Sendable { func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? ) func respond( @@ -54,7 +55,8 @@ extension LanguageModel { public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? = nil + promptPrefix: Prompt? = nil, + tools: [any Tool]? = nil ) { return } diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index ba38550e..d3ea39b3 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -98,7 +98,7 @@ public final class LanguageModelSession: @unchecked Sendable { } public func prewarm(promptPrefix: Prompt? = nil) { - model.prewarm(for: self, promptPrefix: promptPrefix) + model.prewarm(for: self, promptPrefix: promptPrefix, tools: tools.isEmpty ? nil : tools) } nonisolated private func beginResponding() { diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 107f8e84..e843c132 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -105,13 +105,21 @@ import Foundation private let lock = NSLock() private var activeCount = 0 private var config: GPUMemoryConfiguration = .automatic + private var hasCustomConfig = false private init() { Memory.cacheLimit = config.idleCacheLimit } + /// Applies a GPU memory configuration. First custom configuration wins — + /// subsequent calls with a different configuration are ignored to prevent + /// multiple MLXLanguageModel instances from silently overwriting each other. func configure(_ configuration: GPUMemoryConfiguration) { lock.withLock { + if hasCustomConfig && config != configuration { + return + } + hasCustomConfig = true config = configuration if activeCount == 0 { Memory.cacheLimit = configuration.idleCacheLimit @@ -307,10 +315,12 @@ import Foundation private final class SessionCacheEntry: NSObject, @unchecked Sendable { var kvCache: [MLXLMCommon.KVCache] var prefillTokenCount: Int + var prefillTokenHash: Int - init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int) { + init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int, prefillTokenHash: Int) { self.kvCache = kvCache self.prefillTokenCount = prefillTokenCount + self.prefillTokenHash = prefillTokenHash } } @@ -337,6 +347,48 @@ import Foundation sessionKVCache.removeObject(forKey: session) } + /// Hashes up to the first `count` tokens of an MLXArray for cache identity checks. + private func hashTokenPrefix(_ tokens: MLXArray, count: Int = 64) -> Int { + let tokenCount = tokens.dim(0) + let n = min(count, tokenCount) + guard n > 0 else { return 0 } + let prefix = tokens[0.. (cache: [MLXLMCommon.KVCache], input: MLXLMCommon.LMInput) { + let existingEntry = getSessionCache(session) + let fullTokenCount = lmInput.text.tokens.dim(0) + let currentHash = hashTokenPrefix(lmInput.text.tokens) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + existingEntry.prefillTokenHash == currentHash, + lmInput.image == nil + { + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + return (cache: existingEntry.kvCache, input: MLXLMCommon.LMInput(text: partialText)) + } + + if existingEntry != nil { + removeSessionCache(for: session) + } + let freshCache = context.model.newCache(parameters: generateParameters) + return (cache: freshCache, input: lmInput) + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -494,28 +546,14 @@ import Foundation let inputForGeneration: MLXLMCommon.LMInput if isFirstIteration { - let existingEntry = getSessionCache(session) - let fullTokenCount = lmInput.text.tokens.dim(0) - - if let existingEntry, - existingEntry.prefillTokenCount > 0, - fullTokenCount > existingEntry.prefillTokenCount, - lmInput.image == nil - { - // Cache HIT: only prefill new tokens - let cachedCount = existingEntry.prefillTokenCount - let newTokens = lmInput.text.tokens[cachedCount...] - let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) - inputForGeneration = MLXLMCommon.LMInput(text: partialText) - cache = existingEntry.kvCache - } else { - // Cache MISS: create fresh cache - if existingEntry != nil { - removeSessionCache(for: session) - } - cache = context.model.newCache(parameters: generateParameters) - inputForGeneration = lmInput - } + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + cache = resolved.cache + inputForGeneration = resolved.input } else { // Tool-calling iterations: fresh cache (chat has been mutated) cache = context.model.newCache(parameters: generateParameters) @@ -548,7 +586,11 @@ import Foundation // Update session cache with current offset after generation let currentOffset = cache.first?.offset ?? 0 - let cacheEntry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + let cacheEntry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) setSessionCache(cacheEntry, for: session) let assistantText = chunks.joined() @@ -639,33 +681,15 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) - // Check for existing KV cache for this session - let existingEntry = getSessionCache(session) - let cache: [MLXLMCommon.KVCache] - let inputForGeneration: MLXLMCommon.LMInput - - let fullTokenCount = lmInput.text.tokens.dim(0) - - if let existingEntry, - existingEntry.prefillTokenCount > 0, - fullTokenCount > existingEntry.prefillTokenCount, - lmInput.image == nil - { - // Cache HIT: only prefill new tokens - let cachedCount = existingEntry.prefillTokenCount - let newTokens = lmInput.text.tokens[cachedCount...] - let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) - inputForGeneration = MLXLMCommon.LMInput(text: partialText) - cache = existingEntry.kvCache - } else { - // Cache MISS: create fresh cache, prefill everything - if existingEntry != nil { - removeSessionCache(for: session) - } - let newCache = context.model.newCache(parameters: generateParameters) - cache = newCache - inputForGeneration = lmInput - } + // Resolve KV cache for this session + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + let cache = resolved.cache + let inputForGeneration = resolved.input let mlxStream = try MLXLMCommon.generate( input: inputForGeneration, @@ -692,7 +716,11 @@ import Foundation // Update the session cache with current offset let currentOffset = cache.first?.offset ?? 0 - let entry = SessionCacheEntry(kvCache: cache, prefillTokenCount: currentOffset) + let entry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) setSessionCache(entry, for: session) didMarkIdle.withLock { done in @@ -718,9 +746,17 @@ import Foundation } /// Prewarms the model by loading it and optionally prefilling the system prompt into a KV cache. + /// + /// - Parameters: + /// - session: The session whose instructions will be prefilled. + /// - promptPrefix: An optional prompt prefix (reserved for future use). + /// - tools: Tools that will be used with this session. Pass the same tools here + /// so the prefilled cache includes tool definitions in its tokenization, + /// avoiding a cache miss on the first real request. public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? = nil ) { let modelId = self.modelId let hub = self.hub @@ -738,17 +774,25 @@ import Foundation let params = MLXLMCommon.GenerateParameters() let newCache = context.model.newCache(parameters: params) let chat: [MLXLMCommon.Chat.Message] = [.init(role: .system, content: instructions)] + + // Convert tools to MLX ToolSpec format so the prefill tokenization + // matches what respond() will produce, ensuring cache hits. + let toolSpecs: [ToolSpec]? = tools.flatMap { toolList in + toolList.isEmpty ? nil : toolList.map { convertToolToMLXSpec($0) } + } + let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: nil + tools: toolSpecs ) let lmInput = try await context.processor.prepare(input: userInput) _ = try context.model.prepare(lmInput, cache: newCache, windowSize: nil) let entry = SessionCacheEntry( kvCache: newCache, - prefillTokenCount: newCache.first?.offset ?? 0 + prefillTokenCount: newCache.first?.offset ?? 0, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) ) setSessionCache(entry, for: session) } From 800555850117eb417f268a22b6f36bb7ed7f922b Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Wed, 25 Feb 2026 13:39:42 -0600 Subject: [PATCH 4/5] Fix duplicate tool call loops and clear KV caches on eviction - Detect when the MLX tool loop generates the same tool call signature as the previous iteration and break early instead of retrying - Clear sessionKVCache in removeAllFromCache() so memory warning handlers actually free GPU memory from cached KV states Co-Authored-By: Claude Opus 4.6 --- .../Models/MLXLanguageModel.swift | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index e843c132..2a675575 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -462,6 +462,9 @@ import Foundation /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + sessionKVCacheLock.lock() + sessionKVCache.removeAllObjects() + sessionKVCacheLock.unlock() GPUMemoryManager.shared.evict() } @@ -531,6 +534,13 @@ import Foundation // on subsequent iterations (tool results added) we must rebuild. var isFirstIteration = true + // Guard against infinite tool-call loops (e.g. model keeps retrying the + // same tool call). After this many iterations, break and return whatever + // text has been accumulated. + let maxToolIterations = 5 + var toolIteration = 0 + var previousToolCallSignature: String? + // Loop until no more tool calls while true { // Build user input with current chat history and tools @@ -603,6 +613,29 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + // Detect repeated tool calls — if the model generates the exact + // same tool call(s) as the previous iteration, it's stuck in a + // loop. Break and return whatever text we have so far. + let signature = collectedToolCalls + .map { "\($0.function.name):\($0.function.arguments)" } + .joined(separator: "|") + if signature == previousToolCallSignature { + allTextChunks.append(assistantText) + break + } + previousToolCallSignature = signature + + // Record the assistant text generated before the tool call + // as a transcript entry so convertTranscriptToMLXChat() can + // reproduce the exact same chat sequence on future turns + // (keeping the KV cache valid). + if !assistantText.isEmpty { + allEntries.append(.response(Transcript.Response( + assetIDs: [], + segments: [.text(.init(content: assistantText))] + ))) + } + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls): From 4ea4fdbc5531c128064a97b89753fd0267bda690 Mon Sep 17 00:00:00 2001 From: Michael Doise Date: Wed, 25 Feb 2026 13:39:42 -0600 Subject: [PATCH 5/5] Fix duplicate tool call loops and clear KV caches on eviction - Detect when the MLX tool loop generates the same tool call signature as the previous iteration and break early instead of retrying - Clear sessionKVCache in removeAllFromCache() so memory warning handlers actually free GPU memory from cached KV states Co-Authored-By: Claude Opus 4.6 --- .../Models/MLXLanguageModel.swift | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index e843c132..9d1f02e3 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -462,6 +462,9 @@ import Foundation /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + sessionKVCacheLock.withLock { + sessionKVCache.removeAllObjects() + } GPUMemoryManager.shared.evict() } @@ -531,6 +534,13 @@ import Foundation // on subsequent iterations (tool results added) we must rebuild. var isFirstIteration = true + // Guard against infinite tool-call loops (e.g. model keeps retrying the + // same tool call). After this many iterations, break and return whatever + // text has been accumulated. + let maxToolIterations = 5 + var toolIteration = 0 + var previousToolCallSignature: String? + // Loop until no more tool calls while true { // Build user input with current chat history and tools @@ -603,6 +613,29 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + // Detect repeated tool calls — if the model generates the exact + // same tool call(s) as the previous iteration, it's stuck in a + // loop. Break and return whatever text we have so far. + let signature = collectedToolCalls + .map { "\($0.function.name):\($0.function.arguments)" } + .joined(separator: "|") + if signature == previousToolCallSignature { + allTextChunks.append(assistantText) + break + } + previousToolCallSignature = signature + + // Record the assistant text generated before the tool call + // as a transcript entry so convertTranscriptToMLXChat() can + // reproduce the exact same chat sequence on future turns + // (keeping the KV cache valid). + if !assistantText.isEmpty { + allEntries.append(.response(Transcript.Response( + assetIDs: [], + segments: [.text(.init(content: assistantText))] + ))) + } + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls):