diff --git a/Package.resolved b/Package.resolved index 837d776..1495dc8 100644 --- a/Package.resolved +++ b/Package.resolved @@ -4,10 +4,10 @@ { "identity" : "eventsource", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/EventSource.git", + "location" : "https://github.com/mattt/EventSource", "state" : { - "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", - "version" : "1.3.0" + "revision" : "bd64824505da71a1a403adb221f6e25413c0bc7f", + "version" : "1.4.0" } }, { @@ -19,22 +19,13 @@ "version" : "1.3.1" } }, - { - "identity" : "llama.swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/llama.swift", - "state" : { - "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", - "version" : "2.7966.0" - } - }, { "identity" : "mlx-swift", "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "072b684acaae80b6a463abab3a103732f33774bf", - "version" : "0.29.1" + "revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d", + "version" : "0.30.6" } }, { @@ -42,8 +33,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift-lm", "state" : { - "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", - "version" : "2.29.3" + "revision" : "7e19e09027923d89ac47dd087d9627f610e5a91a", + "version" : "2.30.6" } }, { @@ -55,6 +46,15 @@ "version" : "1.0.0" } }, + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" + } + }, { "identity" : "swift-collections", "kind" : "remoteSourceControl", @@ -64,13 +64,22 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + }, { "identity" : "swift-jinja", "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-jinja.git", "state" : { - "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", - "version" : "2.3.1" + "revision" : "f731f03bf746481d4fda07f817c3774390c4d5b9", + "version" : "2.3.2" } }, { @@ -96,8 +105,17 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", - "version" : "1.1.6" + "revision" : "3aecdf18e62303fb5a5543f8e87502b13474573f", + "version" : "1.1.7" + } + }, + { + "identity" : "yyjson", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ibireme/yyjson.git", + "state" : { + "revision" : "8b4a38dc994a110abaec8a400615567bd996105f", + "version" : "0.12.0" } } ], diff --git a/Package.swift b/Package.swift index 3916bf0..d3a6893 100644 --- a/Package.swift +++ b/Package.swift @@ -33,8 +33,9 @@ let package = Package( .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), - // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). - .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), + // mlx-swift-lm >= 2.30.3 for fast SDPA, Gemma3n per-layer intermediate_size, + // cache race fix, Memory API, and chat rehydration. >= 2.25.5 for ToolSpec/tool calls. + .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.30.3"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), ], targets: [ diff --git a/Sources/AnyLanguageModel/GenerationOptions.swift b/Sources/AnyLanguageModel/GenerationOptions.swift index 5d0f633..17da525 100644 --- a/Sources/AnyLanguageModel/GenerationOptions.swift +++ b/Sources/AnyLanguageModel/GenerationOptions.swift @@ -121,6 +121,25 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// an error will be thrown. public var maximumResponseTokens: Int? + /// Maximum number of tokens to retain in the KV cache. + /// + /// When set, uses a rotating cache that evicts oldest tokens beyond this limit. + /// When `nil` (default), the cache grows unbounded. + /// + /// Recommended values: 2048–4096 for iPhone, `nil` for Mac. + public var maxKVSize: Int? + + /// Bit width for KV cache quantization (for example, 4 or 8). + /// + /// Reduces cache memory usage at slight quality cost. + /// When `nil` (default), the cache uses full precision. + public var kvBits: Int? + + /// Group size for KV cache quantization. + /// + /// Only meaningful when ``kvBits`` is set. Default is 64. + public var kvGroupSize: Int + /// Storage for model-specific custom options. private var customOptionsStorage: CustomOptionsStorage = .init() @@ -157,14 +176,23 @@ public struct GenerationOptions: Sendable, Equatable, Codable { /// responses. Must be between `0` and `1`, inclusive. /// - maximumResponseTokens: The maximum number of tokens the model is allowed /// to produce before being artificially halted. Must be positive. + /// - maxKVSize: Maximum tokens in the KV cache. When set, enables a rotating cache. + /// - kvBits: Bit width for KV cache quantization. + /// - kvGroupSize: Group size for KV cache quantization. Default is 64. public init( sampling: SamplingMode? = nil, temperature: Double? = nil, - maximumResponseTokens: Int? = nil + maximumResponseTokens: Int? = nil, + maxKVSize: Int? = nil, + kvBits: Int? = nil, + kvGroupSize: Int = 64 ) { self.sampling = sampling self.temperature = temperature self.maximumResponseTokens = maximumResponseTokens + self.maxKVSize = maxKVSize + self.kvBits = kvBits + self.kvGroupSize = kvGroupSize } } diff --git a/Sources/AnyLanguageModel/LanguageModel.swift b/Sources/AnyLanguageModel/LanguageModel.swift index 635de68..2d801da 100644 --- a/Sources/AnyLanguageModel/LanguageModel.swift +++ b/Sources/AnyLanguageModel/LanguageModel.swift @@ -14,7 +14,8 @@ public protocol LanguageModel: Sendable { func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? ) func respond( @@ -54,7 +55,8 @@ extension LanguageModel { public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? = nil + promptPrefix: Prompt? = nil, + tools: [any Tool]? = nil ) { return } diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index ba38550..d3ea39b 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -98,7 +98,7 @@ public final class LanguageModelSession: @unchecked Sendable { } public func prewarm(promptPrefix: Prompt? = nil) { - model.prewarm(for: self, promptPrefix: promptPrefix) + model.prewarm(for: self, promptPrefix: promptPrefix, tools: tools.isEmpty ? nil : tools) } nonisolated private func beginResponding() { diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index f4be593..9d1f02e 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -18,6 +18,143 @@ import Foundation import Tokenizers import Hub + // MARK: - GPU Memory Configuration + + /// Controls Metal buffer pool behavior during and between MLX inference. + /// + /// MLX maintains a recycled buffer pool to avoid repeated Metal allocations. + /// This configuration sets the pool size during active inference (`activeCacheLimit`) + /// and between generations (`idleCacheLimit`). + /// + /// ```swift + /// // Automatic (scaled by device RAM): + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .automatic) + /// + /// // Custom: + /// let model = MLXLanguageModel(modelId: "...", gpuMemory: .init( + /// activeCacheLimit: 256_000_000, + /// idleCacheLimit: 50_000_000 + /// )) + /// ``` + public struct GPUMemoryConfiguration: Sendable, Equatable { + /// Maximum Metal buffer cache size in bytes during active inference. + public var activeCacheLimit: Int + + /// Maximum Metal buffer cache size in bytes when no inference is running. + public var idleCacheLimit: Int + + /// Whether to call `Memory.clearCache()` when a model is evicted. + public var clearCacheOnEviction: Bool + + public init( + activeCacheLimit: Int, + idleCacheLimit: Int, + clearCacheOnEviction: Bool = true + ) { + self.activeCacheLimit = activeCacheLimit + self.idleCacheLimit = idleCacheLimit + self.clearCacheOnEviction = clearCacheOnEviction + } + + /// Scaled by device RAM. Idle: 50 MB. Clear cache on eviction. + /// + /// Active limits: <4 GB → 128 MB, <6 GB → 256 MB, <8 GB → 512 MB, 8+ GB → 768 MB. + public static var automatic: GPUMemoryConfiguration { + let ramBytes = ProcessInfo.processInfo.physicalMemory + let ramGB = ramBytes / (1024 * 1024 * 1024) + + let active: Int + switch ramGB { + case ..<4: + active = 128_000_000 + case ..<6: + active = 256_000_000 + case ..<8: + active = 512_000_000 + default: + active = 768_000_000 + } + + return GPUMemoryConfiguration( + activeCacheLimit: active, + idleCacheLimit: 50_000_000, + clearCacheOnEviction: true + ) + } + + /// No management — MLX defaults, unbounded cache. + public static var unconstrained: GPUMemoryConfiguration { + GPUMemoryConfiguration( + activeCacheLimit: Int.max, + idleCacheLimit: Int.max, + clearCacheOnEviction: false + ) + } + } + + // MARK: - GPU Memory Manager + + /// Reference-counted active/idle toggling for the global Metal buffer cache. + /// + /// Multiple sessions can generate concurrently. The cache stays at `activeCacheLimit` + /// as long as ANY session is generating, and drops to `idleCacheLimit` only when ALL + /// sessions complete. + private final class GPUMemoryManager: @unchecked Sendable { + static let shared = GPUMemoryManager() + + private let lock = NSLock() + private var activeCount = 0 + private var config: GPUMemoryConfiguration = .automatic + private var hasCustomConfig = false + + private init() { + Memory.cacheLimit = config.idleCacheLimit + } + + /// Applies a GPU memory configuration. First custom configuration wins — + /// subsequent calls with a different configuration are ignored to prevent + /// multiple MLXLanguageModel instances from silently overwriting each other. + func configure(_ configuration: GPUMemoryConfiguration) { + lock.withLock { + if hasCustomConfig && config != configuration { + return + } + hasCustomConfig = true + config = configuration + if activeCount == 0 { + Memory.cacheLimit = configuration.idleCacheLimit + } + } + } + + func markActive() { + lock.withLock { + if activeCount == 0 { + Memory.cacheLimit = config.activeCacheLimit + } + activeCount += 1 + } + } + + func markIdle() { + lock.withLock { + activeCount = max(0, activeCount - 1) + if activeCount == 0 { + Memory.cacheLimit = config.idleCacheLimit + } + } + } + + func evict() { + lock.withLock { + Memory.cacheLimit = config.idleCacheLimit + if config.clearCacheOnEviction { + Memory.clearCache() + } + } + } + } + /// Wrapper to store model availability state in NSCache. private final class CachedModelState: NSObject, @unchecked Sendable { enum Value { @@ -172,6 +309,86 @@ import Foundation /// Shared cache across MLXLanguageModel instances. private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) + // MARK: - Session KV Cache Store + + /// Stores a KV cache and its prefill token count for a session. + private final class SessionCacheEntry: NSObject, @unchecked Sendable { + var kvCache: [MLXLMCommon.KVCache] + var prefillTokenCount: Int + var prefillTokenHash: Int + + init(kvCache: [MLXLMCommon.KVCache], prefillTokenCount: Int, prefillTokenHash: Int) { + self.kvCache = kvCache + self.prefillTokenCount = prefillTokenCount + self.prefillTokenHash = prefillTokenHash + } + } + + /// Maps LanguageModelSession (weak key) → SessionCacheEntry. + /// When a session is deallocated, its cache entry is automatically released. + private nonisolated(unsafe) let sessionKVCache = NSMapTable.weakToStrongObjects() + private let sessionKVCacheLock = NSLock() + + private func getSessionCache(_ session: LanguageModelSession) -> SessionCacheEntry? { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + return sessionKVCache.object(forKey: session) + } + + private func setSessionCache(_ entry: SessionCacheEntry, for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.setObject(entry, forKey: session) + } + + private func removeSessionCache(for session: LanguageModelSession) { + sessionKVCacheLock.lock() + defer { sessionKVCacheLock.unlock() } + sessionKVCache.removeObject(forKey: session) + } + + /// Hashes up to the first `count` tokens of an MLXArray for cache identity checks. + private func hashTokenPrefix(_ tokens: MLXArray, count: Int = 64) -> Int { + let tokenCount = tokens.dim(0) + let n = min(count, tokenCount) + guard n > 0 else { return 0 } + let prefix = tokens[0.. (cache: [MLXLMCommon.KVCache], input: MLXLMCommon.LMInput) { + let existingEntry = getSessionCache(session) + let fullTokenCount = lmInput.text.tokens.dim(0) + let currentHash = hashTokenPrefix(lmInput.text.tokens) + + if let existingEntry, + existingEntry.prefillTokenCount > 0, + fullTokenCount > existingEntry.prefillTokenCount, + existingEntry.prefillTokenHash == currentHash, + lmInput.image == nil + { + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens) + return (cache: existingEntry.kvCache, input: MLXLMCommon.LMInput(text: partialText)) + } + + if existingEntry != nil { + removeSessionCache(for: session) + } + let freshCache = context.model.newCache(parameters: generateParameters) + return (cache: freshCache, input: lmInput) + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -200,16 +417,22 @@ import Foundation /// The local directory containing the model files. public let directory: URL? + /// GPU memory management configuration for Metal buffer pools. + public let gpuMemory: GPUMemoryConfiguration + /// Creates an MLX language model. /// /// - Parameters: /// - modelId: The model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit"). /// - hub: An optional Hub API instance for downloading models. If not provided, the default Hub API is used. /// - directory: An optional local directory URL containing the model files. If provided, the model is loaded from this directory instead of downloading. - public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil) { + /// - gpuMemory: GPU memory configuration. Defaults to `.automatic` which scales by device RAM. + public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil, gpuMemory: GPUMemoryConfiguration = .automatic) { self.modelId = modelId self.hub = hub self.directory = directory + self.gpuMemory = gpuMemory + GPUMemoryManager.shared.configure(gpuMemory) } /// The current availability of this model in memory. @@ -233,11 +456,16 @@ import Foundation public func removeFromCache() async { let key = directory?.absoluteString ?? modelId await modelCache.removeAndCancel(for: key) + GPUMemoryManager.shared.evict() } /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + sessionKVCacheLock.withLock { + sessionKVCache.removeAllObjects() + } + GPUMemoryManager.shared.evict() } /// Get or load model context with caching @@ -263,6 +491,9 @@ import Foundation // Get cached or load fresh ModelContext let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + if type != String.self { let jsonString = try await generateStructuredJSON( context: context, @@ -298,6 +529,18 @@ import Foundation var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] + // Track the KV cache across the tool-calling loop. + // On the first iteration we try to reuse the session's cached KV state; + // on subsequent iterations (tool results added) we must rebuild. + var isFirstIteration = true + + // Guard against infinite tool-call loops (e.g. model keeps retrying the + // same tool call). After this many iterations, break and return whatever + // text has been accumulated. + let maxToolIterations = 5 + var toolIteration = 0 + var previousToolCallSignature: String? + // Loop until no more tool calls while true { // Build user input with current chat history and tools @@ -308,9 +551,31 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Determine cache and input for generation + let cache: [MLXLMCommon.KVCache] + let inputForGeneration: MLXLMCommon.LMInput + + if isFirstIteration { + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + cache = resolved.cache + inputForGeneration = resolved.input + } else { + // Tool-calling iterations: fresh cache (chat has been mutated) + cache = context.model.newCache(parameters: generateParameters) + inputForGeneration = lmInput + } + + isFirstIteration = false + // Generate let stream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -329,6 +594,15 @@ import Foundation } } + // Update session cache with current offset after generation + let currentOffset = cache.first?.offset ?? 0 + let cacheEntry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) + setSessionCache(cacheEntry, for: session) + let assistantText = chunks.joined() allTextChunks.append(assistantText) @@ -339,6 +613,29 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + // Detect repeated tool calls — if the model generates the exact + // same tool call(s) as the previous iteration, it's stuck in a + // loop. Break and return whatever text we have so far. + let signature = collectedToolCalls + .map { "\($0.function.name):\($0.function.arguments)" } + .joined(separator: "|") + if signature == previousToolCallSignature { + allTextChunks.append(assistantText) + break + } + previousToolCallSignature = signature + + // Record the assistant text generated before the tool call + // as a transcript entry so convertTranscriptToMLXChat() can + // reproduce the exact same chat sequence on future turns + // (keeping the KV cache valid). + if !assistantText.isEmpty { + allEntries.append(.response(Transcript.Response( + assetIDs: [], + segments: [.text(.init(content: assistantText))] + ))) + } + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls): @@ -398,6 +695,9 @@ import Foundation let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in + let didMarkIdle = Locked(false) + GPUMemoryManager.shared.markActive() + let task = Task { @Sendable in do { // Get cached or load fresh ModelContext @@ -414,8 +714,19 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) + // Resolve KV cache for this session + let resolved = resolveCache( + for: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) + let cache = resolved.cache + let inputForGeneration = resolved.input + let mlxStream = try MLXLMCommon.generate( - input: lmInput, + input: inputForGeneration, + cache: cache, parameters: generateParameters, context: context ) @@ -436,29 +747,88 @@ import Foundation } } + // Update the session cache with current offset + let currentOffset = cache.first?.offset ?? 0 + let entry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: currentOffset, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) + setSessionCache(entry, for: session) + + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish() } catch { + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } continuation.finish(throwing: error) } } - continuation.onTermination = { _ in task.cancel() } + continuation.onTermination = { _ in + didMarkIdle.withLock { done in + if !done { GPUMemoryManager.shared.markIdle(); done = true } + } + task.cancel() + } } return LanguageModelSession.ResponseStream(stream: stream) } - /// Prewarms the model + /// Prewarms the model by loading it and optionally prefilling the system prompt into a KV cache. + /// + /// - Parameters: + /// - session: The session whose instructions will be prefilled. + /// - promptPrefix: An optional prompt prefix (reserved for future use). + /// - tools: Tools that will be used with this session. Pass the same tools here + /// so the prefilled cache includes tool definitions in its tokenization, + /// avoiding a cache miss on the first real request. public func prewarm( for session: LanguageModelSession, - promptPrefix: Prompt? + promptPrefix: Prompt?, + tools: [any Tool]? = nil ) { let modelId = self.modelId let hub = self.hub let directory = self.directory Task { + GPUMemoryManager.shared.markActive() + defer { GPUMemoryManager.shared.markIdle() } + do { - _ = try await loadContext(modelId: modelId, hub: hub, directory: directory) + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + + // Prefill the system prompt into a KV cache so the first turn is faster + if let instructions = session.instructions?.description, !instructions.isEmpty { + let params = MLXLMCommon.GenerateParameters() + let newCache = context.model.newCache(parameters: params) + let chat: [MLXLMCommon.Chat.Message] = [.init(role: .system, content: instructions)] + + // Convert tools to MLX ToolSpec format so the prefill tokenization + // matches what respond() will produce, ensuring cache hits. + let toolSpecs: [ToolSpec]? = tools.flatMap { toolList in + toolList.isEmpty ? nil : toolList.map { convertToolToMLXSpec($0) } + } + + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: toolSpecs + ) + let lmInput = try await context.processor.prepare(input: userInput) + _ = try context.model.prepare(lmInput, cache: newCache, windowSize: nil) + + let entry = SessionCacheEntry( + kvCache: newCache, + prefillTokenCount: newCache.first?.offset ?? 0, + prefillTokenHash: hashTokenPrefix(lmInput.text.tokens) + ) + setSessionCache(entry, for: session) + } } catch { // Ignore errors during prewarm } @@ -471,9 +841,9 @@ import Foundation private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: nil, - kvBits: nil, - kvGroupSize: 64, + maxKVSize: options.maxKVSize, + kvBits: options.kvBits, + kvGroupSize: options.kvGroupSize, quantizedKVStart: 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0,