diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 4ffb877..f4be593 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -18,20 +18,28 @@ import Foundation import Tokenizers import Hub - /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). - private final class CachedContext: NSObject, @unchecked Sendable { - let context: ModelContext - init(_ context: ModelContext) { self.context = context } + /// Wrapper to store model availability state in NSCache. + private final class CachedModelState: NSObject, @unchecked Sendable { + enum Value { + case loaded(ModelContext) + case failed(String) + } + + let value: Value + + init(_ value: Value) { + self.value = value + } } /// Coordinates a bounded in-memory cache with structured, coalesced loading. private final class ModelContextCache { - private let cache: NSCache - private let inFlight = Locked<[String: Task]>([:]) + private let cache: NSCache + private let inFlight = Locked<[String: Task]>([:]) /// Creates a cache with a count-based eviction limit. init(countLimit: Int) { - let cache = NSCache() + let cache = NSCache() cache.countLimit = countLimit self.cache = cache } @@ -42,23 +50,45 @@ import Foundation loader: @escaping @Sendable () async throws -> ModelContext ) async throws -> ModelContext { let cacheKey = key as NSString - if let cached = cache.object(forKey: cacheKey) { - return cached.context + if let cached = cache.object(forKey: cacheKey), + case .loaded(let context) = cached.value + { + return context } if let task = inFlightTask(for: key) { - return try await task.value.context + let cached = try await task.value + if case .loaded(let context) = cached.value { + return context + } + throw CancellationError() } - let task = Task { try await CachedContext(loader()) } + let task = Task { + let context = try await loader() + return CachedModelState(.loaded(context)) + } setInFlight(task, for: key) do { let cached = try await task.value cache.setObject(cached, forKey: cacheKey) clearInFlight(for: key) - return cached.context + if case .loaded(let context) = cached.value { + return context + } + throw CancellationError() } catch { + // Don't treat cancellations as load failures. + if error is CancellationError || Task.isCancelled { + cache.removeObject(forKey: cacheKey) + clearInFlight(for: key) + throw error + } + cache.setObject( + CachedModelState(.failed(String(reflecting: error))), + forKey: cacheKey + ) clearInFlight(for: key) throw error } @@ -74,6 +104,28 @@ import Foundation cache.removeAllObjects() } + /// Returns whether a cached context exists for the key. + func contains(_ key: String) -> Bool { + guard let cached = cache.object(forKey: key as NSString) else { + return false + } + if case .loaded = cached.value { + return true + } + return false + } + + /// Returns a description of the most recent load failure for the key. + func failureDescription(for key: String) -> String? { + guard let cached = cache.object(forKey: key as NSString) else { + return nil + } + if case .failed(let description) = cached.value { + return description + } + return nil + } + /// Cancels in-flight work and removes cached data for the key. func removeAndCancel(for key: String) async { let task = removeInFlight(for: key) @@ -88,11 +140,11 @@ import Foundation cache.removeAllObjects() } - private func inFlightTask(for key: String) -> Task? { + private func inFlightTask(for key: String) -> Task? { inFlight.withLock { $0[key] } } - private func setInFlight(_ task: Task, for key: String) { + private func setInFlight(_ task: Task, for key: String) { inFlight.withLock { $0[key] = task } } @@ -100,7 +152,7 @@ import Foundation inFlight.withLock { $0[key] = nil } } - private func removeInFlight(for key: String) -> Task? { + private func removeInFlight(for key: String) -> Task? { inFlight.withLock { let task = $0[key] $0[key] = nil @@ -108,7 +160,7 @@ import Foundation } } - private func removeAllInFlight() -> [Task] { + private func removeAllInFlight() -> [Task] { inFlight.withLock { let tasks = Array($0.values) $0.removeAll() @@ -132,8 +184,12 @@ import Foundation /// ``` public struct MLXLanguageModel: LanguageModel { /// The reason the model is unavailable. - /// This model is always available. - public typealias UnavailableReason = Never + public enum UnavailableReason: Sendable, Equatable, Hashable { + /// The model has not been loaded into memory yet. + case notLoaded + /// The model failed to load and includes the underlying error details. + case failedToLoad(String) + } /// The model identifier. public let modelId: String @@ -156,6 +212,20 @@ import Foundation self.directory = directory } + /// The current availability of this model in memory. + public var availability: Availability { + let key = directory?.absoluteString ?? modelId + if modelCache.contains(key) { + return .available + } + + if let failureDescription = modelCache.failureDescription(for: key) { + return .unavailable(.failedToLoad(failureDescription)) + } + + return .unavailable(.notLoaded) + } + /// Removes this model from the shared cache and cancels any in-flight load. /// /// Call this to free memory when the model is no longer needed. diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 62827fd..82c5f85 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -35,6 +35,20 @@ import Testing let model = MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit") let visionModel = MLXLanguageModel(modelId: "mlx-community/Qwen2-VL-2B-Instruct-4bit") + @Test func availabilityBecomesAvailableAfterSuccessfulLoad() async throws { + await model.removeFromCache() + + #expect(model.availability == .unavailable(.notLoaded)) + #expect(model.isAvailable == false) + + let session = LanguageModelSession(model: model) + let response = try await session.respond(to: "Say hello") + #expect(!response.content.isEmpty) + + #expect(model.availability == .available) + #expect(model.isAvailable == true) + } + @Test func basicResponse() async throws { let session = LanguageModelSession(model: model) @@ -205,5 +219,25 @@ import Testing ) #expect([Priority.low, Priority.medium, Priority.high].contains(response.content)) } + + @Test func unavailableForNonexistentModel() async { + let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test") + await model.removeFromCache() + #expect(model.availability == .unavailable(.notLoaded)) + #expect(model.isAvailable == false) + + let session = LanguageModelSession(model: model) + await #expect(throws: Error.self) { + _ = try await session.respond(to: "Hello") + } + + switch model.availability { + case .unavailable(.failedToLoad(let description)): + #expect(!description.isEmpty) + default: + Issue.record("Expected model availability to report failedToLoad after failed request") + } + #expect(model.isAvailable == false) + } } #endif // MLX