From 2dde78c4e59656bf385e23297208f9211f1efa7a Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Mon, 23 Mar 2026 21:14:27 +1100 Subject: [PATCH 01/19] Add queryable runtime persistence with GRDB store --- .gitignore | 1 + .../Shared/AgentDemoRuntimeFactory.swift | 6 +- DemoApp/README.md | 4 +- Package.resolved | 15 + Package.swift | 8 +- README.md | 65 +- .../CodexKit/Approval/ApprovalModels.swift | 4 + Sources/CodexKit/Runtime/AgentHistory.swift | 757 +++++++++++ Sources/CodexKit/Runtime/AgentModels.swift | 9 +- Sources/CodexKit/Runtime/AgentQuerying.swift | 455 +++++++ .../Runtime/AgentRuntime+History.swift | 345 +++++ .../Runtime/AgentRuntime+Skills.swift | 1 + .../Runtime/AgentRuntime+Threads.swift | 70 +- .../AgentRuntime+TurnConsumption.swift | 469 ++++++- Sources/CodexKit/Runtime/AgentRuntime.swift | 17 +- .../Runtime/GRDBRuntimeStateStore.swift | 1148 +++++++++++++++++ .../CodexKit/Runtime/RuntimeStateStore.swift | 1011 ++++++++++++++- Sources/CodexKit/Tools/ToolModels.swift | 56 +- .../AgentRuntimeHistoryTests.swift | 818 ++++++++++++ 19 files changed, 5205 insertions(+), 54 deletions(-) create mode 100644 Package.resolved create mode 100644 Sources/CodexKit/Runtime/AgentHistory.swift create mode 100644 Sources/CodexKit/Runtime/AgentQuerying.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+History.swift create mode 100644 Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift create mode 100644 Tests/CodexKitTests/AgentRuntimeHistoryTests.swift diff --git a/.gitignore b/.gitignore index 40fcf50..60c79ad 100644 --- a/.gitignore +++ b/.gitignore @@ -102,6 +102,7 @@ iOSInjectionProject/ !*/xcshareddata/swiftpm/Package.resolved ### Xcode Patch ### **/xcshareddata/WorkspaceSettings.xcsettings +DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.xcworkspace/ #Cocoapods Pods/ diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift index f2ab5fc..5b3d625 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift @@ -102,7 +102,7 @@ enum AgentDemoRuntimeFactory { ) ), approvalPresenter: approvalInbox, - stateStore: FileRuntimeStateStore(url: stateURL ?? defaultStateURL()), + stateStore: try! GRDBRuntimeStateStore(url: stateURL ?? defaultStateURL()), memory: .init( store: try! SQLiteMemoryStore(url: defaultMemoryURL()), automaticCapturePolicy: .init( @@ -143,7 +143,7 @@ enum AgentDemoRuntimeFactory { ) ), approvalPresenter: NonInteractiveApprovalPresenter(), - stateStore: FileRuntimeStateStore(url: defaultStateURL()), + stateStore: try! GRDBRuntimeStateStore(url: defaultStateURL()), memory: .init( store: try! SQLiteMemoryStore(url: defaultMemoryURL()), automaticCapturePolicy: .init( @@ -169,7 +169,7 @@ enum AgentDemoRuntimeFactory { return baseDirectory .appendingPathComponent("AssistantRuntimeDemoApp", isDirectory: true) - .appendingPathComponent("runtime-state.json") + .appendingPathComponent("runtime-state.sqlite") } static func defaultMemoryURL() -> URL { diff --git a/DemoApp/README.md b/DemoApp/README.md index e11f578..16ad85f 100644 --- a/DemoApp/README.md +++ b/DemoApp/README.md @@ -44,10 +44,10 @@ The demo uses the new configuration-first surface: - `ChatGPTAuthProvider` - `KeychainSessionSecureStore` - `CodexResponsesBackend` -- `FileRuntimeStateStore` +- `GRDBRuntimeStateStore` - `ApprovalInbox` and `DeviceCodePromptCoordinator` from `CodexKitUI` -The app links `CodexKit` and `CodexKitUI` from the repo's local `Package.swift`, so it exercises the same SPM integration path a host app would use. +The app links `CodexKit` and `CodexKitUI` from the repo's local `Package.swift`, so it exercises the same SPM integration path a host app would use. Runtime state is stored in `runtime-state.sqlite`, memory is stored in `memory.sqlite`, and the GRDB-backed runtime store will import an older sibling `runtime-state.json` file automatically on first launch if one exists. ## Files diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..73d017e --- /dev/null +++ b/Package.resolved @@ -0,0 +1,15 @@ +{ + "originHash" : "13f5bc1889c60a454f2efe4b8c8bf5dbbb8486c6de415c9c3c333f21df57d574", + "pins" : [ + { + "identity" : "grdb.swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/groue/GRDB.swift.git", + "state" : { + "revision" : "36e30a6f1ef10e4194f6af0cff90888526f0c115", + "version" : "7.10.0" + } + } + ], + "version" : 3 +} diff --git a/Package.swift b/Package.swift index d4a1bef..0d520fd 100644 --- a/Package.swift +++ b/Package.swift @@ -17,9 +17,15 @@ let package = Package( targets: ["CodexKitUI"] ), ], + dependencies: [ + .package(url: "https://github.com/groue/GRDB.swift.git", from: "7.10.0"), + ], targets: [ .target( - name: "CodexKit" + name: "CodexKit", + dependencies: [ + .product(name: "GRDB", package: "GRDB.swift"), + ] ), .target( name: "CodexKitUI", diff --git a/README.md b/README.md index 7826c58..8e4d375 100644 --- a/README.md +++ b/README.md @@ -89,12 +89,12 @@ let runtime = try AgentRuntime(configuration: .init( ) ), approvalPresenter: approvalInbox, - stateStore: FileRuntimeStateStore( + stateStore: try GRDBRuntimeStateStore( url: FileManager.default.urls( for: .applicationSupportDirectory, in: .userDomainMask ).first! - .appendingPathComponent("CodexKit/runtime-state.json") + .appendingPathComponent("CodexKit/runtime-state.sqlite") ) )) @@ -140,7 +140,7 @@ flowchart LR A["SwiftUI App"] --> B["AgentRuntime"] B --> C["ChatGPTAuthProvider"] B --> D["SessionSecureStore
KeychainSessionSecureStore"] - B --> E["RuntimeStateStore
FileRuntimeStateStore"] + B --> E["RuntimeStateStore
GRDBRuntimeStateStore"] B --> F["CodexResponsesBackend"] B --> G["ToolRegistry + Executors"] B --> H["ApprovalPresenter
ApprovalInbox"] @@ -154,9 +154,20 @@ The recommended production path for iOS is: - `ChatGPTAuthProvider` - `KeychainSessionSecureStore` - `CodexResponsesBackend` -- `FileRuntimeStateStore` +- `GRDBRuntimeStateStore` - `ApprovalInbox` and `DeviceCodePromptCoordinator` from `CodexKitUI` +Bundled runtime-state stores now include: + +- `GRDBRuntimeStateStore` + The recommended production store. Uses SQLite through GRDB, supports migrations, query pushdown, redaction, whole-thread deletion, paged history reads, and lightweight restore/inspection. +- `FileRuntimeStateStore` + A simple JSON-backed fallback for small apps, tests, or export/import-style workflows. +- `InMemoryRuntimeStateStore` + Useful for previews and tests. + +If you are migrating from the older file-backed store, `GRDBRuntimeStateStore(url:)` automatically imports a sibling `*.json` runtime state file on first open. For example, `runtime-state.sqlite` will import from `runtime-state.json` if it exists and the SQLite store is still empty. + `ChatGPTAuthProvider` supports: - `.deviceCode` for the most reliable sign-in path @@ -211,6 +222,43 @@ Available values: - `.high` - `.extraHigh` +## Persistent State And Queries + +`CodexKit` now treats runtime persistence as a queryable store instead of a single “load the whole thread” blob. For most apps, the main thing to know is: + +- use `GRDBRuntimeStateStore` for persisted production state +- use `fetchThreadHistory(id:query:)` and `fetchLatestStructuredOutputMetadata(id:)` for common thread inspection +- use the typed `execute(_:)` query surface when you need more control over filtering, sorting, paging, or cross-thread reads + +```swift +let stateStore = try GRDBRuntimeStateStore( + url: FileManager.default.urls( + for: .applicationSupportDirectory, + in: .userDomainMask + ).first! + .appendingPathComponent("CodexKit/runtime-state.sqlite") +) + +let runtime = try AgentRuntime(configuration: .init( + authProvider: authProvider, + secureStore: secureStore, + backend: backend, + approvalPresenter: approvalPresenter, + stateStore: stateStore +)) + +let page = try await runtime.fetchThreadHistory( + id: thread.id, + query: .init(limit: 40, direction: .backward) +) + +let snapshots = try await runtime.execute( + ThreadSnapshotQuery(limit: 20) +) +``` + +This path also supports explicit history redaction and whole-thread deletion without forcing hosts to replay raw event streams themselves. + ## Typed Completions For most apps, there are now three common send paths: @@ -342,7 +390,7 @@ let runtime = try AgentRuntime(configuration: .init( configuration: .init(model: "gpt-5.4") ), approvalPresenter: approvalPresenter, - stateStore: FileRuntimeStateStore(url: stateURL), + stateStore: try GRDBRuntimeStateStore(url: stateURL), memory: .init( store: try SQLiteMemoryStore(url: memoryURL), automaticCapturePolicy: .init( @@ -456,12 +504,12 @@ let runtime = try AgentRuntime(configuration: .init( ), backend: CodexResponsesBackend(), approvalPresenter: approvalInbox, - stateStore: FileRuntimeStateStore( + stateStore: try GRDBRuntimeStateStore( url: FileManager.default.urls( for: .applicationSupportDirectory, in: .userDomainMask ).first! - .appendingPathComponent("CodexKit/runtime-state.json") + .appendingPathComponent("CodexKit/runtime-state.sqlite") ), memory: .init(store: memoryStore) )) @@ -646,6 +694,7 @@ The demo app exercises: - thread-pinned personas and one-turn overrides - a one-tap skill policy probe that compares tool behavior in normal vs skill-constrained threads - a Health Coach tab with HealthKit steps, AI-generated coaching, local reminders, and tone switching +- GRDB-backed runtime persistence with automatic import from older `runtime-state.json` state on first launch Each tab is focused on a single story: @@ -762,7 +811,7 @@ print(preview) ## Production Checklist - Store sessions in keychain (`KeychainSessionSecureStore`) -- Use persistent runtime state (`FileRuntimeStateStore`) +- Use persistent runtime state (`GRDBRuntimeStateStore`) - Gate impactful tools with approvals - Handle auth cancellation and sign-out resets cleanly - Tune retry/backoff policy for your app’s UX and latency targets diff --git a/Sources/CodexKit/Approval/ApprovalModels.swift b/Sources/CodexKit/Approval/ApprovalModels.swift index 32a6289..9aab485 100644 --- a/Sources/CodexKit/Approval/ApprovalModels.swift +++ b/Sources/CodexKit/Approval/ApprovalModels.swift @@ -30,6 +30,8 @@ public struct ApprovalRequest: Identifiable, Hashable, Sendable { } } +extension ApprovalRequest: Codable {} + public struct ApprovalResolution: Hashable, Sendable { public let requestID: String public let threadID: String @@ -52,6 +54,8 @@ public struct ApprovalResolution: Hashable, Sendable { } } +extension ApprovalResolution: Codable {} + public protocol ApprovalPresenting: Sendable { func requestApproval(_ request: ApprovalRequest) async throws -> ApprovalDecision } diff --git a/Sources/CodexKit/Runtime/AgentHistory.swift b/Sources/CodexKit/Runtime/AgentHistory.swift new file mode 100644 index 0000000..bd75dc5 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentHistory.swift @@ -0,0 +1,757 @@ +import Foundation + +public struct AgentHistoryRecord: Codable, Hashable, Sendable { + public let id: String + public let sequenceNumber: Int + public let createdAt: Date + public let item: AgentHistoryItem + public let redaction: AgentHistoryRedaction? + + public init( + id: String? = nil, + sequenceNumber: Int, + createdAt: Date, + item: AgentHistoryItem, + redaction: AgentHistoryRedaction? = nil + ) { + self.id = id ?? item.defaultRecordID + self.sequenceNumber = sequenceNumber + self.createdAt = createdAt + self.item = item + self.redaction = redaction + } + + enum CodingKeys: String, CodingKey { + case id + case sequenceNumber + case createdAt + case item + case redaction + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let sequenceNumber = try container.decode(Int.self, forKey: .sequenceNumber) + let createdAt = try container.decode(Date.self, forKey: .createdAt) + let item = try container.decode(AgentHistoryItem.self, forKey: .item) + let redaction = try container.decodeIfPresent(AgentHistoryRedaction.self, forKey: .redaction) + self.init( + id: try container.decodeIfPresent(String.self, forKey: .id), + sequenceNumber: sequenceNumber, + createdAt: createdAt, + item: item, + redaction: redaction + ) + } +} + +public struct AgentThreadHistoryPage: Sendable, Hashable { + public let threadID: String + public let items: [AgentHistoryItem] + public let nextCursor: AgentHistoryCursor? + public let previousCursor: AgentHistoryCursor? + public let hasMoreBefore: Bool + public let hasMoreAfter: Bool + + public init( + threadID: String, + items: [AgentHistoryItem], + nextCursor: AgentHistoryCursor?, + previousCursor: AgentHistoryCursor?, + hasMoreBefore: Bool, + hasMoreAfter: Bool + ) { + self.threadID = threadID + self.items = items + self.nextCursor = nextCursor + self.previousCursor = previousCursor + self.hasMoreBefore = hasMoreBefore + self.hasMoreAfter = hasMoreAfter + } +} + +public struct AgentHistoryCursor: Codable, Hashable, Sendable { + public let rawValue: String + + public init(rawValue: String) { + self.rawValue = rawValue + } +} + +public struct AgentHistoryQuery: Sendable, Hashable { + public var limit: Int + public var cursor: AgentHistoryCursor? + public var direction: AgentHistoryDirection + public var filter: AgentHistoryFilter? + + public init( + limit: Int = 50, + cursor: AgentHistoryCursor? = nil, + direction: AgentHistoryDirection = .backward, + filter: AgentHistoryFilter? = nil + ) { + self.limit = limit + self.cursor = cursor + self.direction = direction + self.filter = filter + } +} + +public enum AgentHistoryDirection: Sendable, Hashable { + case forward + case backward +} + +public struct AgentHistoryFilter: Sendable, Hashable { + public var includeMessages: Bool + public var includeToolCalls: Bool + public var includeToolResults: Bool + public var includeStructuredOutputs: Bool + public var includeApprovals: Bool + public var includeSystemEvents: Bool + + public init( + includeMessages: Bool = true, + includeToolCalls: Bool = true, + includeToolResults: Bool = true, + includeStructuredOutputs: Bool = true, + includeApprovals: Bool = true, + includeSystemEvents: Bool = true + ) { + self.includeMessages = includeMessages + self.includeToolCalls = includeToolCalls + self.includeToolResults = includeToolResults + self.includeStructuredOutputs = includeStructuredOutputs + self.includeApprovals = includeApprovals + self.includeSystemEvents = includeSystemEvents + } +} + +public struct AgentThreadSummary: Codable, Hashable, Sendable { + public let threadID: String + public let createdAt: Date + public let updatedAt: Date + public let latestItemAt: Date? + public let itemCount: Int? + public let latestAssistantMessagePreview: String? + public let latestStructuredOutputMetadata: AgentStructuredOutputMetadata? + public let latestPartialStructuredOutput: AgentPartialStructuredOutputSnapshot? + public let latestToolState: AgentLatestToolState? + public let latestTurnStatus: AgentTurnStatus? + public let pendingState: AgentThreadPendingState? + + public init( + threadID: String, + createdAt: Date, + updatedAt: Date, + latestItemAt: Date? = nil, + itemCount: Int? = nil, + latestAssistantMessagePreview: String? = nil, + latestStructuredOutputMetadata: AgentStructuredOutputMetadata? = nil, + latestPartialStructuredOutput: AgentPartialStructuredOutputSnapshot? = nil, + latestToolState: AgentLatestToolState? = nil, + latestTurnStatus: AgentTurnStatus? = nil, + pendingState: AgentThreadPendingState? = nil + ) { + self.threadID = threadID + self.createdAt = createdAt + self.updatedAt = updatedAt + self.latestItemAt = latestItemAt + self.itemCount = itemCount + self.latestAssistantMessagePreview = latestAssistantMessagePreview + self.latestStructuredOutputMetadata = latestStructuredOutputMetadata + self.latestPartialStructuredOutput = latestPartialStructuredOutput + self.latestToolState = latestToolState + self.latestTurnStatus = latestTurnStatus + self.pendingState = pendingState + } +} + +public protocol AgentRuntimeThreadInspecting: Sendable { + func fetchThreadSummary(id: String) async throws -> AgentThreadSummary + func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage + func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? +} + +public extension AgentThreadSummary { + var snapshot: AgentThreadSnapshot { + AgentThreadSnapshot( + threadID: threadID, + createdAt: createdAt, + updatedAt: updatedAt, + latestItemAt: latestItemAt, + itemCount: itemCount, + latestAssistantMessagePreview: latestAssistantMessagePreview, + latestStructuredOutputMetadata: latestStructuredOutputMetadata, + latestPartialStructuredOutput: latestPartialStructuredOutput, + latestToolState: latestToolState, + latestTurnStatus: latestTurnStatus, + pendingState: pendingState + ) + } +} + +public enum AgentHistoryItem: Hashable, Sendable { + case message(AgentMessage) + case toolCall(AgentToolCallRecord) + case toolResult(AgentToolResultRecord) + case structuredOutput(AgentStructuredOutputRecord) + case approval(AgentApprovalRecord) + case systemEvent(AgentSystemEventRecord) +} + +public struct AgentToolCallRecord: Codable, Hashable, Sendable { + public let invocation: ToolInvocation + public let requestedAt: Date + + public init( + invocation: ToolInvocation, + requestedAt: Date = Date() + ) { + self.invocation = invocation + self.requestedAt = requestedAt + } +} + +public struct AgentToolResultRecord: Codable, Hashable, Sendable { + public let threadID: String + public let turnID: String + public let result: ToolResultEnvelope + public let completedAt: Date + + public init( + threadID: String, + turnID: String, + result: ToolResultEnvelope, + completedAt: Date = Date() + ) { + self.threadID = threadID + self.turnID = turnID + self.result = result + self.completedAt = completedAt + } +} + +public struct AgentStructuredOutputRecord: Codable, Hashable, Sendable { + public let threadID: String + public let turnID: String + public let messageID: String? + public let metadata: AgentStructuredOutputMetadata + public let committedAt: Date + + public init( + threadID: String, + turnID: String, + messageID: String? = nil, + metadata: AgentStructuredOutputMetadata, + committedAt: Date = Date() + ) { + self.threadID = threadID + self.turnID = turnID + self.messageID = messageID + self.metadata = metadata + self.committedAt = committedAt + } +} + +public enum AgentApprovalEventKind: String, Codable, Hashable, Sendable { + case requested + case resolved +} + +public struct AgentApprovalRecord: Codable, Hashable, Sendable { + public let kind: AgentApprovalEventKind + public let request: ApprovalRequest? + public let resolution: ApprovalResolution? + public let occurredAt: Date + + public init( + kind: AgentApprovalEventKind, + request: ApprovalRequest? = nil, + resolution: ApprovalResolution? = nil, + occurredAt: Date = Date() + ) { + self.kind = kind + self.request = request + self.resolution = resolution + self.occurredAt = occurredAt + } +} + +public enum AgentSystemEventType: String, Codable, Hashable, Sendable { + case threadCreated + case threadResumed + case threadStatusChanged + case turnStarted + case turnCompleted + case turnFailed +} + +public struct AgentSystemEventRecord: Codable, Hashable, Sendable { + public let type: AgentSystemEventType + public let threadID: String + public let turnID: String? + public let status: AgentThreadStatus? + public let turnSummary: AgentTurnSummary? + public let error: AgentRuntimeError? + public let occurredAt: Date + + public init( + type: AgentSystemEventType, + threadID: String, + turnID: String? = nil, + status: AgentThreadStatus? = nil, + turnSummary: AgentTurnSummary? = nil, + error: AgentRuntimeError? = nil, + occurredAt: Date = Date() + ) { + self.type = type + self.threadID = threadID + self.turnID = turnID + self.status = status + self.turnSummary = turnSummary + self.error = error + self.occurredAt = occurredAt + } +} + +public enum AgentThreadPendingState: Hashable, Sendable { + case approval(AgentPendingApprovalState) + case userInput(AgentPendingUserInputState) + case toolWait(AgentPendingToolWaitState) +} + +public struct AgentPendingApprovalState: Codable, Hashable, Sendable { + public let request: ApprovalRequest + public let requestedAt: Date + + public init( + request: ApprovalRequest, + requestedAt: Date = Date() + ) { + self.request = request + self.requestedAt = requestedAt + } +} + +public struct AgentPendingUserInputState: Codable, Hashable, Sendable { + public let requestID: String + public let turnID: String + public let title: String + public let message: String + public let requestedAt: Date + + public init( + requestID: String, + turnID: String, + title: String, + message: String, + requestedAt: Date = Date() + ) { + self.requestID = requestID + self.turnID = turnID + self.title = title + self.message = message + self.requestedAt = requestedAt + } +} + +public struct AgentPendingToolWaitState: Codable, Hashable, Sendable { + public let invocationID: String + public let turnID: String + public let toolName: String + public let startedAt: Date + public let sessionID: String? + public let sessionStatus: String? + public let metadata: JSONValue? + public let resumable: Bool + + public init( + invocationID: String, + turnID: String, + toolName: String, + startedAt: Date = Date(), + sessionID: String? = nil, + sessionStatus: String? = nil, + metadata: JSONValue? = nil, + resumable: Bool = false + ) { + self.invocationID = invocationID + self.turnID = turnID + self.toolName = toolName + self.startedAt = startedAt + self.sessionID = sessionID + self.sessionStatus = sessionStatus + self.metadata = metadata + self.resumable = resumable + } +} + +public enum AgentToolSessionStatus: String, Codable, Hashable, Sendable { + case waiting + case running + case completed + case failed + case denied +} + +public struct AgentLatestToolState: Codable, Hashable, Sendable { + public let invocationID: String + public let turnID: String + public let toolName: String + public let status: AgentToolSessionStatus + public let success: Bool? + public let sessionID: String? + public let sessionStatus: String? + public let metadata: JSONValue? + public let resumable: Bool + public let updatedAt: Date + public let resultPreview: String? + + public init( + invocationID: String, + turnID: String, + toolName: String, + status: AgentToolSessionStatus, + success: Bool? = nil, + sessionID: String? = nil, + sessionStatus: String? = nil, + metadata: JSONValue? = nil, + resumable: Bool = false, + updatedAt: Date = Date(), + resultPreview: String? = nil + ) { + self.invocationID = invocationID + self.turnID = turnID + self.toolName = toolName + self.status = status + self.success = success + self.sessionID = sessionID + self.sessionStatus = sessionStatus + self.metadata = metadata + self.resumable = resumable + self.updatedAt = updatedAt + self.resultPreview = resultPreview + } +} + +public struct AgentPartialStructuredOutputSnapshot: Codable, Hashable, Sendable { + public let turnID: String + public let formatName: String + public let payload: JSONValue + public let updatedAt: Date + + public init( + turnID: String, + formatName: String, + payload: JSONValue, + updatedAt: Date = Date() + ) { + self.turnID = turnID + self.formatName = formatName + self.payload = payload + self.updatedAt = updatedAt + } +} + +extension AgentThreadPendingState: Codable { + private enum CodingKeys: String, CodingKey { + case kind + case approval + case userInput + case toolWait + } + + private enum Kind: String, Codable { + case approval + case userInput + case toolWait + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + switch try container.decode(Kind.self, forKey: .kind) { + case .approval: + self = .approval(try container.decode(AgentPendingApprovalState.self, forKey: .approval)) + case .userInput: + self = .userInput(try container.decode(AgentPendingUserInputState.self, forKey: .userInput)) + case .toolWait: + self = .toolWait(try container.decode(AgentPendingToolWaitState.self, forKey: .toolWait)) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case let .approval(state): + try container.encode(Kind.approval, forKey: .kind) + try container.encode(state, forKey: .approval) + case let .userInput(state): + try container.encode(Kind.userInput, forKey: .kind) + try container.encode(state, forKey: .userInput) + case let .toolWait(state): + try container.encode(Kind.toolWait, forKey: .kind) + try container.encode(state, forKey: .toolWait) + } + } +} + +extension AgentHistoryItem: Codable { + private enum CodingKeys: String, CodingKey { + case kind + case message + case toolCall + case toolResult + case structuredOutput + case approval + case systemEvent + } + + private enum Kind: String, Codable { + case message + case toolCall + case toolResult + case structuredOutput + case approval + case systemEvent + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + switch try container.decode(Kind.self, forKey: .kind) { + case .message: + self = .message(try container.decode(AgentMessage.self, forKey: .message)) + case .toolCall: + self = .toolCall(try container.decode(AgentToolCallRecord.self, forKey: .toolCall)) + case .toolResult: + self = .toolResult(try container.decode(AgentToolResultRecord.self, forKey: .toolResult)) + case .structuredOutput: + self = .structuredOutput(try container.decode(AgentStructuredOutputRecord.self, forKey: .structuredOutput)) + case .approval: + self = .approval(try container.decode(AgentApprovalRecord.self, forKey: .approval)) + case .systemEvent: + self = .systemEvent(try container.decode(AgentSystemEventRecord.self, forKey: .systemEvent)) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case let .message(message): + try container.encode(Kind.message, forKey: .kind) + try container.encode(message, forKey: .message) + case let .toolCall(record): + try container.encode(Kind.toolCall, forKey: .kind) + try container.encode(record, forKey: .toolCall) + case let .toolResult(record): + try container.encode(Kind.toolResult, forKey: .kind) + try container.encode(record, forKey: .toolResult) + case let .structuredOutput(record): + try container.encode(Kind.structuredOutput, forKey: .kind) + try container.encode(record, forKey: .structuredOutput) + case let .approval(record): + try container.encode(Kind.approval, forKey: .kind) + try container.encode(record, forKey: .approval) + case let .systemEvent(record): + try container.encode(Kind.systemEvent, forKey: .kind) + try container.encode(record, forKey: .systemEvent) + } + } +} + +extension AgentHistoryFilter { + func matches(_ item: AgentHistoryItem) -> Bool { + switch item { + case .message: + includeMessages + case .toolCall: + includeToolCalls + case .toolResult: + includeToolResults + case .structuredOutput: + includeStructuredOutputs + case .approval: + includeApprovals + case .systemEvent: + includeSystemEvents + } + } +} + +extension AgentHistoryItem { + var kind: AgentHistoryItemKind { + switch self { + case .message: + .message + case .toolCall: + .toolCall + case .toolResult: + .toolResult + case .structuredOutput: + .structuredOutput + case .approval: + .approval + case .systemEvent: + .systemEvent + } + } + + var turnID: String? { + switch self { + case let .message(message): + return message.structuredOutput == nil ? nil : nil + case let .toolCall(record): + return record.invocation.turnID + case let .toolResult(record): + return record.turnID + case let .structuredOutput(record): + return record.turnID + case let .approval(record): + return record.request?.turnID ?? record.resolution?.turnID + case let .systemEvent(record): + return record.turnID + } + } + + var defaultRecordID: String { + switch self { + case let .message(message): + return "message:\(message.id)" + case let .toolCall(record): + return "toolCall:\(record.invocation.id)" + case let .toolResult(record): + return "toolResult:\(record.result.invocationID)" + case let .structuredOutput(record): + return "structuredOutput:\(record.messageID ?? record.turnID)" + case let .approval(record): + return "approval:\(record.request?.id ?? record.resolution?.requestID ?? UUID().uuidString)" + case let .systemEvent(record): + return "systemEvent:\(record.type.rawValue):\(record.turnID ?? record.threadID)" + } + } +} + +extension AgentHistoryRecord { + func redacted(reason: AgentRedactionReason?) -> AgentHistoryRecord { + AgentHistoryRecord( + id: id, + sequenceNumber: sequenceNumber, + createdAt: createdAt, + item: item.redactedPayload(), + redaction: AgentHistoryRedaction(reason: reason) + ) + } +} + +private extension AgentHistoryItem { + func redactedPayload() -> AgentHistoryItem { + switch self { + case let .message(message): + return .message( + AgentMessage( + id: message.id, + threadID: message.threadID, + role: message.role, + text: "[Redacted]", + images: [], + structuredOutput: message.structuredOutput.map { + AgentStructuredOutputMetadata( + formatName: $0.formatName, + payload: .object(["redacted": .bool(true)]) + ) + }, + createdAt: message.createdAt + ) + ) + + case let .toolCall(record): + return .toolCall( + AgentToolCallRecord( + invocation: ToolInvocation( + id: record.invocation.id, + threadID: record.invocation.threadID, + turnID: record.invocation.turnID, + toolName: record.invocation.toolName, + arguments: .object(["redacted": .bool(true)]) + ), + requestedAt: record.requestedAt + ) + ) + + case let .toolResult(record): + return .toolResult( + AgentToolResultRecord( + threadID: record.threadID, + turnID: record.turnID, + result: ToolResultEnvelope( + invocationID: record.result.invocationID, + toolName: record.result.toolName, + success: record.result.success, + content: [.text("[Redacted]")], + errorMessage: record.result.errorMessage == nil ? nil : "[Redacted]", + session: record.result.session + ), + completedAt: record.completedAt + ) + ) + + case let .structuredOutput(record): + return .structuredOutput( + AgentStructuredOutputRecord( + threadID: record.threadID, + turnID: record.turnID, + messageID: record.messageID, + metadata: AgentStructuredOutputMetadata( + formatName: record.metadata.formatName, + payload: .object(["redacted": .bool(true)]) + ), + committedAt: record.committedAt + ) + ) + + case let .approval(record): + let request = record.request.map { + ApprovalRequest( + id: $0.id, + threadID: $0.threadID, + turnID: $0.turnID, + toolInvocation: ToolInvocation( + id: $0.toolInvocation.id, + threadID: $0.toolInvocation.threadID, + turnID: $0.toolInvocation.turnID, + toolName: $0.toolInvocation.toolName, + arguments: .object(["redacted": .bool(true)]) + ), + title: "[Redacted]", + message: "[Redacted]" + ) + } + return .approval( + AgentApprovalRecord( + kind: record.kind, + request: request, + resolution: record.resolution, + occurredAt: record.occurredAt + ) + ) + + case let .systemEvent(record): + return .systemEvent( + AgentSystemEventRecord( + type: record.type, + threadID: record.threadID, + turnID: record.turnID, + status: record.status, + turnSummary: record.turnSummary, + error: record.error.map { + AgentRuntimeError(code: $0.code, message: "[Redacted]") + }, + occurredAt: record.occurredAt + ) + ) + } + } +} diff --git a/Sources/CodexKit/Runtime/AgentModels.swift b/Sources/CodexKit/Runtime/AgentModels.swift index 56da2a0..8ec4c27 100644 --- a/Sources/CodexKit/Runtime/AgentModels.swift +++ b/Sources/CodexKit/Runtime/AgentModels.swift @@ -1,6 +1,6 @@ import Foundation -public struct AgentRuntimeError: Error, LocalizedError, Equatable, Sendable { +public struct AgentRuntimeError: Error, LocalizedError, Equatable, Hashable, Sendable, Codable { public let code: String public let message: String @@ -49,6 +49,13 @@ public struct AgentRuntimeError: Error, LocalizedError, Equatable, Sendable { ) } + public static func invalidHistoryCursor() -> AgentRuntimeError { + AgentRuntimeError( + code: "invalid_history_cursor", + message: "The requested history cursor is invalid for this thread." + ) + } + public static func structuredOutputDecodingFailed( typeName: String, underlyingMessage: String diff --git a/Sources/CodexKit/Runtime/AgentQuerying.swift b/Sources/CodexKit/Runtime/AgentQuerying.swift new file mode 100644 index 0000000..392ccf8 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentQuerying.swift @@ -0,0 +1,455 @@ +import Foundation + +public enum AgentLogicalSchemaVersion: Int, Sendable, Codable, Hashable { + case v1 = 1 +} + +public struct AgentStoreCapabilities: Sendable, Hashable, Codable { + public var supportsPushdownQueries: Bool + public var supportsCrossThreadQueries: Bool + public var supportsSorting: Bool + public var supportsFiltering: Bool + public var supportsMigrations: Bool + + public init( + supportsPushdownQueries: Bool, + supportsCrossThreadQueries: Bool, + supportsSorting: Bool, + supportsFiltering: Bool, + supportsMigrations: Bool + ) { + self.supportsPushdownQueries = supportsPushdownQueries + self.supportsCrossThreadQueries = supportsCrossThreadQueries + self.supportsSorting = supportsSorting + self.supportsFiltering = supportsFiltering + self.supportsMigrations = supportsMigrations + } +} + +public struct AgentStoreMetadata: Sendable, Hashable, Codable { + public let logicalSchemaVersion: AgentLogicalSchemaVersion + public let storeSchemaVersion: Int + public let capabilities: AgentStoreCapabilities + public let storeKind: String + + public init( + logicalSchemaVersion: AgentLogicalSchemaVersion, + storeSchemaVersion: Int, + capabilities: AgentStoreCapabilities, + storeKind: String + ) { + self.logicalSchemaVersion = logicalSchemaVersion + self.storeSchemaVersion = storeSchemaVersion + self.capabilities = capabilities + self.storeKind = storeKind + } +} + +public enum AgentStoreError: Error, Sendable { + case incompatibleLogicalSchema(found: Int, supported: [Int]) + case migrationRequired(from: Int, to: Int) + case migrationFailed(String) + case queryNotSupported(String) +} + +public protocol AgentRuntimeQueryable: Sendable { + func execute(_ query: Query) async throws -> Query.Result +} + +public protocol AgentQuerySpec: Sendable { + associatedtype Result: Sendable + + func execute(in state: StoredRuntimeState) throws -> Result +} + +public protocol AgentRuntimeQueryableStore: RuntimeStateStoring { + func execute(_ query: Query) async throws -> Query.Result +} + +public enum AgentSortOrder: String, Sendable, Hashable, Codable { + case ascending + case descending +} + +public struct AgentQueryPage: Sendable, Hashable, Codable { + public var limit: Int + public var cursor: AgentHistoryCursor? + + public init( + limit: Int = 50, + cursor: AgentHistoryCursor? = nil + ) { + self.limit = limit + self.cursor = cursor + } +} + +public enum AgentHistoryItemKind: String, Sendable, Hashable, Codable, CaseIterable { + case message + case toolCall + case toolResult + case structuredOutput + case approval + case systemEvent +} + +public struct AgentHistoryQueryResult: Sendable, Hashable { + public let threadID: String + public let records: [AgentHistoryRecord] + public let nextCursor: AgentHistoryCursor? + public let previousCursor: AgentHistoryCursor? + public let hasMoreBefore: Bool + public let hasMoreAfter: Bool + + public init( + threadID: String, + records: [AgentHistoryRecord], + nextCursor: AgentHistoryCursor?, + previousCursor: AgentHistoryCursor?, + hasMoreBefore: Bool, + hasMoreAfter: Bool + ) { + self.threadID = threadID + self.records = records + self.nextCursor = nextCursor + self.previousCursor = previousCursor + self.hasMoreBefore = hasMoreBefore + self.hasMoreAfter = hasMoreAfter + } +} + +public enum AgentHistorySort: Sendable, Hashable, Codable { + case sequence(AgentSortOrder) + case createdAt(AgentSortOrder) +} + +public struct HistoryItemsQuery: AgentQuerySpec { + public typealias Result = AgentHistoryQueryResult + + public var threadID: String + public var kinds: Set? + public var createdAtRange: ClosedRange? + public var turnID: String? + public var includeRedacted: Bool + public var sort: AgentHistorySort + public var page: AgentQueryPage? + + public init( + threadID: String, + kinds: Set? = nil, + createdAtRange: ClosedRange? = nil, + turnID: String? = nil, + includeRedacted: Bool = true, + sort: AgentHistorySort = .sequence(.ascending), + page: AgentQueryPage? = nil + ) { + self.threadID = threadID + self.kinds = kinds + self.createdAtRange = createdAtRange + self.turnID = turnID + self.includeRedacted = includeRedacted + self.sort = sort + self.page = page + } + + public func execute(in state: StoredRuntimeState) throws -> AgentHistoryQueryResult { + try state.execute(self) + } +} + +public enum AgentThreadMetadataSort: Sendable, Hashable, Codable { + case updatedAt(AgentSortOrder) + case createdAt(AgentSortOrder) +} + +public struct ThreadMetadataQuery: AgentQuerySpec { + public typealias Result = [AgentThread] + + public var threadIDs: Set? + public var statuses: Set? + public var updatedAtRange: ClosedRange? + public var sort: AgentThreadMetadataSort + public var limit: Int? + + public init( + threadIDs: Set? = nil, + statuses: Set? = nil, + updatedAtRange: ClosedRange? = nil, + sort: AgentThreadMetadataSort = .updatedAt(.descending), + limit: Int? = nil + ) { + self.threadIDs = threadIDs + self.statuses = statuses + self.updatedAtRange = updatedAtRange + self.sort = sort + self.limit = limit + } + + public func execute(in state: StoredRuntimeState) throws -> [AgentThread] { + state.execute(self) + } +} + +public enum AgentPendingStateKind: String, Sendable, Hashable, Codable, CaseIterable { + case approval + case userInput + case toolWait +} + +public struct AgentPendingStateRecord: Sendable, Hashable, Codable { + public let threadID: String + public let pendingState: AgentThreadPendingState + public let updatedAt: Date + + public init( + threadID: String, + pendingState: AgentThreadPendingState, + updatedAt: Date + ) { + self.threadID = threadID + self.pendingState = pendingState + self.updatedAt = updatedAt + } +} + +public enum AgentPendingStateSort: Sendable, Hashable, Codable { + case updatedAt(AgentSortOrder) +} + +public struct PendingStateQuery: AgentQuerySpec { + public typealias Result = [AgentPendingStateRecord] + + public var threadIDs: Set? + public var kinds: Set? + public var sort: AgentPendingStateSort + public var limit: Int? + + public init( + threadIDs: Set? = nil, + kinds: Set? = nil, + sort: AgentPendingStateSort = .updatedAt(.descending), + limit: Int? = nil + ) { + self.threadIDs = threadIDs + self.kinds = kinds + self.sort = sort + self.limit = limit + } + + public func execute(in state: StoredRuntimeState) throws -> [AgentPendingStateRecord] { + state.execute(self) + } +} + +public enum AgentStructuredOutputSort: Sendable, Hashable, Codable { + case committedAt(AgentSortOrder) +} + +public struct StructuredOutputQuery: AgentQuerySpec { + public typealias Result = [AgentStructuredOutputRecord] + + public var threadIDs: Set? + public var formatNames: Set? + public var latestOnly: Bool + public var sort: AgentStructuredOutputSort + public var limit: Int? + + public init( + threadIDs: Set? = nil, + formatNames: Set? = nil, + latestOnly: Bool = false, + sort: AgentStructuredOutputSort = .committedAt(.descending), + limit: Int? = nil + ) { + self.threadIDs = threadIDs + self.formatNames = formatNames + self.latestOnly = latestOnly + self.sort = sort + self.limit = limit + } + + public func execute(in state: StoredRuntimeState) throws -> [AgentStructuredOutputRecord] { + state.execute(self) + } +} + +public struct AgentThreadSnapshot: Sendable, Hashable, Codable { + public let threadID: String + public let createdAt: Date + public let updatedAt: Date + public let latestItemAt: Date? + public let itemCount: Int? + public let latestAssistantMessagePreview: String? + public let latestStructuredOutputMetadata: AgentStructuredOutputMetadata? + public let latestPartialStructuredOutput: AgentPartialStructuredOutputSnapshot? + public let latestToolState: AgentLatestToolState? + public let latestTurnStatus: AgentTurnStatus? + public let pendingState: AgentThreadPendingState? + + public init( + threadID: String, + createdAt: Date, + updatedAt: Date, + latestItemAt: Date? = nil, + itemCount: Int? = nil, + latestAssistantMessagePreview: String? = nil, + latestStructuredOutputMetadata: AgentStructuredOutputMetadata? = nil, + latestPartialStructuredOutput: AgentPartialStructuredOutputSnapshot? = nil, + latestToolState: AgentLatestToolState? = nil, + latestTurnStatus: AgentTurnStatus? = nil, + pendingState: AgentThreadPendingState? = nil + ) { + self.threadID = threadID + self.createdAt = createdAt + self.updatedAt = updatedAt + self.latestItemAt = latestItemAt + self.itemCount = itemCount + self.latestAssistantMessagePreview = latestAssistantMessagePreview + self.latestStructuredOutputMetadata = latestStructuredOutputMetadata + self.latestPartialStructuredOutput = latestPartialStructuredOutput + self.latestToolState = latestToolState + self.latestTurnStatus = latestTurnStatus + self.pendingState = pendingState + } +} + +public enum AgentThreadSnapshotSort: Sendable, Hashable, Codable { + case updatedAt(AgentSortOrder) + case createdAt(AgentSortOrder) +} + +public struct ThreadSnapshotQuery: AgentQuerySpec { + public typealias Result = [AgentThreadSnapshot] + + public var threadIDs: Set? + public var sort: AgentThreadSnapshotSort + public var limit: Int? + + public init( + threadIDs: Set? = nil, + sort: AgentThreadSnapshotSort = .updatedAt(.descending), + limit: Int? = nil + ) { + self.threadIDs = threadIDs + self.sort = sort + self.limit = limit + } + + public func execute(in state: StoredRuntimeState) throws -> [AgentThreadSnapshot] { + state.execute(self) + } +} + +public struct AgentRedactionReason: Sendable, Hashable, Codable { + public let code: String + public let message: String? + + public init( + code: String, + message: String? = nil + ) { + self.code = code + self.message = message + } +} + +public struct AgentHistoryRedaction: Sendable, Hashable, Codable { + public let redactedAt: Date + public let reason: AgentRedactionReason? + + public init( + redactedAt: Date = Date(), + reason: AgentRedactionReason? = nil + ) { + self.redactedAt = redactedAt + self.reason = reason + } +} + +public struct AgentToolSessionRecord: Sendable, Hashable, Codable { + public let threadID: String + public let invocationID: String + public let turnID: String + public let toolName: String + public let sessionID: String? + public let sessionStatus: String? + public let metadata: JSONValue? + public let resumable: Bool + public let updatedAt: Date + + public init( + threadID: String, + invocationID: String, + turnID: String, + toolName: String, + sessionID: String? = nil, + sessionStatus: String? = nil, + metadata: JSONValue? = nil, + resumable: Bool = false, + updatedAt: Date = Date() + ) { + self.threadID = threadID + self.invocationID = invocationID + self.turnID = turnID + self.toolName = toolName + self.sessionID = sessionID + self.sessionStatus = sessionStatus + self.metadata = metadata + self.resumable = resumable + self.updatedAt = updatedAt + } +} + +public enum AgentStoreWriteOperation: Sendable, Hashable { + case upsertThread(AgentThread) + case upsertSummary(threadID: String, summary: AgentThreadSummary) + case appendHistoryItems(threadID: String, items: [AgentHistoryRecord]) + case setPendingState(threadID: String, state: AgentThreadPendingState?) + case setPartialStructuredSnapshot(threadID: String, snapshot: AgentPartialStructuredOutputSnapshot?) + case upsertToolSession(threadID: String, session: AgentToolSessionRecord) + case redactHistoryItems(threadID: String, itemIDs: [String], reason: AgentRedactionReason?) + case deleteThread(threadID: String) +} + +public extension AgentRuntimeQueryableStore { + func execute(_ query: Query) async throws -> Query.Result { + let state = try await loadState() + return try query.execute(in: state) + } +} + +extension AgentStoreWriteOperation { + var affectedThreadID: String { + switch self { + case let .upsertThread(thread): + thread.id + case let .upsertSummary(threadID, _): + threadID + case let .appendHistoryItems(threadID, _): + threadID + case let .setPendingState(threadID, _): + threadID + case let .setPartialStructuredSnapshot(threadID, _): + threadID + case let .upsertToolSession(threadID, _): + threadID + case let .redactHistoryItems(threadID, _, _): + threadID + case let .deleteThread(threadID): + threadID + } + } +} + +extension AgentThreadPendingState { + var kind: AgentPendingStateKind { + switch self { + case .approval: + .approval + case .userInput: + .userInput + case .toolWait: + .toolWait + } + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+History.swift b/Sources/CodexKit/Runtime/AgentRuntime+History.swift new file mode 100644 index 0000000..85f745a --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+History.swift @@ -0,0 +1,345 @@ +import Foundation + +extension AgentRuntime: AgentRuntimeQueryable, AgentRuntimeThreadInspecting { + public func execute(_ query: Query) async throws -> Query.Result { + if let queryableStore = stateStore as? any AgentRuntimeQueryableStore { + return try await queryableStore.execute(query) + } + + let loadedState = try await stateStore.loadState() + return try query.execute(in: loadedState) + } + + public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { + if let inspectingStore = stateStore as? any RuntimeStateInspecting { + return try await inspectingStore.fetchThreadSummary(id: id) + } + + let snapshots = try await execute( + ThreadSnapshotQuery( + threadIDs: [id], + limit: 1 + ) + ) + guard let snapshot = snapshots.first else { + throw AgentRuntimeError.threadNotFound(id) + } + return snapshot.summary + } + + public func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage { + if let inspectingStore = stateStore as? any RuntimeStateInspecting { + return try await inspectingStore.fetchThreadHistory(id: id, query: query) + } + + let result = try await execute( + HistoryItemsQuery( + threadID: id, + kinds: query.filter?.includedKinds, + includeRedacted: true, + sort: query.direction == .forward ? .sequence(.ascending) : .sequence(.descending), + page: AgentQueryPage(limit: query.limit, cursor: query.cursor) + ) + ) + + return AgentThreadHistoryPage( + threadID: id, + items: result.records.map(\.item), + nextCursor: result.nextCursor, + previousCursor: result.previousCursor, + hasMoreBefore: result.hasMoreBefore, + hasMoreAfter: result.hasMoreAfter + ) + } + + public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { + if let inspectingStore = stateStore as? any RuntimeStateInspecting { + return try await inspectingStore.fetchLatestStructuredOutputMetadata(id: id) + } + + let records = try await execute( + StructuredOutputQuery( + threadIDs: [id], + latestOnly: true, + limit: 1 + ) + ) + return records.first?.metadata + } + + public func fetchLatestStructuredOutput( + id: String, + as outputType: Output.Type, + decoder: JSONDecoder = JSONDecoder() + ) async throws -> Output? { + guard let metadata = try await fetchLatestStructuredOutputMetadata(id: id) else { + return nil + } + + return try decodeStructuredValue( + metadata.payload, + as: outputType, + decoder: decoder + ) + } + + public func storeMetadata() async throws -> AgentStoreMetadata { + try await stateStore.readMetadata() + } + + @discardableResult + public func prepareStore() async throws -> AgentStoreMetadata { + try await stateStore.prepare() + } +} + +extension AgentRuntime { + public func deleteThread(id: String) async throws { + if !pendingStoreOperations.isEmpty { + try await persistState() + } + state = try state.applying([.deleteThread(threadID: id)]) + try await stateStore.apply([.deleteThread(threadID: id)]) + } + + public func redactHistoryItems( + _ itemIDs: [String], + in threadID: String, + reason: AgentRedactionReason? = nil + ) async throws { + guard !itemIDs.isEmpty else { + return + } + if !pendingStoreOperations.isEmpty { + try await persistState() + } + + let operation = AgentStoreWriteOperation.redactHistoryItems( + threadID: threadID, + itemIDs: itemIDs, + reason: reason + ) + state = try state.applying([operation]) + try await stateStore.apply([operation]) + } + + func appendHistoryItem( + _ item: AgentHistoryItem, + threadID: String, + createdAt: Date + ) { + let nextSequence = state.nextHistorySequenceByThread[threadID] + ?? ((state.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1) + let record = AgentHistoryRecord( + sequenceNumber: nextSequence, + createdAt: createdAt, + item: item + ) + state.historyByThread[threadID, default: []].append(record) + state.nextHistorySequenceByThread[threadID] = nextSequence + 1 + enqueueStoreOperation( + .appendHistoryItems(threadID: threadID, items: [record]) + ) + } + + func updateThreadTimestamp( + _ timestamp: Date, + for threadID: String + ) { + guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { + return + } + + state.threads[index].updatedAt = max(state.threads[index].updatedAt, timestamp) + enqueueStoreOperation(.upsertThread(state.threads[index])) + } + + func updateSummary( + for threadID: String, + _ mutate: (AgentThreadSummary) -> AgentThreadSummary + ) throws { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + let current = state.summariesByThread[threadID] + ?? state.threadSummaryFallback(for: thread) + let updated = mutate(current) + state.summariesByThread[threadID] = updated + enqueueStoreOperation(.upsertSummary(threadID: threadID, summary: updated)) + } + + func setPendingState( + _ pendingState: AgentThreadPendingState?, + for threadID: String + ) throws { + try updateSummary(for: threadID) { summary in + AgentThreadSummary( + threadID: summary.threadID, + createdAt: summary.createdAt, + updatedAt: summary.updatedAt, + latestItemAt: summary.latestItemAt, + itemCount: summary.itemCount, + latestAssistantMessagePreview: summary.latestAssistantMessagePreview, + latestStructuredOutputMetadata: summary.latestStructuredOutputMetadata, + latestPartialStructuredOutput: summary.latestPartialStructuredOutput, + latestToolState: summary.latestToolState, + latestTurnStatus: summary.latestTurnStatus, + pendingState: pendingState + ) + } + } + + func setLatestPartialStructuredOutput( + _ snapshot: AgentPartialStructuredOutputSnapshot?, + for threadID: String + ) throws { + try updateSummary(for: threadID) { summary in + AgentThreadSummary( + threadID: summary.threadID, + createdAt: summary.createdAt, + updatedAt: summary.updatedAt, + latestItemAt: summary.latestItemAt, + itemCount: summary.itemCount, + latestAssistantMessagePreview: summary.latestAssistantMessagePreview, + latestStructuredOutputMetadata: summary.latestStructuredOutputMetadata, + latestPartialStructuredOutput: snapshot, + latestToolState: summary.latestToolState, + latestTurnStatus: summary.latestTurnStatus, + pendingState: summary.pendingState + ) + } + } + + func setLatestStructuredOutputMetadata( + _ metadata: AgentStructuredOutputMetadata?, + for threadID: String + ) throws { + try updateSummary(for: threadID) { summary in + AgentThreadSummary( + threadID: summary.threadID, + createdAt: summary.createdAt, + updatedAt: summary.updatedAt, + latestItemAt: summary.latestItemAt, + itemCount: summary.itemCount, + latestAssistantMessagePreview: summary.latestAssistantMessagePreview, + latestStructuredOutputMetadata: metadata, + latestPartialStructuredOutput: summary.latestPartialStructuredOutput, + latestToolState: summary.latestToolState, + latestTurnStatus: summary.latestTurnStatus, + pendingState: summary.pendingState + ) + } + } + + func setLatestToolState( + _ latestToolState: AgentLatestToolState?, + for threadID: String + ) throws { + try updateSummary(for: threadID) { summary in + AgentThreadSummary( + threadID: summary.threadID, + createdAt: summary.createdAt, + updatedAt: summary.updatedAt, + latestItemAt: summary.latestItemAt, + itemCount: summary.itemCount, + latestAssistantMessagePreview: summary.latestAssistantMessagePreview, + latestStructuredOutputMetadata: summary.latestStructuredOutputMetadata, + latestPartialStructuredOutput: summary.latestPartialStructuredOutput, + latestToolState: latestToolState, + latestTurnStatus: summary.latestTurnStatus, + pendingState: summary.pendingState + ) + } + } + + func setLatestTurnStatus( + _ latestTurnStatus: AgentTurnStatus?, + for threadID: String + ) throws { + try updateSummary(for: threadID) { summary in + AgentThreadSummary( + threadID: summary.threadID, + createdAt: summary.createdAt, + updatedAt: summary.updatedAt, + latestItemAt: summary.latestItemAt, + itemCount: summary.itemCount, + latestAssistantMessagePreview: summary.latestAssistantMessagePreview, + latestStructuredOutputMetadata: summary.latestStructuredOutputMetadata, + latestPartialStructuredOutput: summary.latestPartialStructuredOutput, + latestToolState: summary.latestToolState, + latestTurnStatus: latestTurnStatus, + pendingState: summary.pendingState + ) + } + } + + func latestToolState( + for invocation: ToolInvocation, + result: ToolResultEnvelope?, + updatedAt: Date + ) -> AgentLatestToolState { + let status: AgentToolSessionStatus + if let result { + if result.errorMessage == "Tool execution was denied by the user." { + status = .denied + } else if let session = result.session, !session.isTerminal { + status = .running + } else if result.success { + status = .completed + } else { + status = .failed + } + } else { + status = .waiting + } + + return AgentLatestToolState( + invocationID: invocation.id, + turnID: invocation.turnID, + toolName: invocation.toolName, + status: status, + success: result?.success, + sessionID: result?.session?.sessionID, + sessionStatus: result?.session?.status, + metadata: result?.session?.metadata, + resumable: result?.session?.resumable ?? false, + updatedAt: updatedAt, + resultPreview: result?.primaryText + ) + } +} + +private extension AgentHistoryFilter { + var includedKinds: Set { + var kinds: Set = [] + if includeMessages { kinds.insert(.message) } + if includeToolCalls { kinds.insert(.toolCall) } + if includeToolResults { kinds.insert(.toolResult) } + if includeStructuredOutputs { kinds.insert(.structuredOutput) } + if includeApprovals { kinds.insert(.approval) } + if includeSystemEvents { kinds.insert(.systemEvent) } + return kinds + } +} + +private extension AgentThreadSnapshot { + var summary: AgentThreadSummary { + AgentThreadSummary( + threadID: threadID, + createdAt: createdAt, + updatedAt: updatedAt, + latestItemAt: latestItemAt, + itemCount: itemCount, + latestAssistantMessagePreview: latestAssistantMessagePreview, + latestStructuredOutputMetadata: latestStructuredOutputMetadata, + latestPartialStructuredOutput: latestPartialStructuredOutput, + latestToolState: latestToolState, + latestTurnStatus: latestTurnStatus, + pendingState: pendingState + ) + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Skills.swift b/Sources/CodexKit/Runtime/AgentRuntime+Skills.swift index 9192b21..7e709e8 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+Skills.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+Skills.swift @@ -81,6 +81,7 @@ extension AgentRuntime { state.threads[index].skillIDs = skillIDs state.threads[index].updatedAt = Date() + enqueueStoreOperation(.upsertThread(state.threads[index])) try await persistState() } diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift index f32669e..3392db2 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift @@ -34,7 +34,20 @@ extension AgentRuntime { thread.personaStack = resolvedPersonaStack thread.skillIDs = skillIDs thread.memoryContext = memoryContext - try await upsertThread(thread) + try await upsertThread(thread, persist: false) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .threadCreated, + threadID: thread.id, + occurredAt: thread.createdAt + ) + ), + threadID: thread.id, + createdAt: thread.createdAt + ) + updateThreadTimestamp(thread.createdAt, for: thread.id) + try await persistState() return thread } @@ -47,7 +60,20 @@ extension AgentRuntime { try await backend.resumeThread(id: id, session: session) } let thread = resume.result - try await upsertThread(thread) + try await upsertThread(thread, persist: false) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .threadResumed, + threadID: thread.id, + occurredAt: Date() + ) + ), + threadID: thread.id, + createdAt: Date() + ) + updateThreadTimestamp(Date(), for: thread.id) + try await persistState() return thread } @@ -75,6 +101,7 @@ extension AgentRuntime { state.threads[index].personaStack = personaStack state.threads[index].updatedAt = Date() + enqueueStoreOperation(.upsertThread(state.threads[index])) try await persistState() } @@ -102,12 +129,16 @@ extension AgentRuntime { state.threads[index].memoryContext = memoryContext state.threads[index].updatedAt = Date() + enqueueStoreOperation(.upsertThread(state.threads[index])) try await persistState() } // MARK: - State Mutation - func upsertThread(_ thread: AgentThread) async throws { + func upsertThread( + _ thread: AgentThread, + persist: Bool = true + ) async throws { if let index = state.threads.firstIndex(where: { $0.id == thread.id }) { var mergedThread = thread if mergedThread.title == nil { @@ -123,10 +154,19 @@ extension AgentRuntime { mergedThread.memoryContext = state.threads[index].memoryContext } state.threads[index] = mergedThread + enqueueStoreOperation(.upsertThread(state.threads[index])) } else { state.threads.append(thread) + enqueueStoreOperation(.upsertThread(thread)) + } + if state.summariesByThread[thread.id] == nil { + let summary = state.threadSummaryFallback(for: thread) + state.summariesByThread[thread.id] = summary + enqueueStoreOperation(.upsertSummary(threadID: thread.id, summary: summary)) + } + if persist { + try await persistState() } - try await persistState() } func setThreadStatus( @@ -137,13 +177,34 @@ extension AgentRuntime { throw AgentRuntimeError.threadNotFound(threadID) } + let previousStatus = state.threads[index].status state.threads[index].status = status state.threads[index].updatedAt = Date() + enqueueStoreOperation(.upsertThread(state.threads[index])) + if previousStatus != status { + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .threadStatusChanged, + threadID: threadID, + status: status, + occurredAt: state.threads[index].updatedAt + ) + ), + threadID: threadID, + createdAt: state.threads[index].updatedAt + ) + } try await persistState() } func appendMessage(_ message: AgentMessage) async throws { state.messagesByThread[message.threadID, default: []].append(message) + appendHistoryItem( + .message(message), + threadID: message.threadID, + createdAt: message.createdAt + ) if let index = state.threads.firstIndex(where: { $0.id == message.threadID }) { state.threads[index].updatedAt = message.createdAt @@ -156,6 +217,7 @@ extension AgentRuntime { : "Image message (\(message.images.count))" } } + enqueueStoreOperation(.upsertThread(state.threads[index])) } try await persistState() diff --git a/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift b/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift index 2580619..7985428 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift @@ -17,11 +17,28 @@ extension AgentRuntime { nil } var assistantMessages: [AgentMessage] = [] + var currentTurnID: String? do { for try await backendEvent in turnStream.events { switch backendEvent { case let .turnStarted(turn): + currentTurnID = turn.id + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnStarted, + threadID: threadID, + turnID: turn.id, + occurredAt: turn.startedAt + ) + ), + threadID: threadID, + createdAt: turn.startedAt + ) + try setLatestTurnStatus(.running, for: threadID) + updateThreadTimestamp(turn.startedAt, for: threadID) + try await persistState() continuation.yield(.turnStarted(turn)) case let .assistantMessageDelta(threadID, turnID, delta): @@ -46,6 +63,22 @@ extension AgentRuntime { break case let .toolCallRequested(invocation): + appendHistoryItem( + .toolCall( + AgentToolCallRecord( + invocation: invocation, + requestedAt: Date() + ) + ), + threadID: invocation.threadID, + createdAt: Date() + ) + try setLatestToolState( + latestToolState(for: invocation, result: nil, updatedAt: Date()), + for: invocation.threadID + ) + updateThreadTimestamp(Date(), for: invocation.threadID) + try await persistState() continuation.yield(.toolCallStarted(invocation)) let result: ToolResultEnvelope @@ -72,6 +105,21 @@ extension AgentRuntime { case let .turnCompleted(summary): if let completionError = policyTracker?.completionError() { + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: completionError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try setLatestTurnStatus(.failed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) try await setThreadStatus(.failed, for: threadID) continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) continuation.yield(.turnFailed(completionError)) @@ -79,6 +127,21 @@ extension AgentRuntime { return } + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnCompleted, + threadID: threadID, + turnID: summary.turnID, + turnSummary: summary, + occurredAt: summary.completedAt + ) + ), + threadID: threadID, + createdAt: summary.completedAt + ) + try setLatestTurnStatus(.completed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) try await setThreadStatus(.idle, for: threadID) await automaticallyCaptureMemoriesIfConfigured( for: threadID, @@ -97,6 +160,21 @@ extension AgentRuntime { code: "turn_failed", message: error.localizedDescription ) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try? setLatestTurnStatus(.failed, for: threadID) + try? setLatestPartialStructuredOutput(nil, for: threadID) try? await setThreadStatus(.failed, for: threadID) continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) continuation.yield(.turnFailed(runtimeError)) @@ -123,11 +201,28 @@ extension AgentRuntime { } var assistantMessages: [AgentMessage] = [] var sawStructuredCommit = false + var currentTurnID: String? do { for try await backendEvent in turnStream.events { switch backendEvent { case let .turnStarted(turn): + currentTurnID = turn.id + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnStarted, + threadID: threadID, + turnID: turn.id, + occurredAt: turn.startedAt + ) + ), + threadID: threadID, + createdAt: turn.startedAt + ) + try setLatestTurnStatus(.running, for: threadID) + updateThreadTimestamp(turn.startedAt, for: threadID) + try await persistState() continuation.yield(.turnStarted(turn)) case let .assistantMessageDelta(threadID, turnID, delta): @@ -153,6 +248,19 @@ extension AgentRuntime { as: outputType, decoder: decoder ) + if let currentTurnID { + try setLatestPartialStructuredOutput( + AgentPartialStructuredOutputSnapshot( + turnID: currentTurnID, + formatName: responseFormat.name, + payload: value, + updatedAt: Date() + ), + for: threadID + ) + updateThreadTimestamp(Date(), for: threadID) + try await persistState() + } if options.emitPartials { continuation.yield(.structuredOutputPartial(decoded)) } @@ -176,6 +284,26 @@ extension AgentRuntime { decoder: decoder ) sawStructuredCommit = true + let metadata = AgentStructuredOutputMetadata( + formatName: responseFormat.name, + payload: value + ) + try setLatestStructuredOutputMetadata(metadata, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) + appendHistoryItem( + .structuredOutput( + AgentStructuredOutputRecord( + threadID: threadID, + turnID: currentTurnID ?? "", + metadata: metadata, + committedAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + updateThreadTimestamp(Date(), for: threadID) + try await persistState() continuation.yield(.structuredOutputCommitted(decoded)) } catch { let validationFailure = AgentStructuredOutputValidationFailure( @@ -187,6 +315,21 @@ extension AgentRuntime { stage: validationFailure.stage, underlyingMessage: validationFailure.message ) + try? setLatestPartialStructuredOutput(nil, for: threadID) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try? setLatestTurnStatus(.failed, for: threadID) try await setThreadStatus(.failed, for: threadID) continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) continuation.yield(.structuredOutputValidationFailed(validationFailure)) @@ -196,9 +339,27 @@ extension AgentRuntime { } case let .structuredOutputValidationFailed(validationFailure): + try? setLatestPartialStructuredOutput(nil, for: threadID) + try? await persistState() continuation.yield(.structuredOutputValidationFailed(validationFailure)) case let .toolCallRequested(invocation): + appendHistoryItem( + .toolCall( + AgentToolCallRecord( + invocation: invocation, + requestedAt: Date() + ) + ), + threadID: invocation.threadID, + createdAt: Date() + ) + try setLatestToolState( + latestToolState(for: invocation, result: nil, updatedAt: Date()), + for: invocation.threadID + ) + updateThreadTimestamp(Date(), for: invocation.threadID) + try await persistState() continuation.yield(.toolCallStarted(invocation)) let result: ToolResultEnvelope @@ -225,6 +386,21 @@ extension AgentRuntime { case let .turnCompleted(summary): if let completionError = policyTracker?.completionError() { + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: completionError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try setLatestTurnStatus(.failed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) try await setThreadStatus(.failed, for: threadID) continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) continuation.yield(.turnFailed(completionError)) @@ -236,6 +412,21 @@ extension AgentRuntime { let runtimeError = AgentRuntimeError.structuredOutputMissing( formatName: responseFormat.name ) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try setLatestTurnStatus(.failed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) try await setThreadStatus(.failed, for: threadID) continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) continuation.yield(.turnFailed(runtimeError)) @@ -243,6 +434,21 @@ extension AgentRuntime { return } + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnCompleted, + threadID: threadID, + turnID: summary.turnID, + turnSummary: summary, + occurredAt: summary.completedAt + ) + ), + threadID: threadID, + createdAt: summary.completedAt + ) + try setLatestTurnStatus(.completed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) try await setThreadStatus(.idle, for: threadID) await automaticallyCaptureMemoriesIfConfigured( for: threadID, @@ -261,6 +467,21 @@ extension AgentRuntime { code: "turn_failed", message: error.localizedDescription ) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try? setLatestTurnStatus(.failed, for: threadID) + try? setLatestPartialStructuredOutput(nil, for: threadID) try? await setThreadStatus(.failed, for: threadID) continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) continuation.yield(.turnFailed(runtimeError)) @@ -284,6 +505,26 @@ extension AgentRuntime { ?? "This tool requires explicit approval before it can run." ) + appendHistoryItem( + .approval( + AgentApprovalRecord( + kind: .requested, + request: approval, + occurredAt: Date() + ) + ), + threadID: invocation.threadID, + createdAt: Date() + ) + try setPendingState( + .approval( + AgentPendingApprovalState( + request: approval, + requestedAt: Date() + ) + ), + for: invocation.threadID + ) try await setThreadStatus(.waitingForApproval, for: invocation.threadID) continuation.yield( .threadStatusChanged( @@ -294,22 +535,71 @@ extension AgentRuntime { continuation.yield(.approvalRequested(approval)) let decision = try await approvalCoordinator.requestApproval(approval) + let resolution = ApprovalResolution( + requestID: approval.id, + threadID: approval.threadID, + turnID: approval.turnID, + decision: decision + ) + appendHistoryItem( + .approval( + AgentApprovalRecord( + kind: .resolved, + request: approval, + resolution: resolution, + occurredAt: resolution.decidedAt + ) + ), + threadID: invocation.threadID, + createdAt: resolution.decidedAt + ) + try setPendingState(nil, for: invocation.threadID) continuation.yield( .approvalResolved( - ApprovalResolution( - requestID: approval.id, - threadID: approval.threadID, - turnID: approval.turnID, - decision: decision - ) + resolution ) ) guard decision == .approved else { - return .denied(invocation: invocation) + let denied = ToolResultEnvelope.denied(invocation: invocation) + try setLatestToolState( + latestToolState(for: invocation, result: denied, updatedAt: resolution.decidedAt), + for: invocation.threadID + ) + appendHistoryItem( + .toolResult( + AgentToolResultRecord( + threadID: invocation.threadID, + turnID: invocation.turnID, + result: denied, + completedAt: resolution.decidedAt + ) + ), + threadID: invocation.threadID, + createdAt: resolution.decidedAt + ) + updateThreadTimestamp(resolution.decidedAt, for: invocation.threadID) + try await persistState() + return denied } } + let toolWaitStartedAt = Date() + try setPendingState( + .toolWait( + AgentPendingToolWaitState( + invocationID: invocation.id, + turnID: invocation.turnID, + toolName: invocation.toolName, + startedAt: toolWaitStartedAt + ) + ), + for: invocation.threadID + ) + try setLatestToolState( + latestToolState(for: invocation, result: nil, updatedAt: toolWaitStartedAt), + for: invocation.threadID + ) try await setThreadStatus(.waitingForToolResult, for: invocation.threadID) continuation.yield( .threadStatusChanged( @@ -318,7 +608,46 @@ extension AgentRuntime { ) ) - return await toolRegistry.execute(invocation, session: session) + let result = await toolRegistry.execute(invocation, session: session) + let resultDate = Date() + try setLatestToolState( + latestToolState(for: invocation, result: result, updatedAt: resultDate), + for: invocation.threadID + ) + if let session = result.session, !session.isTerminal { + try setPendingState( + .toolWait( + AgentPendingToolWaitState( + invocationID: invocation.id, + turnID: invocation.turnID, + toolName: invocation.toolName, + startedAt: toolWaitStartedAt, + sessionID: session.sessionID, + sessionStatus: session.status, + metadata: session.metadata, + resumable: session.resumable + ) + ), + for: invocation.threadID + ) + } else { + try setPendingState(nil, for: invocation.threadID) + appendHistoryItem( + .toolResult( + AgentToolResultRecord( + threadID: invocation.threadID, + turnID: invocation.turnID, + result: result, + completedAt: resultDate + ) + ), + threadID: invocation.threadID, + createdAt: resultDate + ) + } + updateThreadTimestamp(resultDate, for: invocation.threadID) + try await persistState() + return result } func resolveToolInvocation( @@ -337,6 +666,26 @@ extension AgentRuntime { ?? "This tool requires explicit approval before it can run." ) + appendHistoryItem( + .approval( + AgentApprovalRecord( + kind: .requested, + request: approval, + occurredAt: Date() + ) + ), + threadID: invocation.threadID, + createdAt: Date() + ) + try setPendingState( + .approval( + AgentPendingApprovalState( + request: approval, + requestedAt: Date() + ) + ), + for: invocation.threadID + ) try await setThreadStatus(.waitingForApproval, for: invocation.threadID) continuation.yield( .threadStatusChanged( @@ -347,22 +696,71 @@ extension AgentRuntime { continuation.yield(.approvalRequested(approval)) let decision = try await approvalCoordinator.requestApproval(approval) + let resolution = ApprovalResolution( + requestID: approval.id, + threadID: approval.threadID, + turnID: approval.turnID, + decision: decision + ) + appendHistoryItem( + .approval( + AgentApprovalRecord( + kind: .resolved, + request: approval, + resolution: resolution, + occurredAt: resolution.decidedAt + ) + ), + threadID: invocation.threadID, + createdAt: resolution.decidedAt + ) + try setPendingState(nil, for: invocation.threadID) continuation.yield( .approvalResolved( - ApprovalResolution( - requestID: approval.id, - threadID: approval.threadID, - turnID: approval.turnID, - decision: decision - ) + resolution ) ) guard decision == .approved else { - return .denied(invocation: invocation) + let denied = ToolResultEnvelope.denied(invocation: invocation) + try setLatestToolState( + latestToolState(for: invocation, result: denied, updatedAt: resolution.decidedAt), + for: invocation.threadID + ) + appendHistoryItem( + .toolResult( + AgentToolResultRecord( + threadID: invocation.threadID, + turnID: invocation.turnID, + result: denied, + completedAt: resolution.decidedAt + ) + ), + threadID: invocation.threadID, + createdAt: resolution.decidedAt + ) + updateThreadTimestamp(resolution.decidedAt, for: invocation.threadID) + try await persistState() + return denied } } + let toolWaitStartedAt = Date() + try setPendingState( + .toolWait( + AgentPendingToolWaitState( + invocationID: invocation.id, + turnID: invocation.turnID, + toolName: invocation.toolName, + startedAt: toolWaitStartedAt + ) + ), + for: invocation.threadID + ) + try setLatestToolState( + latestToolState(for: invocation, result: nil, updatedAt: toolWaitStartedAt), + for: invocation.threadID + ) try await setThreadStatus(.waitingForToolResult, for: invocation.threadID) continuation.yield( .threadStatusChanged( @@ -371,6 +769,45 @@ extension AgentRuntime { ) ) - return await toolRegistry.execute(invocation, session: session) + let result = await toolRegistry.execute(invocation, session: session) + let resultDate = Date() + try setLatestToolState( + latestToolState(for: invocation, result: result, updatedAt: resultDate), + for: invocation.threadID + ) + if let session = result.session, !session.isTerminal { + try setPendingState( + .toolWait( + AgentPendingToolWaitState( + invocationID: invocation.id, + turnID: invocation.turnID, + toolName: invocation.toolName, + startedAt: toolWaitStartedAt, + sessionID: session.sessionID, + sessionStatus: session.status, + metadata: session.metadata, + resumable: session.resumable + ) + ), + for: invocation.threadID + ) + } else { + try setPendingState(nil, for: invocation.threadID) + appendHistoryItem( + .toolResult( + AgentToolResultRecord( + threadID: invocation.threadID, + turnID: invocation.turnID, + result: result, + completedAt: resultDate + ) + ), + threadID: invocation.threadID, + createdAt: resultDate + ) + } + updateThreadTimestamp(resultDate, for: invocation.threadID) + try await persistState() + return result } } diff --git a/Sources/CodexKit/Runtime/AgentRuntime.swift b/Sources/CodexKit/Runtime/AgentRuntime.swift index 90668d3..a8faa73 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime.swift @@ -64,6 +64,7 @@ public actor AgentRuntime { var skillsByID: [String: AgentSkill] var state: StoredRuntimeState = .empty + var pendingStoreOperations: [AgentStoreWriteOperation] = [] struct ResolvedTurnSkills { let threadSkills: [AgentSkill] @@ -170,7 +171,9 @@ public actor AgentRuntime { @discardableResult public func restore() async throws -> StoredRuntimeState { _ = try await sessionManager.restore() + _ = try await stateStore.prepare() state = try await stateStore.loadState() + pendingStoreOperations.removeAll() return state } @@ -216,7 +219,19 @@ public actor AgentRuntime { // MARK: - Instruction Resolution func persistState() async throws { - try await stateStore.saveState(state) + state = state.normalized() + guard !pendingStoreOperations.isEmpty else { + try await stateStore.saveState(state) + return + } + + let operations = pendingStoreOperations + try await stateStore.apply(operations) + pendingStoreOperations.removeAll() + } + + func enqueueStoreOperation(_ operation: AgentStoreWriteOperation) { + pendingStoreOperations.append(operation) } func resolveInstructions( diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift new file mode 100644 index 0000000..ed632d6 --- /dev/null +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift @@ -0,0 +1,1148 @@ +import Foundation +import GRDB + +public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { + private static let currentStoreSchemaVersion = 1 + + private let url: URL + private let legacyStateURL: URL? + private let databaseExistedAtInitialization: Bool + private let dbQueue: DatabaseQueue + private let migrator: DatabaseMigrator + private var isPrepared = false + + public init( + url: URL, + importingLegacyStateFrom legacyStateURL: URL? = nil + ) throws { + self.url = url + let fileManager = FileManager.default + self.databaseExistedAtInitialization = fileManager.fileExists(atPath: url.path) + self.legacyStateURL = legacyStateURL ?? Self.defaultLegacyImportURL(for: url) + + let directory = url.deletingLastPathComponent() + if !directory.path.isEmpty { + try fileManager.createDirectory( + at: directory, + withIntermediateDirectories: true + ) + } + + var configuration = Configuration() + configuration.foreignKeysEnabled = true + configuration.label = "CodexKit.GRDBRuntimeStateStore" + dbQueue = try DatabaseQueue(path: url.path, configuration: configuration) + migrator = Self.makeMigrator() + } + + public func prepare() async throws -> AgentStoreMetadata { + try await ensurePrepared() + return try await readMetadata() + } + + public func readMetadata() async throws -> AgentStoreMetadata { + try await ensurePrepared() + let storeSchemaVersion = try await readUserVersion() + + return AgentStoreMetadata( + logicalSchemaVersion: .v1, + storeSchemaVersion: storeSchemaVersion, + capabilities: AgentStoreCapabilities( + supportsPushdownQueries: true, + supportsCrossThreadQueries: true, + supportsSorting: true, + supportsFiltering: true, + supportsMigrations: true + ), + storeKind: "GRDBRuntimeStateStore" + ) + } + + public func loadState() async throws -> StoredRuntimeState { + try await ensurePrepared() + + return try await dbQueue.read { db in + let threadRows = try RuntimeThreadRow.fetchAll(db) + let summaryRows = try RuntimeSummaryRow.fetchAll(db) + let historyRows = try RuntimeHistoryRow.fetchAll(db) + + let threads = try threadRows.map { try Self.decodeThread(from: $0) } + let summariesByThread = try Dictionary( + uniqueKeysWithValues: summaryRows.map { row in + (row.threadID, try Self.decodeSummary(from: row)) + } + ) + let historyByThread = try Dictionary( + grouping: historyRows.map { try Self.decodeHistoryRecord(from: $0) }, + by: \.item.threadID + ) + + return StoredRuntimeState( + threads: threads, + historyByThread: historyByThread, + summariesByThread: summariesByThread + ) + } + } + + public func saveState(_ state: StoredRuntimeState) async throws { + try await ensurePrepared() + + let normalized = state.normalized() + try await dbQueue.write { db in + try Self.replaceDatabaseContents(with: normalized, in: db) + } + } + + public func apply(_ operations: [AgentStoreWriteOperation]) async throws { + try await ensurePrepared() + guard !operations.isEmpty else { + return + } + + let affectedThreadIDs = Set(operations.map(\.affectedThreadID)) + guard !affectedThreadIDs.isEmpty else { + return + } + + try await dbQueue.write { db in + var partialState = try Self.loadPartialState(for: affectedThreadIDs, from: db) + partialState = try partialState.applying(operations) + + for threadID in affectedThreadIDs { + try Self.deletePersistedThread(threadID, in: db) + } + + try Self.persistThreads( + ids: affectedThreadIDs, + from: partialState, + in: db + ) + } + } + + public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { + try await ensurePrepared() + + return try await dbQueue.read { db in + guard let threadRow = try RuntimeThreadRow.fetchOne(db, key: id) else { + throw AgentRuntimeError.threadNotFound(id) + } + if let summaryRow = try RuntimeSummaryRow.fetchOne(db, key: id) { + return try Self.decodeSummary(from: summaryRow) + } + let thread = try Self.decodeThread(from: threadRow) + return StoredRuntimeState(threads: [thread]).threadSummaryFallback(for: thread) + } + } + + public func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage { + try await ensurePrepared() + + return try await dbQueue.read { db in + guard try RuntimeThreadRow.fetchOne(db, key: id) != nil else { + throw AgentRuntimeError.threadNotFound(id) + } + + return try Self.fetchHistoryPage( + threadID: id, + query: query, + in: db + ) + } + } + + public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { + let summary = try await fetchThreadSummary(id: id) + return summary.latestStructuredOutputMetadata + } + + public func execute(_ query: Query) async throws -> Query.Result { + try await ensurePrepared() + + if let historyQuery = query as? HistoryItemsQuery { + return try await executeHistoryQuery(historyQuery) as! Query.Result + } + if let threadQuery = query as? ThreadMetadataQuery { + return try await executeThreadQuery(threadQuery) as! Query.Result + } + if let pendingQuery = query as? PendingStateQuery { + return try await executePendingStateQuery(pendingQuery) as! Query.Result + } + if let structuredQuery = query as? StructuredOutputQuery { + return try await executeStructuredOutputQuery(structuredQuery) as! Query.Result + } + if let snapshotQuery = query as? ThreadSnapshotQuery { + return try await executeThreadSnapshotQuery(snapshotQuery) as! Query.Result + } + + let state = try await loadState() + return try query.execute(in: state) + } + + private func ensurePrepared() async throws { + if isPrepared { + return + } + + let version = try await readUserVersion() + guard version <= Self.currentStoreSchemaVersion else { + throw AgentStoreError.migrationFailed( + "Unsupported future GRDB runtime store schema version \(version)." + ) + } + + try migrator.migrate(dbQueue) + if try await shouldImportLegacyState() { + try await importLegacyState() + } + isPrepared = true + } + + private func executeHistoryQuery(_ query: HistoryItemsQuery) async throws -> AgentHistoryQueryResult { + try await dbQueue.read { db in + guard let threadRow = try RuntimeThreadRow.fetchOne(db, key: query.threadID) else { + return AgentHistoryQueryResult( + threadID: query.threadID, + records: [], + nextCursor: nil, + previousCursor: nil, + hasMoreBefore: false, + hasMoreAfter: false + ) + } + + let thread = try Self.decodeThread(from: threadRow) + let history = try Self.fetchHistoryRows( + threadID: query.threadID, + kinds: query.kinds, + createdAtRange: query.createdAtRange, + turnID: query.turnID, + includeRedacted: query.includeRedacted, + in: db + ) + + let state = StoredRuntimeState( + threads: [thread], + historyByThread: [query.threadID: history] + ) + return try state.execute(query) + } + } + + private func executeThreadQuery(_ query: ThreadMetadataQuery) async throws -> [AgentThread] { + try await dbQueue.read { db in + let rows = try RuntimeThreadRow.fetchAll( + db, + sql: Self.threadMetadataSQL(for: query), + arguments: StatementArguments(Self.threadMetadataArguments(for: query)) + ) + return try rows.map { try Self.decodeThread(from: $0) } + } + } + + private func executePendingStateQuery(_ query: PendingStateQuery) async throws -> [AgentPendingStateRecord] { + try await dbQueue.read { db in + let summaries = try RuntimeSummaryRow.fetchAll( + db, + sql: Self.pendingStateSQL(for: query), + arguments: StatementArguments(Self.pendingStateArguments(for: query)) + ) + let records = try summaries.compactMap { row -> AgentPendingStateRecord? in + let summary = try Self.decodeSummary(from: row) + guard let pendingState = summary.pendingState else { + return nil + } + return AgentPendingStateRecord( + threadID: summary.threadID, + pendingState: pendingState, + updatedAt: summary.updatedAt + ) + } + return records + } + } + + private func executeStructuredOutputQuery(_ query: StructuredOutputQuery) async throws -> [AgentStructuredOutputRecord] { + try await dbQueue.read { db in + var records = try RuntimeStructuredOutputRow.fetchAll( + db, + sql: Self.structuredOutputSQL(for: query), + arguments: StatementArguments(Self.structuredOutputArguments(for: query)) + ).map { try Self.decodeStructuredOutputRecord(from: $0) } + + if query.latestOnly { + var seen = Set() + records = records.filter { seen.insert($0.threadID).inserted } + } + + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } + } + + private func executeThreadSnapshotQuery(_ query: ThreadSnapshotQuery) async throws -> [AgentThreadSnapshot] { + try await dbQueue.read { db in + let snapshots = try RuntimeSummaryRow.fetchAll( + db, + sql: Self.threadSnapshotSQL(for: query), + arguments: StatementArguments(Self.threadSnapshotArguments(for: query)) + ) + .map { try Self.decodeSummary(from: $0) } + .map(\.snapshot) + return snapshots + } + } + + private static func replaceDatabaseContents( + with normalized: StoredRuntimeState, + in db: Database + ) throws { + let threadRows = try normalized.threads.map(Self.makeThreadRow) + let summaryRows = try normalized.threads.compactMap { thread -> RuntimeSummaryRow? in + guard let summary = normalized.summariesByThread[thread.id] else { + return nil + } + return try Self.makeSummaryRow(from: summary) + } + let historyRows = try normalized.historyByThread.values + .flatMap { $0 } + .map(Self.makeHistoryRow) + let structuredOutputRows = try Self.structuredOutputRows(from: normalized.historyByThread) + + try RuntimeStructuredOutputRow.deleteAll(db) + try RuntimeHistoryRow.deleteAll(db) + try RuntimeSummaryRow.deleteAll(db) + try RuntimeThreadRow.deleteAll(db) + + for row in threadRows { + try row.insert(db) + } + for row in summaryRows { + try row.insert(db) + } + for row in historyRows { + try row.insert(db) + } + for row in structuredOutputRows { + try row.insert(db) + } + } + + private func shouldImportLegacyState() async throws -> Bool { + guard let legacyStateURL else { + return false + } + guard legacyStateURL != url else { + return false + } + guard FileManager.default.fileExists(atPath: legacyStateURL.path) else { + return false + } + guard !databaseExistedAtInitialization else { + return false + } + + let threadCount = try await dbQueue.read { db in + try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM \(RuntimeThreadRow.databaseTableName)") ?? 0 + } + return threadCount == 0 + } + + private func importLegacyState() async throws { + guard let legacyStateURL else { + return + } + + let legacyStore = FileRuntimeStateStore(url: legacyStateURL) + let state = try await legacyStore.loadState().normalized() + guard !state.threads.isEmpty || !state.historyByThread.isEmpty else { + return + } + + try await dbQueue.write { db in + try Self.replaceDatabaseContents(with: state, in: db) + } + } + + private static func loadPartialState( + for threadIDs: Set, + from db: Database + ) throws -> StoredRuntimeState { + guard !threadIDs.isEmpty else { + return .empty + } + + let ids = Array(threadIDs) + let placeholders = sqlPlaceholders(count: ids.count) + let threadRows = try RuntimeThreadRow.fetchAll( + db, + sql: "SELECT * FROM \(RuntimeThreadRow.databaseTableName) WHERE threadID IN \(placeholders)", + arguments: StatementArguments(ids) + ) + let summaryRows = try RuntimeSummaryRow.fetchAll( + db, + sql: "SELECT * FROM \(RuntimeSummaryRow.databaseTableName) WHERE threadID IN \(placeholders)", + arguments: StatementArguments(ids) + ) + let historyRows = try RuntimeHistoryRow.fetchAll( + db, + sql: """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE threadID IN \(placeholders) + ORDER BY threadID ASC, sequenceNumber ASC + """, + arguments: StatementArguments(ids) + ) + + let threads = try threadRows.map { try Self.decodeThread(from: $0) } + let summaries = try Dictionary( + uniqueKeysWithValues: summaryRows.map { ($0.threadID, try Self.decodeSummary(from: $0)) } + ) + let history = try Dictionary( + grouping: historyRows.map { try Self.decodeHistoryRecord(from: $0) }, + by: \.item.threadID + ) + let nextSequence = history.mapValues { ($0.last?.sequenceNumber ?? 0) + 1 } + + return StoredRuntimeState( + threads: threads, + historyByThread: history, + summariesByThread: summaries, + nextHistorySequenceByThread: nextSequence + ) + } + + private static func persistThreads( + ids threadIDs: Set, + from state: StoredRuntimeState, + in db: Database + ) throws { + let normalized = state.normalized() + let threads = normalized.threads.filter { threadIDs.contains($0.id) } + guard !threads.isEmpty else { + return + } + + for thread in threads { + try Self.makeThreadRow(from: thread).insert(db) + if let summary = normalized.summariesByThread[thread.id] { + try Self.makeSummaryRow(from: summary).insert(db) + } + for record in normalized.historyByThread[thread.id] ?? [] { + try Self.makeHistoryRow(from: record).insert(db) + } + } + + for row in try Self.structuredOutputRows( + from: normalized.historyByThread.filter { threadIDs.contains($0.key) } + ) { + try row.insert(db) + } + } + + private static func deletePersistedThread( + _ threadID: String, + in db: Database + ) throws { + try db.execute( + sql: "DELETE FROM \(RuntimeThreadRow.databaseTableName) WHERE threadID = ?", + arguments: [threadID] + ) + } + + private static func fetchHistoryRows( + threadID: String, + kinds: Set?, + createdAtRange: ClosedRange?, + turnID: String?, + includeRedacted: Bool, + in db: Database + ) throws -> [AgentHistoryRecord] { + var clauses = ["threadID = ?"] + var arguments: [any DatabaseValueConvertible] = [threadID] + + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + arguments.append(contentsOf: kinds.map(\.rawValue)) + } + if let createdAtRange { + clauses.append("createdAt >= ?") + clauses.append("createdAt <= ?") + arguments.append(createdAtRange.lowerBound.timeIntervalSince1970) + arguments.append(createdAtRange.upperBound.timeIntervalSince1970) + } + if let turnID { + clauses.append("turnID = ?") + arguments.append(turnID) + } + if !includeRedacted { + clauses.append("isRedacted = 0") + } + + let sql = """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ORDER BY sequenceNumber ASC + """ + return try RuntimeHistoryRow.fetchAll( + db, + sql: sql, + arguments: StatementArguments(arguments) + ).map { try Self.decodeHistoryRecord(from: $0) } + } + + private static func fetchHistoryPage( + threadID: String, + query: AgentHistoryQuery, + in db: Database + ) throws -> AgentThreadHistoryPage { + let limit = max(1, query.limit) + let kinds = historyKinds(from: query.filter) + let anchor = try decodeCursorSequence(query.cursor, expectedThreadID: threadID) + + switch query.direction { + case .backward: + var clauses = ["threadID = ?"] + var arguments: [any DatabaseValueConvertible] = [threadID] + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + for kind in kinds { arguments.append(kind.rawValue) } + } + if let anchor { + clauses.append("sequenceNumber < ?") + arguments.append(anchor) + } + + let sql = """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ORDER BY sequenceNumber DESC + LIMIT \(limit + 1) + """ + let fetched = try RuntimeHistoryRow.fetchAll( + db, + sql: sql, + arguments: StatementArguments(arguments) + ) + let hasMoreBefore = fetched.count > limit + let pageRowsDescending = Array(fetched.prefix(limit)) + let pageRecords = try pageRowsDescending + .map { try Self.decodeHistoryRecord(from: $0) } + .reversed() + + let hasMoreAfter: Bool + if let anchor { + hasMoreAfter = try historyRecordExists( + threadID: threadID, + kinds: kinds, + comparator: "sequenceNumber >= ?", + value: anchor, + in: db + ) + } else { + hasMoreAfter = false + } + + return AgentThreadHistoryPage( + threadID: threadID, + items: pageRecords.map(\.item), + nextCursor: hasMoreBefore ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + previousCursor: hasMoreAfter ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + + case .forward: + var clauses = ["threadID = ?"] + var arguments: [any DatabaseValueConvertible] = [threadID] + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + for kind in kinds { arguments.append(kind.rawValue) } + } + if let anchor { + clauses.append("sequenceNumber > ?") + arguments.append(anchor) + } + + let sql = """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ORDER BY sequenceNumber ASC + LIMIT \(limit + 1) + """ + let fetched = try RuntimeHistoryRow.fetchAll( + db, + sql: sql, + arguments: StatementArguments(arguments) + ) + let hasMoreAfter = fetched.count > limit + let pageRows = Array(fetched.prefix(limit)) + let pageRecords = try pageRows.map { try Self.decodeHistoryRecord(from: $0) } + + let hasMoreBefore: Bool + if let anchor { + hasMoreBefore = try historyRecordExists( + threadID: threadID, + kinds: kinds, + comparator: "sequenceNumber <= ?", + value: anchor, + in: db + ) + } else { + hasMoreBefore = false + } + + return AgentThreadHistoryPage( + threadID: threadID, + items: pageRecords.map(\.item), + nextCursor: hasMoreAfter ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + previousCursor: hasMoreBefore ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + } + } + + private static func historyRecordExists( + threadID: String, + kinds: Set?, + comparator: String, + value: Int, + in db: Database + ) throws -> Bool { + var clauses = ["threadID = ?", comparator] + var arguments: [any DatabaseValueConvertible] = [threadID, value] + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + for kind in kinds { arguments.append(kind.rawValue) } + } + + let sql = """ + SELECT EXISTS( + SELECT 1 FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ) + """ + return try Bool.fetchOne(db, sql: sql, arguments: StatementArguments(arguments)) ?? false + } + + private static func threadMetadataSQL(for query: ThreadMetadataQuery) -> String { + var sql = "SELECT * FROM \(RuntimeThreadRow.databaseTableName)" + var clauses: [String] = [] + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + clauses.append("threadID IN \(sqlPlaceholders(count: threadIDs.count))") + } + if let statuses = query.statuses, !statuses.isEmpty { + clauses.append("status IN \(sqlPlaceholders(count: statuses.count))") + } + if query.updatedAtRange != nil { + clauses.append("updatedAt >= ?") + clauses.append("updatedAt <= ?") + } + if !clauses.isEmpty { + sql += " WHERE " + clauses.joined(separator: " AND ") + } + sql += " ORDER BY " + threadSortClause(query.sort) + if let limit = query.limit { + sql += " LIMIT \(max(0, limit))" + } + return sql + } + + private static func threadMetadataArguments(for query: ThreadMetadataQuery) -> [any DatabaseValueConvertible] { + var arguments: [any DatabaseValueConvertible] = [] + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + for threadID in threadIDs { + arguments.append(threadID) + } + } + if let statuses = query.statuses, !statuses.isEmpty { + for status in statuses { + arguments.append(status.rawValue) + } + } + if let range = query.updatedAtRange { + arguments.append(range.lowerBound.timeIntervalSince1970) + arguments.append(range.upperBound.timeIntervalSince1970) + } + return arguments + } + + private static func pendingStateSQL(for query: PendingStateQuery) -> String { + var sql = "SELECT * FROM \(RuntimeSummaryRow.databaseTableName) WHERE pendingStateKind IS NOT NULL" + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + sql += " AND threadID IN \(sqlPlaceholders(count: threadIDs.count))" + } + if let kinds = query.kinds, !kinds.isEmpty { + sql += " AND pendingStateKind IN \(sqlPlaceholders(count: kinds.count))" + } + sql += " ORDER BY updatedAt " + sortDirection(for: query.sort) + if let limit = query.limit { + sql += " LIMIT \(max(0, limit))" + } + return sql + } + + private static func pendingStateArguments(for query: PendingStateQuery) -> [any DatabaseValueConvertible] { + var arguments: [any DatabaseValueConvertible] = [] + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + for threadID in threadIDs { + arguments.append(threadID) + } + } + if let kinds = query.kinds, !kinds.isEmpty { + for kind in kinds { + arguments.append(kind.rawValue) + } + } + return arguments + } + + private static func structuredOutputSQL(for query: StructuredOutputQuery) -> String { + var sql = "SELECT * FROM \(RuntimeStructuredOutputRow.databaseTableName)" + var clauses: [String] = [] + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + clauses.append("threadID IN \(sqlPlaceholders(count: threadIDs.count))") + } + if let formatNames = query.formatNames, !formatNames.isEmpty { + clauses.append("formatName IN \(sqlPlaceholders(count: formatNames.count))") + } + if !clauses.isEmpty { + sql += " WHERE " + clauses.joined(separator: " AND ") + } + sql += " ORDER BY committedAt " + sortDirection(for: query.sort) + if let limit = query.limit, !query.latestOnly { + sql += " LIMIT \(max(0, limit))" + } + return sql + } + + private static func structuredOutputArguments(for query: StructuredOutputQuery) -> [any DatabaseValueConvertible] { + var arguments: [any DatabaseValueConvertible] = [] + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + for threadID in threadIDs { + arguments.append(threadID) + } + } + if let formatNames = query.formatNames, !formatNames.isEmpty { + for formatName in formatNames { + arguments.append(formatName) + } + } + return arguments + } + + private static func threadSnapshotSQL(for query: ThreadSnapshotQuery) -> String { + var sql = "SELECT * FROM \(RuntimeSummaryRow.databaseTableName)" + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + sql += " WHERE threadID IN \(sqlPlaceholders(count: threadIDs.count))" + } + sql += " ORDER BY " + snapshotSortClause(query.sort) + if let limit = query.limit { + sql += " LIMIT \(max(0, limit))" + } + return sql + } + + private static func threadSnapshotArguments(for query: ThreadSnapshotQuery) -> [any DatabaseValueConvertible] { + guard let threadIDs = query.threadIDs, !threadIDs.isEmpty else { + return [] + } + return threadIDs.map { $0 as any DatabaseValueConvertible } + } + + private static func defaultLegacyImportURL(for url: URL) -> URL { + url.deletingPathExtension().appendingPathExtension("json") + } + + private static func sqlPlaceholders(count: Int) -> String { + "(" + Array(repeating: "?", count: count).joined(separator: ", ") + ")" + } + + private static func threadSortClause(_ sort: AgentThreadMetadataSort) -> String { + switch sort { + case let .updatedAt(order): + "updatedAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" + case let .createdAt(order): + "createdAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" + } + } + + private static func snapshotSortClause(_ sort: AgentThreadSnapshotSort) -> String { + switch sort { + case let .updatedAt(order): + "updatedAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" + case let .createdAt(order): + "createdAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" + } + } + + private static func sortDirection(for sort: AgentPendingStateSort) -> String { + switch sort { + case let .updatedAt(order): + order == .ascending ? "ASC" : "DESC" + } + } + + private static func sortDirection(for sort: AgentStructuredOutputSort) -> String { + switch sort { + case let .committedAt(order): + order == .ascending ? "ASC" : "DESC" + } + } + + private static func historyKinds(from filter: AgentHistoryFilter?) -> Set? { + guard let filter else { + return nil + } + + var kinds: Set = [] + if filter.includeMessages { kinds.insert(.message) } + if filter.includeToolCalls { kinds.insert(.toolCall) } + if filter.includeToolResults { kinds.insert(.toolResult) } + if filter.includeStructuredOutputs { kinds.insert(.structuredOutput) } + if filter.includeApprovals { kinds.insert(.approval) } + if filter.includeSystemEvents { kinds.insert(.systemEvent) } + return kinds + } + + private static func makeCursor(threadID: String, sequenceNumber: Int?) -> AgentHistoryCursor? { + guard let sequenceNumber else { + return nil + } + + let payload = GRDBHistoryCursorPayload( + version: 1, + threadID: threadID, + sequenceNumber: sequenceNumber + ) + let data = (try? JSONEncoder().encode(payload)) ?? Data() + let base64 = data.base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + return AgentHistoryCursor(rawValue: base64) + } + + private static func decodeCursorSequence( + _ cursor: AgentHistoryCursor?, + expectedThreadID: String + ) throws -> Int? { + guard let cursor else { + return nil + } + + let padded = cursor.rawValue + .replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let remainder = padded.count % 4 + let adjusted = padded + String(repeating: "=", count: remainder == 0 ? 0 : 4 - remainder) + + guard let data = Data(base64Encoded: adjusted) else { + throw AgentRuntimeError.invalidHistoryCursor() + } + + let payload = try JSONDecoder().decode(GRDBHistoryCursorPayload.self, from: data) + guard payload.threadID == expectedThreadID else { + throw AgentRuntimeError.invalidHistoryCursor() + } + return payload.sequenceNumber + } + + private static func makeThreadRow(from thread: AgentThread) throws -> RuntimeThreadRow { + RuntimeThreadRow( + threadID: thread.id, + createdAt: thread.createdAt.timeIntervalSince1970, + updatedAt: thread.updatedAt.timeIntervalSince1970, + status: thread.status.rawValue, + encodedThread: try JSONEncoder().encode(thread) + ) + } + + private static func makeSummaryRow(from summary: AgentThreadSummary) throws -> RuntimeSummaryRow { + RuntimeSummaryRow( + threadID: summary.threadID, + createdAt: summary.createdAt.timeIntervalSince1970, + updatedAt: summary.updatedAt.timeIntervalSince1970, + latestItemAt: summary.latestItemAt?.timeIntervalSince1970, + itemCount: summary.itemCount, + pendingStateKind: summary.pendingState?.kind.rawValue, + latestStructuredOutputFormatName: summary.latestStructuredOutputMetadata?.formatName, + encodedSummary: try JSONEncoder().encode(summary) + ) + } + + private static func makeHistoryRow(from record: AgentHistoryRecord) throws -> RuntimeHistoryRow { + RuntimeHistoryRow( + storageID: "\(record.item.threadID):\(record.sequenceNumber)", + recordID: record.id, + threadID: record.item.threadID, + sequenceNumber: record.sequenceNumber, + createdAt: record.createdAt.timeIntervalSince1970, + kind: record.item.kind.rawValue, + turnID: record.item.turnID, + isRedacted: record.redaction != nil, + encodedRecord: try JSONEncoder().encode(record) + ) + } + + private static func structuredOutputRows( + from historyByThread: [String: [AgentHistoryRecord]] + ) throws -> [RuntimeStructuredOutputRow] { + try historyByThread.values + .flatMap { $0 } + .compactMap { record -> RuntimeStructuredOutputRow? in + switch record.item { + case let .structuredOutput(output): + return try Self.makeStructuredOutputRow( + id: "structured:\(record.id)", + record: output + ) + + case let .message(message): + guard let metadata = message.structuredOutput else { + return nil + } + return try Self.makeStructuredOutputRow( + id: "message:\(message.id)", + record: AgentStructuredOutputRecord( + threadID: message.threadID, + turnID: "", + messageID: message.id, + metadata: metadata, + committedAt: message.createdAt + ) + ) + + default: + return nil + } + } + } + + private static func makeStructuredOutputRow( + id: String, + record: AgentStructuredOutputRecord + ) throws -> RuntimeStructuredOutputRow { + RuntimeStructuredOutputRow( + outputID: id, + threadID: record.threadID, + formatName: record.metadata.formatName, + committedAt: record.committedAt.timeIntervalSince1970, + encodedRecord: try JSONEncoder().encode(record) + ) + } + + private static func decodeThread(from row: RuntimeThreadRow) throws -> AgentThread { + try JSONDecoder().decode(AgentThread.self, from: row.encodedThread) + } + + private static func decodeSummary(from row: RuntimeSummaryRow) throws -> AgentThreadSummary { + try JSONDecoder().decode(AgentThreadSummary.self, from: row.encodedSummary) + } + + private static func decodeHistoryRecord(from row: RuntimeHistoryRow) throws -> AgentHistoryRecord { + try JSONDecoder().decode(AgentHistoryRecord.self, from: row.encodedRecord) + } + + private static func decodeStructuredOutputRecord(from row: RuntimeStructuredOutputRow) throws -> AgentStructuredOutputRecord { + try JSONDecoder().decode(AgentStructuredOutputRecord.self, from: row.encodedRecord) + } + + private static func makeMigrator() -> DatabaseMigrator { + var migrator = DatabaseMigrator() + + migrator.registerMigration("runtime_store_v1") { db in + try db.create(table: RuntimeThreadRow.databaseTableName) { table in + table.column("threadID", .text).primaryKey() + table.column("createdAt", .double).notNull() + table.column("updatedAt", .double).notNull() + table.column("status", .text).notNull() + table.column("encodedThread", .blob).notNull() + } + + try db.create(table: RuntimeSummaryRow.databaseTableName) { table in + table.column("threadID", .text) + .primaryKey() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("createdAt", .double).notNull() + table.column("updatedAt", .double).notNull() + table.column("latestItemAt", .double) + table.column("itemCount", .integer) + table.column("pendingStateKind", .text) + table.column("latestStructuredOutputFormatName", .text) + table.column("encodedSummary", .blob).notNull() + } + + try db.create(table: RuntimeHistoryRow.databaseTableName) { table in + table.column("storageID", .text).primaryKey() + table.column("recordID", .text).notNull() + table.column("threadID", .text) + .notNull() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("sequenceNumber", .integer).notNull() + table.column("createdAt", .double).notNull() + table.column("kind", .text).notNull() + table.column("turnID", .text) + table.column("isRedacted", .boolean).notNull().defaults(to: false) + table.column("encodedRecord", .blob).notNull() + } + + try db.create(index: "runtime_history_thread_sequence", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "sequenceNumber"], unique: true) + try db.create(index: "runtime_history_thread_created_at", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "createdAt"]) + try db.create(index: "runtime_history_thread_kind", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "kind"]) + try db.create(index: "runtime_history_thread_record_id", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "recordID"]) + + try db.create(table: RuntimeStructuredOutputRow.databaseTableName) { table in + table.column("outputID", .text).primaryKey() + table.column("threadID", .text) + .notNull() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("formatName", .text).notNull() + table.column("committedAt", .double).notNull() + table.column("encodedRecord", .blob).notNull() + } + + try db.create(index: "runtime_structured_outputs_thread_committed_at", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["threadID", "committedAt"]) + try db.create(index: "runtime_structured_outputs_format_name", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["formatName"]) + + try db.execute(sql: "PRAGMA user_version = \(currentStoreSchemaVersion)") + } + + return migrator + } + + private static func sortPendingStateRecords( + _ records: [AgentPendingStateRecord], + using sort: AgentPendingStateSort + ) -> [AgentPendingStateRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + } + } + } + + private static func sortStructuredOutputRecords( + _ records: [AgentStructuredOutputRecord], + using sort: AgentStructuredOutputSort + ) -> [AgentStructuredOutputRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .committedAt(order): + if lhs.committedAt == rhs.committedAt { + return (lhs.messageID ?? lhs.turnID) < (rhs.messageID ?? rhs.turnID) + } + return order == .ascending ? lhs.committedAt < rhs.committedAt : lhs.committedAt > rhs.committedAt + } + } + } + + private static func sortThreadSnapshots( + _ snapshots: [AgentThreadSnapshot], + using sort: AgentThreadSnapshotSort + ) -> [AgentThreadSnapshot] { + snapshots.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + case let .createdAt(order): + if lhs.createdAt == rhs.createdAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt + } + } + } + + private func readUserVersion() async throws -> Int { + try await dbQueue.read { db in + try Int.fetchOne(db, sql: "PRAGMA user_version;") ?? 0 + } + } +} + +private struct RuntimeThreadRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_threads" + + let threadID: String + let createdAt: Double + let updatedAt: Double + let status: String + let encodedThread: Data +} + +private struct RuntimeSummaryRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_summaries" + + let threadID: String + let createdAt: Double + let updatedAt: Double + let latestItemAt: Double? + let itemCount: Int? + let pendingStateKind: String? + let latestStructuredOutputFormatName: String? + let encodedSummary: Data +} + +private struct RuntimeHistoryRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_history_items" + + let storageID: String + let recordID: String + let threadID: String + let sequenceNumber: Int + let createdAt: Double + let kind: String + let turnID: String? + let isRedacted: Bool + let encodedRecord: Data +} + +private struct RuntimeStructuredOutputRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_structured_outputs" + + let outputID: String + let threadID: String + let formatName: String + let committedAt: Double + let encodedRecord: Data +} + +private struct GRDBHistoryCursorPayload: Codable { + let version: Int + let threadID: String + let sequenceNumber: Int +} + +private extension AgentHistoryItem { + var threadID: String { + switch self { + case let .message(message): + message.threadID + case let .toolCall(record): + record.invocation.threadID + case let .toolResult(record): + record.threadID + case let .structuredOutput(record): + record.threadID + case let .approval(record): + record.request?.threadID ?? record.resolution?.threadID ?? "" + case let .systemEvent(record): + record.threadID + } + } +} diff --git a/Sources/CodexKit/Runtime/RuntimeStateStore.swift b/Sources/CodexKit/Runtime/RuntimeStateStore.swift index f591074..650aec8 100644 --- a/Sources/CodexKit/Runtime/RuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/RuntimeStateStore.swift @@ -3,28 +3,118 @@ import Foundation public struct StoredRuntimeState: Codable, Hashable, Sendable { public var threads: [AgentThread] public var messagesByThread: [String: [AgentMessage]] + public var historyByThread: [String: [AgentHistoryRecord]] + public var summariesByThread: [String: AgentThreadSummary] + public var nextHistorySequenceByThread: [String: Int] public init( threads: [AgentThread] = [], - messagesByThread: [String: [AgentMessage]] = [:] + messagesByThread: [String: [AgentMessage]] = [:], + historyByThread: [String: [AgentHistoryRecord]] = [:], + summariesByThread: [String: AgentThreadSummary] = [:], + nextHistorySequenceByThread: [String: Int] = [:] + ) { + self.init( + threads: threads, + messagesByThread: messagesByThread, + historyByThread: historyByThread, + summariesByThread: summariesByThread, + nextHistorySequenceByThread: nextHistorySequenceByThread, + normalizeState: false + ) + self = normalized() + } + + init( + threads: [AgentThread], + messagesByThread: [String: [AgentMessage]], + historyByThread: [String: [AgentHistoryRecord]], + summariesByThread: [String: AgentThreadSummary], + nextHistorySequenceByThread: [String: Int], + normalizeState: Bool ) { self.threads = threads self.messagesByThread = messagesByThread + self.historyByThread = historyByThread + self.summariesByThread = summariesByThread + self.nextHistorySequenceByThread = nextHistorySequenceByThread + if normalizeState { + self = normalized() + } } public static let empty = StoredRuntimeState() + + enum CodingKeys: String, CodingKey { + case threads + case messagesByThread + case historyByThread + case summariesByThread + case nextHistorySequenceByThread + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.init( + threads: try container.decodeIfPresent([AgentThread].self, forKey: .threads) ?? [], + messagesByThread: try container.decodeIfPresent([String: [AgentMessage]].self, forKey: .messagesByThread) ?? [:], + historyByThread: try container.decodeIfPresent([String: [AgentHistoryRecord]].self, forKey: .historyByThread) ?? [:], + summariesByThread: try container.decodeIfPresent([String: AgentThreadSummary].self, forKey: .summariesByThread) ?? [:], + nextHistorySequenceByThread: try container.decodeIfPresent([String: Int].self, forKey: .nextHistorySequenceByThread) ?? [:] + ) + } } public protocol RuntimeStateStoring: Sendable { func loadState() async throws -> StoredRuntimeState func saveState(_ state: StoredRuntimeState) async throws + func prepare() async throws -> AgentStoreMetadata + func readMetadata() async throws -> AgentStoreMetadata + func apply(_ operations: [AgentStoreWriteOperation]) async throws } -public actor InMemoryRuntimeStateStore: RuntimeStateStoring { +public protocol RuntimeStateInspecting: Sendable { + func fetchThreadSummary(id: String) async throws -> AgentThreadSummary + func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage + func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? +} + +public extension RuntimeStateStoring { + func prepare() async throws -> AgentStoreMetadata { + _ = try await loadState() + return try await readMetadata() + } + + func readMetadata() async throws -> AgentStoreMetadata { + AgentStoreMetadata( + logicalSchemaVersion: .v1, + storeSchemaVersion: 1, + capabilities: AgentStoreCapabilities( + supportsPushdownQueries: false, + supportsCrossThreadQueries: true, + supportsSorting: true, + supportsFiltering: true, + supportsMigrations: false + ), + storeKind: String(describing: Self.self) + ) + } + + func apply(_ operations: [AgentStoreWriteOperation]) async throws { + let state = try await loadState() + let updated = try state.applying(operations) + try await saveState(updated) + } +} + +public actor InMemoryRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { private var state: StoredRuntimeState public init(initialState: StoredRuntimeState = .empty) { - state = initialState + state = initialState.normalized() } public func loadState() async throws -> StoredRuntimeState { @@ -32,38 +122,937 @@ public actor InMemoryRuntimeStateStore: RuntimeStateStoring { } public func saveState(_ state: StoredRuntimeState) async throws { - self.state = state + self.state = state.normalized() + } + + public func prepare() async throws -> AgentStoreMetadata { + state = state.normalized() + return try await readMetadata() + } + + public func readMetadata() async throws -> AgentStoreMetadata { + AgentStoreMetadata( + logicalSchemaVersion: .v1, + storeSchemaVersion: 1, + capabilities: AgentStoreCapabilities( + supportsPushdownQueries: true, + supportsCrossThreadQueries: true, + supportsSorting: true, + supportsFiltering: true, + supportsMigrations: false + ), + storeKind: "InMemoryRuntimeStateStore" + ) + } + + public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { + try state.threadSummary(id: id) + } + + public func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage { + try state.threadHistoryPage(id: id, query: query) + } + + public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { + try state.threadSummary(id: id).latestStructuredOutputMetadata } } -public actor FileRuntimeStateStore: RuntimeStateStoring { +public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { private let url: URL private let encoder = JSONEncoder() private let decoder = JSONDecoder() + private let fileManager = FileManager.default public init(url: URL) { self.url = url } public func loadState() async throws -> StoredRuntimeState { - guard FileManager.default.fileExists(atPath: url.path) else { + try loadNormalizedStateMigratingIfNeeded() + } + + public func saveState(_ state: StoredRuntimeState) async throws { + try persistLayout(for: state.normalized()) + } + + public func prepare() async throws -> AgentStoreMetadata { + _ = try loadNormalizedStateMigratingIfNeeded() + return try await readMetadata() + } + + public func readMetadata() async throws -> AgentStoreMetadata { + AgentStoreMetadata( + logicalSchemaVersion: .v1, + storeSchemaVersion: 1, + capabilities: AgentStoreCapabilities( + supportsPushdownQueries: false, + supportsCrossThreadQueries: false, + supportsSorting: true, + supportsFiltering: true, + supportsMigrations: true + ), + storeKind: "FileRuntimeStateStore" + ) + } + + public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { + if let manifest = try loadManifest() { + guard let thread = manifest.threads.first(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + return manifest.summariesByThread[id] + ?? StoredRuntimeState(threads: [thread]).threadSummaryFallback(for: thread) + } + + return try loadNormalizedStateMigratingIfNeeded().threadSummary(id: id) + } + + public func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage { + if let manifest = try loadManifest() { + guard manifest.threads.contains(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + + let history = try loadHistory(for: id) + let state = StoredRuntimeState( + threads: manifest.threads, + historyByThread: [id: history], + summariesByThread: manifest.summariesByThread, + nextHistorySequenceByThread: manifest.nextHistorySequenceByThread + ) + return try state.threadHistoryPage(id: id, query: query) + } + + return try loadNormalizedStateMigratingIfNeeded().threadHistoryPage(id: id, query: query) + } + + public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { + let summary = try await fetchThreadSummary(id: id) + return summary.latestStructuredOutputMetadata + } + + private func loadNormalizedStateMigratingIfNeeded() throws -> StoredRuntimeState { + guard fileManager.fileExists(atPath: url.path) else { return .empty } + if let manifest = try loadManifest() { + return try state(from: manifest) + } + let data = try Data(contentsOf: url) - return try decoder.decode(StoredRuntimeState.self, from: data) + let legacy = try decoder.decode(StoredRuntimeState.self, from: data).normalized() + try persistLayout(for: legacy) + return legacy } - public func saveState(_ state: StoredRuntimeState) async throws { + private func loadManifest() throws -> FileRuntimeStateManifest? { + guard fileManager.fileExists(atPath: url.path) else { + return nil + } + + let data = try Data(contentsOf: url) + return try? decoder.decode(FileRuntimeStateManifest.self, from: data) + } + + private func state(from manifest: FileRuntimeStateManifest) throws -> StoredRuntimeState { + var historyByThread: [String: [AgentHistoryRecord]] = [:] + for thread in manifest.threads { + historyByThread[thread.id] = try loadHistory(for: thread.id) + } + + return StoredRuntimeState( + threads: manifest.threads, + historyByThread: historyByThread, + summariesByThread: manifest.summariesByThread, + nextHistorySequenceByThread: manifest.nextHistorySequenceByThread + ) + } + + private func loadHistory(for threadID: String) throws -> [AgentHistoryRecord] { + let historyURL = historyFileURL(for: threadID) + guard fileManager.fileExists(atPath: historyURL.path) else { + return [] + } + + let data = try Data(contentsOf: historyURL) + return try decoder.decode([AgentHistoryRecord].self, from: data) + } + + private func persistLayout(for state: StoredRuntimeState) throws { + let normalized = state.normalized() let directory = url.deletingLastPathComponent() if !directory.path.isEmpty { - try FileManager.default.createDirectory( + try fileManager.createDirectory( at: directory, withIntermediateDirectories: true ) } - let data = try encoder.encode(state) - try data.write(to: url, options: .atomic) + try fileManager.createDirectory( + at: historyDirectoryURL, + withIntermediateDirectories: true + ) + + for thread in normalized.threads { + let historyURL = historyFileURL(for: thread.id) + let history = normalized.historyByThread[thread.id] ?? [] + let data = try encoder.encode(history) + try data.write(to: historyURL, options: .atomic) + } + + let manifest = FileRuntimeStateManifest( + threads: normalized.threads, + summariesByThread: normalized.summariesByThread, + nextHistorySequenceByThread: normalized.nextHistorySequenceByThread + ) + let manifestData = try encoder.encode(manifest) + try manifestData.write(to: url, options: .atomic) + } + + private var historyDirectoryURL: URL { + let basename = url.deletingPathExtension().lastPathComponent + return url.deletingLastPathComponent() + .appendingPathComponent("\(basename).codexkit-state", isDirectory: true) + .appendingPathComponent("threads", isDirectory: true) + } + + private func historyFileURL(for threadID: String) -> URL { + historyDirectoryURL.appendingPathComponent(safeThreadFilename(threadID)).appendingPathExtension("json") + } + + private func safeThreadFilename(_ threadID: String) -> String { + threadID.addingPercentEncoding(withAllowedCharacters: .alphanumerics) ?? threadID + } +} + +private struct FileRuntimeStateManifest: Codable { + let storageVersion: Int + let threads: [AgentThread] + let summariesByThread: [String: AgentThreadSummary] + let nextHistorySequenceByThread: [String: Int] + + init( + threads: [AgentThread], + summariesByThread: [String: AgentThreadSummary], + nextHistorySequenceByThread: [String: Int] + ) { + self.storageVersion = 1 + self.threads = threads + self.summariesByThread = summariesByThread + self.nextHistorySequenceByThread = nextHistorySequenceByThread + } +} + +extension StoredRuntimeState { + func normalized() -> StoredRuntimeState { + let sortedThreads = threads.sorted { lhs, rhs in + if lhs.updatedAt == rhs.updatedAt { + return lhs.id < rhs.id + } + return lhs.updatedAt > rhs.updatedAt + } + + var normalizedHistory = historyByThread + .mapValues { records in + records.sorted { lhs, rhs in + if lhs.sequenceNumber == rhs.sequenceNumber { + return lhs.createdAt < rhs.createdAt + } + return lhs.sequenceNumber < rhs.sequenceNumber + } + } + + for (threadID, messages) in messagesByThread where normalizedHistory[threadID]?.isEmpty != false { + normalizedHistory[threadID] = Self.syntheticHistory(from: messages) + } + + let normalizedMessages: [String: [AgentMessage]] = normalizedHistory.mapValues { records in + records.compactMap { record -> AgentMessage? in + guard case let .message(message) = record.item else { + return nil + } + return message + } + } + + var normalizedNextSequence = nextHistorySequenceByThread + for thread in sortedThreads { + let history = normalizedHistory[thread.id] ?? [] + let nextSequence = (history.last?.sequenceNumber ?? 0) + 1 + normalizedNextSequence[thread.id] = max(normalizedNextSequence[thread.id] ?? 0, nextSequence) + } + + var normalizedSummaries: [String: AgentThreadSummary] = [:] + for thread in sortedThreads { + let history = normalizedHistory[thread.id] ?? [] + normalizedSummaries[thread.id] = Self.rebuildSummary( + for: thread, + history: history, + existing: summariesByThread[thread.id] + ) + } + + return StoredRuntimeState( + threads: sortedThreads, + messagesByThread: normalizedMessages, + historyByThread: normalizedHistory, + summariesByThread: normalizedSummaries, + nextHistorySequenceByThread: normalizedNextSequence, + normalizeState: false + ) + } + + func threadSummary(id: String) throws -> AgentThreadSummary { + guard let thread = threads.first(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + + return summariesByThread[id] ?? threadSummaryFallback(for: thread) + } + + func threadSummaryFallback(for thread: AgentThread) -> AgentThreadSummary { + Self.rebuildSummary( + for: thread, + history: historyByThread[thread.id] ?? [], + existing: summariesByThread[thread.id] + ) + } + + func threadHistoryPage( + id: String, + query: AgentHistoryQuery + ) throws -> AgentThreadHistoryPage { + guard threads.contains(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + + let limit = max(1, query.limit) + let filter = query.filter ?? AgentHistoryFilter() + let records = (historyByThread[id] ?? []).filter { filter.matches($0.item) } + let anchor = try query.cursor?.decodedSequenceNumber(expectedThreadID: id) + + switch query.direction { + case .backward: + let endIndex = records.endIndexForBackward(anchor: anchor) + let startIndex = max(0, endIndex - limit) + let pageRecords = Array(records[startIndex ..< endIndex]) + let hasMoreBefore = startIndex > 0 + let hasMoreAfter = endIndex < records.count + return AgentThreadHistoryPage( + threadID: id, + items: pageRecords.map(\.item), + nextCursor: hasMoreBefore ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + previousCursor: hasMoreAfter ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + + case .forward: + let startIndex = records.startIndexForForward(anchor: anchor) + let endIndex = min(records.count, startIndex + limit) + let pageRecords = Array(records[startIndex ..< endIndex]) + let hasMoreBefore = startIndex > 0 + let hasMoreAfter = endIndex < records.count + return AgentThreadHistoryPage( + threadID: id, + items: pageRecords.map(\.item), + nextCursor: hasMoreAfter ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + previousCursor: hasMoreBefore ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + } + } + + func applying(_ operations: [AgentStoreWriteOperation]) throws -> StoredRuntimeState { + var updated = self + + for operation in operations { + switch operation { + case let .upsertThread(thread): + if let index = updated.threads.firstIndex(where: { $0.id == thread.id }) { + updated.threads[index] = thread + } else { + updated.threads.append(thread) + } + + case let .upsertSummary(threadID, summary): + updated.summariesByThread[threadID] = summary + + case let .appendHistoryItems(threadID, items): + updated.historyByThread[threadID, default: []].append(contentsOf: items) + let nextSequence = (updated.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 + updated.nextHistorySequenceByThread[threadID] = nextSequence + + case let .setPendingState(threadID, state): + if let thread = updated.threads.first(where: { $0.id == threadID }) { + let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) + updated.summariesByThread[threadID] = AgentThreadSummary( + threadID: current.threadID, + createdAt: current.createdAt, + updatedAt: current.updatedAt, + latestItemAt: current.latestItemAt, + itemCount: current.itemCount, + latestAssistantMessagePreview: current.latestAssistantMessagePreview, + latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, + latestPartialStructuredOutput: current.latestPartialStructuredOutput, + latestToolState: current.latestToolState, + latestTurnStatus: current.latestTurnStatus, + pendingState: state + ) + } + + case let .setPartialStructuredSnapshot(threadID, snapshot): + if let thread = updated.threads.first(where: { $0.id == threadID }) { + let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) + updated.summariesByThread[threadID] = AgentThreadSummary( + threadID: current.threadID, + createdAt: current.createdAt, + updatedAt: current.updatedAt, + latestItemAt: current.latestItemAt, + itemCount: current.itemCount, + latestAssistantMessagePreview: current.latestAssistantMessagePreview, + latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, + latestPartialStructuredOutput: snapshot, + latestToolState: current.latestToolState, + latestTurnStatus: current.latestTurnStatus, + pendingState: current.pendingState + ) + } + + case let .upsertToolSession(threadID, session): + if let thread = updated.threads.first(where: { $0.id == threadID }) { + let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) + let latestToolState = AgentLatestToolState( + invocationID: session.invocationID, + turnID: session.turnID, + toolName: session.toolName, + status: .running, + success: nil, + sessionID: session.sessionID, + sessionStatus: session.sessionStatus, + metadata: session.metadata, + resumable: session.resumable, + updatedAt: session.updatedAt, + resultPreview: nil + ) + updated.summariesByThread[threadID] = AgentThreadSummary( + threadID: current.threadID, + createdAt: current.createdAt, + updatedAt: current.updatedAt, + latestItemAt: current.latestItemAt, + itemCount: current.itemCount, + latestAssistantMessagePreview: current.latestAssistantMessagePreview, + latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, + latestPartialStructuredOutput: current.latestPartialStructuredOutput, + latestToolState: latestToolState, + latestTurnStatus: current.latestTurnStatus, + pendingState: .toolWait( + AgentPendingToolWaitState( + invocationID: session.invocationID, + turnID: session.turnID, + toolName: session.toolName, + startedAt: session.updatedAt, + sessionID: session.sessionID, + sessionStatus: session.sessionStatus, + metadata: session.metadata, + resumable: session.resumable + ) + ) + ) + } + + case let .redactHistoryItems(threadID, itemIDs, reason): + guard !itemIDs.isEmpty else { + continue + } + updated.historyByThread[threadID] = updated.historyByThread[threadID]?.map { record in + guard itemIDs.contains(record.id) else { + return record + } + return record.redacted(reason: reason) + } + + case let .deleteThread(threadID): + updated.threads.removeAll { $0.id == threadID } + updated.messagesByThread.removeValue(forKey: threadID) + updated.historyByThread.removeValue(forKey: threadID) + updated.summariesByThread.removeValue(forKey: threadID) + updated.nextHistorySequenceByThread.removeValue(forKey: threadID) + } + } + + return updated.normalized() + } + + func execute(_ query: HistoryItemsQuery) throws -> AgentHistoryQueryResult { + guard threads.contains(where: { $0.id == query.threadID }) else { + return AgentHistoryQueryResult( + threadID: query.threadID, + records: [], + nextCursor: nil, + previousCursor: nil, + hasMoreBefore: false, + hasMoreAfter: false + ) + } + + var records = historyByThread[query.threadID] ?? [] + if let kinds = query.kinds { + records = records.filter { kinds.contains($0.item.kind) } + } + if let createdAtRange = query.createdAtRange { + records = records.filter { createdAtRange.contains($0.createdAt) } + } + if let turnID = query.turnID { + records = records.filter { $0.item.turnID == turnID } + } + if !query.includeRedacted { + records = records.filter { $0.redaction == nil } + } + + records = sort(records, using: query.sort) + let page = try page(records, threadID: query.threadID, with: query.page, sort: query.sort) + return page + } + + func execute(_ query: ThreadMetadataQuery) -> [AgentThread] { + var filtered = threads + if let threadIDs = query.threadIDs { + filtered = filtered.filter { threadIDs.contains($0.id) } + } + if let statuses = query.statuses { + filtered = filtered.filter { statuses.contains($0.status) } + } + if let updatedAtRange = query.updatedAtRange { + filtered = filtered.filter { updatedAtRange.contains($0.updatedAt) } + } + filtered = sort(filtered, using: query.sort) + if let limit = query.limit { + filtered = Array(filtered.prefix(max(0, limit))) + } + return filtered + } + + func execute(_ query: PendingStateQuery) -> [AgentPendingStateRecord] { + var records = summariesByThread.compactMap { threadID, summary -> AgentPendingStateRecord? in + guard let pendingState = summary.pendingState else { + return nil + } + return AgentPendingStateRecord( + threadID: threadID, + pendingState: pendingState, + updatedAt: summary.updatedAt + ) + } + + if let threadIDs = query.threadIDs { + records = records.filter { threadIDs.contains($0.threadID) } + } + if let kinds = query.kinds { + records = records.filter { kinds.contains($0.pendingState.kind) } + } + records = sort(records, using: query.sort) + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } + + func execute(_ query: StructuredOutputQuery) -> [AgentStructuredOutputRecord] { + var records = historyByThread.values + .flatMap { $0 } + .compactMap { record -> AgentStructuredOutputRecord? in + switch record.item { + case let .structuredOutput(structuredOutput): + return structuredOutput + + case let .message(message): + guard let metadata = message.structuredOutput else { + return nil + } + return AgentStructuredOutputRecord( + threadID: message.threadID, + turnID: "", + messageID: message.id, + metadata: metadata, + committedAt: message.createdAt + ) + + default: + return nil + } + } + + if let threadIDs = query.threadIDs { + records = records.filter { threadIDs.contains($0.threadID) } + } + if let formatNames = query.formatNames { + records = records.filter { formatNames.contains($0.metadata.formatName) } + } + + records = sort(records, using: query.sort) + + if query.latestOnly { + var seen = Set() + records = records.filter { record in + seen.insert(record.threadID).inserted + } + } + + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } + + func execute(_ query: ThreadSnapshotQuery) -> [AgentThreadSnapshot] { + var snapshots = threads.compactMap { thread -> AgentThreadSnapshot? in + guard query.threadIDs?.contains(thread.id) ?? true else { + return nil + } + let summary = summariesByThread[thread.id] ?? threadSummaryFallback(for: thread) + return summary.snapshot + } + snapshots = sort(snapshots, using: query.sort) + if let limit = query.limit { + snapshots = Array(snapshots.prefix(max(0, limit))) + } + return snapshots + } + + private static func syntheticHistory(from messages: [AgentMessage]) -> [AgentHistoryRecord] { + let orderedMessages = messages.enumerated().sorted { lhs, rhs in + let left = lhs.element + let right = rhs.element + if left.createdAt == right.createdAt { + return lhs.offset < rhs.offset + } + return left.createdAt < right.createdAt + } + + return orderedMessages.enumerated().map { index, pair in + AgentHistoryRecord( + sequenceNumber: index + 1, + createdAt: pair.element.createdAt, + item: .message(pair.element) + ) + } + } + + private static func rebuildSummary( + for thread: AgentThread, + history: [AgentHistoryRecord], + existing: AgentThreadSummary? + ) -> AgentThreadSummary { + var latestAssistantMessagePreview = existing?.latestAssistantMessagePreview + var latestStructuredOutputMetadata = existing?.latestStructuredOutputMetadata + var latestToolState = existing?.latestToolState + var latestTurnStatus = existing?.latestTurnStatus + let latestPartialStructuredOutput = existing?.latestPartialStructuredOutput + let pendingState = existing?.pendingState + + for record in history { + switch record.item { + case let .message(message): + if message.role == .assistant { + latestAssistantMessagePreview = message.displayText + if let structuredOutput = message.structuredOutput { + latestStructuredOutputMetadata = structuredOutput + } + } + + case let .toolCall(toolCall): + latestToolState = AgentLatestToolState( + invocationID: toolCall.invocation.id, + turnID: toolCall.invocation.turnID, + toolName: toolCall.invocation.toolName, + status: .waiting, + updatedAt: toolCall.requestedAt + ) + + case let .toolResult(toolResult): + latestToolState = Self.latestToolState(from: toolResult) + + case let .structuredOutput(structuredOutput): + latestStructuredOutputMetadata = structuredOutput.metadata + + case .approval: + break + + case let .systemEvent(systemEvent): + switch systemEvent.type { + case .turnStarted: + latestTurnStatus = .running + case .turnCompleted: + latestTurnStatus = .completed + case .turnFailed: + latestTurnStatus = .failed + case .threadCreated, .threadResumed, .threadStatusChanged: + break + } + } + } + + return AgentThreadSummary( + threadID: thread.id, + createdAt: thread.createdAt, + updatedAt: thread.updatedAt, + latestItemAt: history.last?.createdAt, + itemCount: history.count, + latestAssistantMessagePreview: latestAssistantMessagePreview, + latestStructuredOutputMetadata: latestStructuredOutputMetadata, + latestPartialStructuredOutput: latestPartialStructuredOutput, + latestToolState: latestToolState, + latestTurnStatus: latestTurnStatus, + pendingState: pendingState + ) + } + + private static func latestToolState(from toolResult: AgentToolResultRecord) -> AgentLatestToolState { + let preview = toolResult.result.primaryText + let session = toolResult.result.session + let status: AgentToolSessionStatus + if toolResult.result.errorMessage == "Tool execution was denied by the user." { + status = .denied + } else if let session, !session.isTerminal { + status = .running + } else if toolResult.result.success { + status = .completed + } else { + status = .failed + } + + return AgentLatestToolState( + invocationID: toolResult.result.invocationID, + turnID: toolResult.turnID, + toolName: toolResult.result.toolName, + status: status, + success: toolResult.result.success, + sessionID: session?.sessionID, + sessionStatus: session?.status, + metadata: session?.metadata, + resumable: session?.resumable ?? false, + updatedAt: toolResult.completedAt, + resultPreview: preview + ) + } +} + +private extension Array where Element == AgentHistoryRecord { + func endIndexForBackward(anchor: Int?) -> Int { + guard let anchor else { + return count + } + + return firstIndex(where: { $0.sequenceNumber >= anchor }) ?? count + } + + func startIndexForForward(anchor: Int?) -> Int { + guard let anchor else { + return 0 + } + + return firstIndex(where: { $0.sequenceNumber > anchor }) ?? count + } +} + +private extension StoredRuntimeState { + func sort( + _ records: [AgentHistoryRecord], + using sort: AgentHistorySort + ) -> [AgentHistoryRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .sequence(order): + if lhs.sequenceNumber == rhs.sequenceNumber { + return lhs.createdAt < rhs.createdAt + } + return order == .ascending + ? lhs.sequenceNumber < rhs.sequenceNumber + : lhs.sequenceNumber > rhs.sequenceNumber + + case let .createdAt(order): + if lhs.createdAt == rhs.createdAt { + return lhs.sequenceNumber < rhs.sequenceNumber + } + return order == .ascending + ? lhs.createdAt < rhs.createdAt + : lhs.createdAt > rhs.createdAt + } + } + } + + func page( + _ records: [AgentHistoryRecord], + threadID: String, + with page: AgentQueryPage?, + sort: AgentHistorySort + ) throws -> AgentHistoryQueryResult { + guard let page else { + let ordered = normalizePageRecords(records, sort: sort) + return AgentHistoryQueryResult( + threadID: threadID, + records: ordered, + nextCursor: nil, + previousCursor: nil, + hasMoreBefore: false, + hasMoreAfter: false + ) + } + + let limit = max(1, page.limit) + let anchor = try page.cursor?.decodedSequenceNumber(expectedThreadID: threadID) + let ascending = normalizePageRecords(records, sort: sort) + let endIndex = if let anchor { + ascending.firstIndex(where: { $0.sequenceNumber >= anchor }) ?? ascending.count + } else { + ascending.count + } + let startIndex = max(0, endIndex - limit) + let sliced = Array(ascending[startIndex ..< endIndex]) + return AgentHistoryQueryResult( + threadID: threadID, + records: sliced, + nextCursor: startIndex > 0 ? AgentHistoryCursor(threadID: threadID, sequenceNumber: sliced.first?.sequenceNumber) : nil, + previousCursor: endIndex < ascending.count ? AgentHistoryCursor(threadID: threadID, sequenceNumber: sliced.last?.sequenceNumber) : nil, + hasMoreBefore: startIndex > 0, + hasMoreAfter: endIndex < ascending.count + ) + } + + func normalizePageRecords( + _ records: [AgentHistoryRecord], + sort: AgentHistorySort + ) -> [AgentHistoryRecord] { + switch sort { + case .sequence(.ascending), .createdAt(.ascending): + return records + case .sequence(.descending), .createdAt(.descending): + return records.reversed() + } + } + + func sort( + _ threads: [AgentThread], + using sort: AgentThreadMetadataSort + ) -> [AgentThread] { + threads.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.id < rhs.id + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + case let .createdAt(order): + if lhs.createdAt == rhs.createdAt { + return lhs.id < rhs.id + } + return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt + } + } + } + + func sort( + _ records: [AgentPendingStateRecord], + using sort: AgentPendingStateSort + ) -> [AgentPendingStateRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + } + } + } + + func sort( + _ records: [AgentStructuredOutputRecord], + using sort: AgentStructuredOutputSort + ) -> [AgentStructuredOutputRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .committedAt(order): + if lhs.committedAt == rhs.committedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.committedAt < rhs.committedAt : lhs.committedAt > rhs.committedAt + } + } + } + + func sort( + _ records: [AgentThreadSnapshot], + using sort: AgentThreadSnapshotSort + ) -> [AgentThreadSnapshot] { + records.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + case let .createdAt(order): + if lhs.createdAt == rhs.createdAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt + } + } + } +} + +private struct AgentHistoryCursorPayload: Codable { + let version: Int + let threadID: String + let sequenceNumber: Int +} + +private extension AgentHistoryCursor { + init(threadID: String, sequenceNumber: Int?) { + guard let sequenceNumber else { + self.init(rawValue: "") + return + } + + let payload = AgentHistoryCursorPayload( + version: 1, + threadID: threadID, + sequenceNumber: sequenceNumber + ) + let data = (try? JSONEncoder().encode(payload)) ?? Data() + let base64 = data.base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + self.init(rawValue: base64) + } + + func decodedSequenceNumber(expectedThreadID: String) throws -> Int { + let padded = rawValue + .replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let remainder = padded.count % 4 + let adjusted = padded + String(repeating: "=", count: remainder == 0 ? 0 : 4 - remainder) + + guard let data = Data(base64Encoded: adjusted) else { + throw AgentRuntimeError.invalidHistoryCursor() + } + + let payload = try JSONDecoder().decode(AgentHistoryCursorPayload.self, from: data) + guard payload.threadID == expectedThreadID else { + throw AgentRuntimeError.invalidHistoryCursor() + } + return payload.sequenceNumber } } diff --git a/Sources/CodexKit/Tools/ToolModels.swift b/Sources/CodexKit/Tools/ToolModels.swift index 3056c14..4839a96 100644 --- a/Sources/CodexKit/Tools/ToolModels.swift +++ b/Sources/CodexKit/Tools/ToolModels.swift @@ -56,6 +56,8 @@ public struct ToolInvocation: Identifiable, Hashable, Sendable { } } +extension ToolInvocation: Codable {} + public enum ToolResultContent: Hashable, Sendable { case text(String) case image(URL) @@ -68,25 +70,52 @@ public enum ToolResultContent: Hashable, Sendable { } } +extension ToolResultContent: Codable {} + +public struct ToolSessionDescriptor: Codable, Hashable, Sendable { + public let sessionID: String + public let status: String + public let metadata: JSONValue? + public let resumable: Bool + public let isTerminal: Bool + + public init( + sessionID: String, + status: String, + metadata: JSONValue? = nil, + resumable: Bool = false, + isTerminal: Bool = true + ) { + self.sessionID = sessionID + self.status = status + self.metadata = metadata + self.resumable = resumable + self.isTerminal = isTerminal + } +} + public struct ToolResultEnvelope: Hashable, Sendable { public let invocationID: String public let toolName: String public let success: Bool public let content: [ToolResultContent] public let errorMessage: String? + public let session: ToolSessionDescriptor? public init( invocationID: String, toolName: String, success: Bool, content: [ToolResultContent] = [], - errorMessage: String? = nil + errorMessage: String? = nil, + session: ToolSessionDescriptor? = nil ) { self.invocationID = invocationID self.toolName = toolName self.success = success self.content = content self.errorMessage = errorMessage + self.session = session } public var primaryText: String? { @@ -95,34 +124,47 @@ public struct ToolResultEnvelope: Hashable, Sendable { public static func success( invocation: ToolInvocation, - text: String + text: String, + session: ToolSessionDescriptor? = nil ) -> ToolResultEnvelope { ToolResultEnvelope( invocationID: invocation.id, toolName: invocation.toolName, success: true, - content: [.text(text)] + content: [.text(text)], + session: session ) } public static func failure( invocation: ToolInvocation, - message: String + message: String, + session: ToolSessionDescriptor? = nil ) -> ToolResultEnvelope { ToolResultEnvelope( invocationID: invocation.id, toolName: invocation.toolName, success: false, content: [.text(message)], - errorMessage: message + errorMessage: message, + session: session ) } - public static func denied(invocation: ToolInvocation) -> ToolResultEnvelope { - failure(invocation: invocation, message: "Tool execution was denied by the user.") + public static func denied( + invocation: ToolInvocation, + session: ToolSessionDescriptor? = nil + ) -> ToolResultEnvelope { + failure( + invocation: invocation, + message: "Tool execution was denied by the user.", + session: session + ) } } +extension ToolResultEnvelope: Codable {} + public struct ToolExecutionContext: Sendable { public let threadID: String public let turnID: String diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift new file mode 100644 index 0000000..b20aaa6 --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift @@ -0,0 +1,818 @@ +import CodexKit +import CodexKitUI +import XCTest + +extension AgentRuntimeTests { + func testFetchThreadSummaryAndLatestStructuredOutputWorkWithoutRestore() async throws { + let backend = InMemoryAgentBackend( + structuredResponseText: #"{"reply":"Your order is already in transit.","priority":"high"}"# + ) + let stateStore = InMemoryRuntimeStateStore() + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "History Summary") + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Draft a shipping reply."), + in: thread.id, + expecting: ShippingReplyDraft.self + ) + + let restoredRuntime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + ) + + let summary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + XCTAssertEqual(summary.threadID, thread.id) + XCTAssertEqual(summary.latestTurnStatus, .completed) + XCTAssertEqual(summary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") + + let metadata = try await restoredRuntime.fetchLatestStructuredOutputMetadata(id: thread.id) + XCTAssertEqual(metadata, summary.latestStructuredOutputMetadata) + + let typed = try await restoredRuntime.fetchLatestStructuredOutput( + id: thread.id, + as: ShippingReplyDraft.self + ) + XCTAssertEqual( + typed, + ShippingReplyDraft( + reply: "Your order is already in transit.", + priority: "high" + ) + ) + } + + func testFetchThreadHistoryPagesMessagesBackwardChronologically() async throws { + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Paged Messages") + _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + + let filter = AgentHistoryFilter( + includeMessages: true, + includeToolCalls: false, + includeToolResults: false, + includeStructuredOutputs: false, + includeApprovals: false, + includeSystemEvents: false + ) + + let newestPage = try await runtime.fetchThreadHistory( + id: thread.id, + query: .init(limit: 2, direction: .backward, filter: filter) + ) + XCTAssertEqual(messageTexts(in: newestPage), ["three", "Echo: three"]) + XCTAssertTrue(newestPage.hasMoreBefore) + XCTAssertFalse(newestPage.hasMoreAfter) + XCTAssertNotNil(newestPage.nextCursor) + + let olderPage = try await runtime.fetchThreadHistory( + id: thread.id, + query: .init( + limit: 2, + cursor: newestPage.nextCursor, + direction: .backward, + filter: filter + ) + ) + XCTAssertEqual(messageTexts(in: olderPage), ["two", "Echo: two"]) + XCTAssertTrue(olderPage.hasMoreBefore) + XCTAssertTrue(olderPage.hasMoreAfter) + } + + @MainActor + func testSummaryReflectsPendingApprovalWithoutRestore() async throws { + let approvalInbox = ApprovalInbox() + let stateStore = InMemoryRuntimeStateStore() + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: approvalInbox, + stateStore: stateStore, + tools: [ + .init( + definition: ToolDefinition( + name: "demo_lookup_profile", + description: "Lookup profile", + inputSchema: .object([:]), + approvalPolicy: .requiresApproval + ), + executor: AnyToolExecutor { invocation, _ in + .success(invocation: invocation, text: "approved-result") + } + ), + ] + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Pending Approval") + let stream = try await runtime.streamMessage( + UserMessageRequest(text: "please use the tool"), + in: thread.id + ) + + let drainTask = Task { + for try await _ in stream {} + } + + try await waitUntil { + await MainActor.run { + approvalInbox.currentRequest != nil + } + } + + let restoredRuntime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: approvalInbox, + stateStore: stateStore + ) + let pendingSummary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + + switch pendingSummary.pendingState { + case let .approval(state): + XCTAssertEqual(state.request.toolInvocation.toolName, "demo_lookup_profile") + default: + XCTFail("Expected approval pending state.") + } + + approvalInbox.approveCurrent() + _ = try await drainTask.value + + let completedSummary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + XCTAssertNil(completedSummary.pendingState) + } + + func testSummaryReflectsPendingToolWaitWithoutRestore() async throws { + let gate = ToolExecutionGate() + let stateStore = InMemoryRuntimeStateStore() + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore, + tools: [ + .init( + definition: ToolDefinition( + name: "demo_lookup_profile", + description: "Lookup profile", + inputSchema: .object([:]) + ), + executor: AnyToolExecutor { invocation, _ in + await gate.wait() + return .success(invocation: invocation, text: "tool-finished") + } + ), + ] + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Pending Tool") + let stream = try await runtime.streamMessage( + UserMessageRequest(text: "please use the tool"), + in: thread.id + ) + + let drainTask = Task { + for try await _ in stream {} + } + + let restoredRuntime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + ) + + try await waitUntil { + let summary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + if case .toolWait = summary.pendingState { + return true + } + return false + } + + let waitingSummary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + switch waitingSummary.pendingState { + case let .toolWait(state): + XCTAssertEqual(state.toolName, "demo_lookup_profile") + default: + XCTFail("Expected tool wait pending state.") + } + + await gate.release() + _ = try await drainTask.value + + let completedSummary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + XCTAssertNil(completedSummary.pendingState) + XCTAssertEqual(completedSummary.latestToolState?.status, .completed) + } + + func testPartialStructuredSnapshotPersistsUntilCommit() async throws { + let backend = BlockingStructuredPartialBackend() + let stateStore = InMemoryRuntimeStateStore() + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Partial Structured") + let stream = try await runtime.streamMessage( + UserMessageRequest(text: "Draft a shipping reply."), + in: thread.id, + expecting: ShippingReplyDraft.self + ) + + let drainTask = Task { + for try await _ in stream {} + } + + await backend.waitForPartialEmission() + + let restoredRuntime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + ) + + try await waitUntil { + let summary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + return summary.latestPartialStructuredOutput != nil + } + + let partialSummary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + XCTAssertEqual(partialSummary.latestPartialStructuredOutput?.formatName, "shipping_reply_draft") + XCTAssertNil(partialSummary.latestStructuredOutputMetadata) + + await backend.releaseCommit() + _ = try await drainTask.value + + let committedSummary = try await restoredRuntime.fetchThreadSummary(id: thread.id) + XCTAssertNil(committedSummary.latestPartialStructuredOutput) + XCTAssertEqual(committedSummary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") + } + + func testFileRuntimeStateStoreMigratesLegacyBlobForSummaryAndHistory() async throws { + let thread = AgentThread( + id: "legacy-thread", + title: "Legacy Thread", + createdAt: Date(timeIntervalSince1970: 100), + updatedAt: Date(timeIntervalSince1970: 100) + ) + let message = AgentMessage( + id: "legacy-message", + threadID: thread.id, + role: .assistant, + text: "Hello from legacy state", + createdAt: Date(timeIntervalSince1970: 101) + ) + let legacyState = StoredRuntimeState( + threads: [thread], + messagesByThread: [thread.id: [message]] + ) + + let url = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString) + .appendingPathExtension("json") + try JSONEncoder().encode(legacyState).write(to: url, options: .atomic) + + let store = FileRuntimeStateStore(url: url) + let summary = try await store.fetchThreadSummary(id: thread.id) + XCTAssertEqual(summary.latestAssistantMessagePreview, "Hello from legacy state") + + let history = try await store.fetchThreadHistory( + id: thread.id, + query: .init( + limit: 10, + direction: .backward, + filter: AgentHistoryFilter( + includeMessages: true, + includeToolCalls: false, + includeToolResults: false, + includeStructuredOutputs: false, + includeApprovals: false, + includeSystemEvents: false + ) + ) + ) + XCTAssertEqual(messageTexts(in: history), ["Hello from legacy state"]) + } + + func testPrepareStoreReturnsMetadataAndQueryableExecutionWorks() async throws { + let stateStore = InMemoryRuntimeStateStore() + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + ) + + let metadata = try await runtime.prepareStore() + XCTAssertEqual(metadata.logicalSchemaVersion, .v1) + XCTAssertEqual(metadata.storeKind, "InMemoryRuntimeStateStore") + XCTAssertTrue(metadata.capabilities.supportsPushdownQueries) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Queryable") + _ = try await runtime.sendMessage(UserMessageRequest(text: "hello"), in: thread.id) + + let threads = try await runtime.execute( + ThreadMetadataQuery( + threadIDs: [thread.id], + limit: 1 + ) + ) + XCTAssertEqual(threads.first?.id, thread.id) + + let snapshots = try await runtime.execute( + ThreadSnapshotQuery( + threadIDs: [thread.id], + limit: 1 + ) + ) + XCTAssertEqual(snapshots.first?.threadID, thread.id) + + let history = try await runtime.execute( + HistoryItemsQuery( + threadID: thread.id, + kinds: [.message] + ) + ) + XCTAssertEqual(history.records.count, 2) + } + + func testRedactHistoryItemsPreservesRecordIdentityAndHidesPayload() async throws { + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Redactions") + _ = try await runtime.sendMessage(UserMessageRequest(text: "top secret"), in: thread.id) + + let history = try await runtime.execute( + HistoryItemsQuery( + threadID: thread.id, + kinds: [.message] + ) + ) + guard let userRecord = history.records.first(where: { + if case let .message(message) = $0.item { + return message.role == .user + } + return false + }) else { + return XCTFail("Expected a persisted user message.") + } + + try await runtime.redactHistoryItems( + [userRecord.id], + in: thread.id, + reason: .init(code: "privacy", message: "User requested redaction") + ) + + let redactedHistory = try await runtime.execute( + HistoryItemsQuery( + threadID: thread.id, + kinds: [.message] + ) + ) + guard let redactedRecord = redactedHistory.records.first(where: { $0.id == userRecord.id }) else { + return XCTFail("Expected redacted record to remain in history.") + } + + XCTAssertNotNil(redactedRecord.redaction) + if case let .message(message) = redactedRecord.item { + XCTAssertEqual(message.text, "[Redacted]") + XCTAssertTrue(message.images.isEmpty) + } else { + XCTFail("Expected redacted message item.") + } + } + + func testDeleteThreadRemovesThreadFromQueries() async throws { + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Delete Me") + _ = try await runtime.sendMessage(UserMessageRequest(text: "bye"), in: thread.id) + + try await runtime.deleteThread(id: thread.id) + + let threads = try await runtime.execute( + ThreadMetadataQuery(threadIDs: [thread.id]) + ) + XCTAssertTrue(threads.isEmpty) + + let snapshots = try await runtime.execute( + ThreadSnapshotQuery(threadIDs: [thread.id]) + ) + XCTAssertTrue(snapshots.isEmpty) + + let history = try await runtime.execute( + HistoryItemsQuery(threadID: thread.id) + ) + XCTAssertTrue(history.records.isEmpty) + } + + func testGRDBRuntimeStateStorePersistsSummariesAndQueriesAcrossReload() async throws { + let url = temporaryRuntimeSQLiteURL() + defer { try? FileManager.default.removeItem(at: url) } + + let backend = InMemoryAgentBackend( + structuredResponseText: #"{"reply":"The replacement is shipping today.","priority":"urgent"}"# + ) + let store = try GRDBRuntimeStateStore(url: url) + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: store + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "GRDB Thread") + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Draft the shipping update."), + in: thread.id, + expecting: ShippingReplyDraft.self + ) + + let reloadedStore = try GRDBRuntimeStateStore(url: url) + let reloadedRuntime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: reloadedStore + ) + + let metadata = try await reloadedRuntime.prepareStore() + XCTAssertEqual(metadata.storeKind, "GRDBRuntimeStateStore") + XCTAssertEqual(metadata.storeSchemaVersion, 1) + + let summary = try await reloadedRuntime.fetchThreadSummary(id: thread.id) + XCTAssertEqual(summary.latestTurnStatus, .completed) + XCTAssertEqual(summary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") + + let snapshots = try await reloadedRuntime.execute( + ThreadSnapshotQuery(threadIDs: [thread.id]) + ) + XCTAssertEqual(snapshots.count, 1) + XCTAssertEqual(snapshots.first?.threadID, thread.id) + + let history = try await reloadedRuntime.execute( + HistoryItemsQuery( + threadID: thread.id, + kinds: [.message, .structuredOutput] + ) + ) + XCTAssertFalse(history.records.isEmpty) + + let typed = try await reloadedRuntime.fetchLatestStructuredOutput( + id: thread.id, + as: ShippingReplyDraft.self + ) + XCTAssertEqual( + typed, + ShippingReplyDraft( + reply: "The replacement is shipping today.", + priority: "urgent" + ) + ) + } + + func testGRDBRuntimeStateStorePersistsRedactionAndDeletion() async throws { + let url = temporaryRuntimeSQLiteURL() + defer { try? FileManager.default.removeItem(at: url) } + + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: try GRDBRuntimeStateStore(url: url) + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "GRDB Mutations") + _ = try await runtime.sendMessage(UserMessageRequest(text: "please redact me"), in: thread.id) + + let messageHistory = try await runtime.execute( + HistoryItemsQuery(threadID: thread.id, kinds: [.message]) + ) + guard let firstMessage = messageHistory.records.first else { + return XCTFail("Expected a persisted message record.") + } + + try await runtime.redactHistoryItems([firstMessage.id], in: thread.id) + + let reloadedAfterRedaction = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: try GRDBRuntimeStateStore(url: url) + ) + let redactedHistory = try await reloadedAfterRedaction.execute( + HistoryItemsQuery(threadID: thread.id, kinds: [.message]) + ) + + guard let redactedRecord = redactedHistory.records.first(where: { $0.id == firstMessage.id }) else { + return XCTFail("Expected the redacted record to still be queryable.") + } + XCTAssertNotNil(redactedRecord.redaction) + + try await reloadedAfterRedaction.deleteThread(id: thread.id) + + let deletedRuntime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: try GRDBRuntimeStateStore(url: url) + ) + let threads = try await deletedRuntime.execute( + ThreadMetadataQuery(threadIDs: [thread.id]) + ) + XCTAssertTrue(threads.isEmpty) + } + + func testGRDBRuntimeStateStoreImportsLegacyFileStateOnFirstPrepare() async throws { + let directory = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) + try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: directory) } + + let legacyURL = directory.appendingPathComponent("runtime-state").appendingPathExtension("json") + let sqliteURL = directory.appendingPathComponent("runtime-state").appendingPathExtension("sqlite") + + let backend = InMemoryAgentBackend( + structuredResponseText: #"{"reply":"Legacy import payload.","priority":"normal"}"# + ) + let legacyRuntime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: FileRuntimeStateStore(url: legacyURL) + ) + + _ = try await legacyRuntime.restore() + _ = try await legacyRuntime.signIn() + + let thread = try await legacyRuntime.createThread(title: "Legacy File Thread") + _ = try await legacyRuntime.sendMessage( + UserMessageRequest(text: "Create a legacy payload."), + in: thread.id, + expecting: ShippingReplyDraft.self + ) + + let importedStore = try GRDBRuntimeStateStore(url: sqliteURL) + let importedRuntime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: importedStore + ) + + _ = try await importedRuntime.prepareStore() + + let summary = try await importedRuntime.fetchThreadSummary(id: thread.id) + XCTAssertEqual(summary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") + + let history = try await importedRuntime.execute( + HistoryItemsQuery(threadID: thread.id, kinds: [.message, .structuredOutput]) + ) + XCTAssertFalse(history.records.isEmpty) + } +} + +private func makeHistoryRuntime( + backend: any AgentBackend, + approvalPresenter: any ApprovalPresenting, + stateStore: any RuntimeStateStoring, + tools: [AgentRuntime.ToolRegistration] = [] +) throws -> AgentRuntime { + try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: approvalPresenter, + stateStore: stateStore, + tools: tools + )) +} + +private func temporaryRuntimeSQLiteURL() -> URL { + FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString) + .appendingPathExtension("sqlite") +} + +private func messageTexts(in page: AgentThreadHistoryPage) -> [String] { + page.items.compactMap { item in + guard case let .message(message) = item else { + return nil + } + return message.displayText + } +} + +private func waitUntil( + timeoutNanoseconds: UInt64 = 2_000_000_000, + intervalNanoseconds: UInt64 = 20_000_000, + condition: @escaping @Sendable () async throws -> Bool +) async throws { + let deadline = DispatchTime.now().uptimeNanoseconds + timeoutNanoseconds + while DispatchTime.now().uptimeNanoseconds < deadline { + if try await condition() { + return + } + try await Task.sleep(nanoseconds: intervalNanoseconds) + } + + XCTFail("Timed out waiting for condition.") +} + +private actor ToolExecutionGate { + private var waiters: [CheckedContinuation] = [] + private var released = false + + func wait() async { + guard !released else { + return + } + + await withCheckedContinuation { continuation in + waiters.append(continuation) + } + } + + func release() { + released = true + let continuations = waiters + waiters.removeAll() + continuations.forEach { $0.resume() } + } +} + +private actor PartialEmissionGate { + private var partialEmitted = false + private var partialWaiters: [CheckedContinuation] = [] + private var commitReleased = false + private var commitWaiters: [CheckedContinuation] = [] + + func markPartialEmitted() { + partialEmitted = true + let continuations = partialWaiters + partialWaiters.removeAll() + continuations.forEach { $0.resume() } + } + + func waitForPartialEmission() async { + guard !partialEmitted else { + return + } + + await withCheckedContinuation { continuation in + partialWaiters.append(continuation) + } + } + + func waitForCommitRelease() async { + guard !commitReleased else { + return + } + + await withCheckedContinuation { continuation in + commitWaiters.append(continuation) + } + } + + func releaseCommit() { + commitReleased = true + let continuations = commitWaiters + commitWaiters.removeAll() + continuations.forEach { $0.resume() } + } +} + +private actor BlockingStructuredPartialBackend: AgentBackend { + private let gate = PartialEmissionGate() + private var latestThreadID: String? + + func createThread(session _: ChatGPTSession) async throws -> AgentThread { + AgentThread(id: UUID().uuidString) + } + + func resumeThread(id: String, session _: ChatGPTSession) async throws -> AgentThread { + AgentThread(id: id) + } + + func beginTurn( + thread: AgentThread, + history _: [AgentMessage], + message _: UserMessageRequest, + instructions _: String, + responseFormat _: AgentStructuredOutputFormat?, + streamedStructuredOutput _: AgentStreamedStructuredOutputRequest?, + tools _: [ToolDefinition], + session _: ChatGPTSession + ) async throws -> any AgentTurnStreaming { + latestThreadID = thread.id + return BlockingStructuredPartialTurnSession( + threadID: thread.id, + gate: gate + ) + } + + func waitForPartialEmission() async { + await gate.waitForPartialEmission() + } + + func releaseCommit() async { + await gate.releaseCommit() + } +} + +private final class BlockingStructuredPartialTurnSession: AgentTurnStreaming, @unchecked Sendable { + let events: AsyncThrowingStream + + init(threadID: String, gate: PartialEmissionGate) { + let turn = AgentTurn(id: UUID().uuidString, threadID: threadID) + let payload: JSONValue = .object([ + "reply": .string("Your order is already in transit."), + "priority": .string("high"), + ]) + + events = AsyncThrowingStream { continuation in + Task { + continuation.yield(.turnStarted(turn)) + continuation.yield( + .assistantMessageDelta( + threadID: threadID, + turnID: turn.id, + delta: "Echo: Draft a shipping reply." + ) + ) + continuation.yield(.structuredOutputPartial(payload)) + await gate.markPartialEmitted() + await gate.waitForCommitRelease() + continuation.yield(.structuredOutputCommitted(payload)) + continuation.yield( + .assistantMessageCompleted( + AgentMessage( + threadID: threadID, + role: .assistant, + text: "Echo: Draft a shipping reply.", + structuredOutput: AgentStructuredOutputMetadata( + formatName: "shipping_reply_draft", + payload: payload + ) + ) + ) + ) + continuation.yield( + .turnCompleted( + AgentTurnSummary( + threadID: threadID, + turnID: turn.id, + usage: AgentUsage(inputTokens: 1, outputTokens: 1) + ) + ) + ) + continuation.finish() + } + } + } + + func submitToolResult(_: ToolResultEnvelope, for _: String) async throws {} +} From 825a7b7b251794899fa393d76f80bbc2f225f299 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Mon, 23 Mar 2026 21:30:20 +1100 Subject: [PATCH 02/19] Improve demo logging and ignore cancellation noise --- .../Shared/AgentDemoView+ChatSections.swift | 17 ++++ .../Shared/AgentDemoView.swift | 1 - .../AgentDemoViewModel+HealthCoach.swift | 8 +- ...entDemoViewModel+HealthCoachPlatform.swift | 2 +- .../Shared/AgentDemoViewModel+Memory.swift | 10 +-- .../Shared/AgentDemoViewModel+Messaging.swift | 39 +++++++-- .../AgentDemoViewModel+StructuredOutput.swift | 21 ++++- .../Shared/AgentDemoViewModel+Tools.swift | 8 +- .../Shared/AgentDemoViewModel.swift | 83 +++++++++++++++++-- 9 files changed, 162 insertions(+), 27 deletions(-) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift index 0877a4a..b906ecb 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift @@ -89,6 +89,23 @@ extension AgentDemoView { .font(.subheadline) .foregroundStyle(.secondary) + Toggle( + "Developer Logging", + isOn: Binding( + get: { viewModel.developerLoggingEnabled }, + set: { viewModel.developerLoggingEnabled = $0 } + ) + ) + .toggleStyle(.switch) + + Text("Logs restore, sign-in, thread lifecycle, turn events, and tool activity to the Xcode console.") + .font(.caption) + .foregroundStyle(.secondary) + + Text("State store: \(viewModel.resolvedStateURL.lastPathComponent)") + .font(.caption.monospaced()) + .foregroundStyle(.secondary) + LazyVGrid(columns: tileColumns, spacing: 12) { ForEach(ReasoningEffort.allCases, id: \.self) { effort in reasoningEffortTile(for: effort) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView.swift index 692981f..38a0bec 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView.swift @@ -32,7 +32,6 @@ struct AgentDemoView: View { #endif .task { await viewModel.restore() - await viewModel.registerDemoTool() } .sheet(item: approvalRequestBinding) { request in approvalSheet(for: request) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift index d68261c..5a868b7 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift @@ -66,7 +66,7 @@ extension AgentDemoViewModel { } } catch { healthKitAuthorized = false - lastError = error.localizedDescription + reportError(error) } notificationAuthorized = await requestNotificationAuthorization() @@ -102,7 +102,7 @@ extension AgentDemoViewModel { await updateReminderScheduleIfPossible() await refreshAICoachFeedback() } catch { - lastError = error.localizedDescription + reportError(error) } #else healthCoachFeedback = "Health Coach is currently available on iOS only." @@ -139,7 +139,7 @@ extension AgentDemoViewModel { threads = await runtime.threads() lastError = nil } catch { - lastError = error.localizedDescription + reportError(error) } await updateReminderScheduleIfPossible() @@ -231,7 +231,7 @@ extension AgentDemoViewModel { } } } catch { - lastError = error.localizedDescription + reportError(error) } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift index 6f7f13f..0246a9b 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift @@ -142,7 +142,7 @@ extension AgentDemoViewModel { } } } catch { - lastError = error.localizedDescription + reportError(error) } } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Memory.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Memory.swift index 6e3d4bf..87ccf8b 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Memory.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Memory.swift @@ -51,7 +51,7 @@ extension AgentDemoViewModel { ) threads = await runtime.threads() } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -90,7 +90,7 @@ extension AgentDemoViewModel { ) threads = await runtime.threads() } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -115,7 +115,7 @@ extension AgentDemoViewModel { diagnostics: diagnostics ) } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -143,7 +143,7 @@ extension AgentDemoViewModel { diagnostics: diagnostics ) } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -199,7 +199,7 @@ extension AgentDemoViewModel { ) ) } catch { - lastError = error.localizedDescription + reportError(error) } } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift index 82cbc99..9f22d41 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift @@ -18,6 +18,9 @@ extension AgentDemoViewModel { skillIDs: [String] = [] ) async { do { + developerLog( + "Creating thread. title=\(title ?? "") skills=\(skillIDs.joined(separator: ",")) personaLayers=\(personaStack?.layers.count ?? 0)" + ) let thread = try await runtime.createThread( title: title, personaStack: personaStack, @@ -26,8 +29,11 @@ extension AgentDemoViewModel { threads = await runtime.threads() activeThreadID = thread.id setMessages(await runtime.messages(for: thread.id)) + developerLog( + "Created thread. id=\(thread.id) title=\(thread.title ?? "") totalThreads=\(threads.count)" + ) } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -58,6 +64,9 @@ extension AgentDemoViewModel { ) do { + developerLog( + "Sending message. threadID=\(activeThreadID) textLength=\(trimmedText.count) images=\(images.count) personaOverrideLayers=\(personaOverride?.layers.count ?? 0)" + ) _ = try await sendRequest( request, in: activeThreadID, @@ -65,7 +74,7 @@ extension AgentDemoViewModel { renderInActiveTranscript: true ) } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -145,7 +154,7 @@ extension AgentDemoViewModel { lastError = "Probe completed, but result was inconclusive. Review the two thread summaries." } } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -183,6 +192,9 @@ extension AgentDemoViewModel { ) let threadTitle = threads.first(where: { $0.id == threadID })?.title lastResolvedInstructionsThreadTitle = threadTitle ?? "Untitled Thread" + developerLog( + "Captured resolved instructions. threadID=\(threadID) title=\(lastResolvedInstructionsThreadTitle ?? "Untitled Thread")" + ) } catch { lastResolvedInstructions = nil lastResolvedInstructionsThreadTitle = nil @@ -198,6 +210,10 @@ extension AgentDemoViewModel { streamingText = "" } + developerLog( + "Starting streamed turn. threadID=\(threadID) textLength=\(request.text?.count ?? 0) imageCount=\(request.images.count)" + ) + let stream = try await runtime.streamMessage( request, in: threadID @@ -210,6 +226,9 @@ extension AgentDemoViewModel { switch event { case let .threadStarted(thread): threads = [thread] + threads.filter { $0.id != thread.id } + developerLog( + "Thread started event. id=\(thread.id) title=\(thread.title ?? "") status=\(thread.status.rawValue)" + ) case let .threadStatusChanged(threadID, status): threads = threads.map { thread in @@ -222,9 +241,10 @@ extension AgentDemoViewModel { updated.updatedAt = Date() return updated } + developerLog("Thread status changed. threadID=\(threadID) status=\(status.rawValue)") case .turnStarted: - break + developerLog("Turn started. threadID=\(threadID)") case let .assistantMessageDelta(_, _, delta): if renderInActiveTranscript { @@ -244,12 +264,15 @@ extension AgentDemoViewModel { streamingText = "" } } + developerLog( + "Message committed. threadID=\(threadID) role=\(message.role.rawValue) textLength=\(message.text.count)" + ) case .approvalRequested: - break + developerLog("Approval requested. threadID=\(threadID)") case .approvalResolved: - break + developerLog("Approval resolved. threadID=\(threadID)") case let .toolCallStarted(invocation): diagnostics.sawToolCall = true @@ -276,9 +299,13 @@ extension AgentDemoViewModel { setMessages(await runtime.messages(for: threadID)) } threads = await runtime.threads() + developerLog("Turn completed. threadID=\(threadID)") case let .turnFailed(error): diagnostics.turnFailedCode = error.code + developerErrorLog( + "Turn failed. threadID=\(threadID) code=\(error.code) message=\(error.message)" + ) lastError = error.message } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift index be7262b..d667e8d 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift @@ -19,6 +19,7 @@ extension AgentDemoViewModel { } do { + developerLog("Running structured shipping reply demo.") let thread = try await runtime.createThread( title: "Structured Output: Shipping Draft", personaStack: Self.supportPersona @@ -47,8 +48,9 @@ extension AgentDemoViewModel { threads = await runtime.threads() activeThreadID = thread.id setMessages(await runtime.messages(for: thread.id)) + developerLog("Structured shipping reply demo finished. threadID=\(thread.id)") } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -68,6 +70,7 @@ extension AgentDemoViewModel { } do { + developerLog("Running structured imported summary demo.") let thread = try await runtime.createThread( title: "Structured Output: Imported Summary" ) @@ -95,8 +98,9 @@ extension AgentDemoViewModel { threads = await runtime.threads() activeThreadID = thread.id setMessages(await runtime.messages(for: thread.id)) + developerLog("Structured imported summary demo finished. threadID=\(thread.id)") } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -118,6 +122,7 @@ extension AgentDemoViewModel { } do { + developerLog("Running streamed structured output demo.") let thread = try await runtime.createThread( title: "Structured Output: Streamed Delivery Update", personaStack: Self.supportPersona @@ -155,9 +160,13 @@ extension AgentDemoViewModel { if partialSnapshots.last != partial { partialSnapshots.append(partial) } + developerLog("Structured partial received. threadID=\(thread.id)") case let .structuredOutputCommitted(payload): committedPayload = payload + developerLog( + "Structured payload committed. threadID=\(thread.id) format=\(StreamedStructuredDeliveryUpdate.responseFormat.name)" + ) case let .turnFailed(error): throw error @@ -188,9 +197,15 @@ extension AgentDemoViewModel { threads = await runtime.threads() activeThreadID = thread.id setMessages(messages) + developerLog( + "Streamed structured output demo finished. threadID=\(thread.id) partialCount=\(partialSnapshots.count) persistedMetadata=\(persistedMetadata != nil)" + ) } catch { + guard !Self.isCancellationError(error) else { + return + } structuredStreamingError = error.localizedDescription - lastError = error.localizedDescription + reportError(error) } } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift index c078a91..2ffb15d 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift @@ -15,8 +15,9 @@ extension AgentDemoViewModel { do { try await runtime.replaceSkill(Self.healthCoachSkill) try await runtime.replaceSkill(Self.travelPlannerSkill) + developerLog("Registered demo skills: \(Self.healthCoachSkill.id), \(Self.travelPlannerSkill.id)") } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -73,8 +74,11 @@ extension AgentDemoViewModel { try await registerTool(travelPlannerDefinition) { invocation, _ in Self.makeTravelDayPlan(invocation: invocation) } + developerLog( + "Registered demo tools: \(Self.healthCoachToolName), \(Self.travelPlannerToolName)" + ) } catch { - lastError = error.localizedDescription + reportError(error) } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift index 2c63019..ba5d50e 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift @@ -83,6 +83,7 @@ struct AutomaticPolicyMemoryDemoResult: Sendable { @MainActor @Observable final class AgentDemoViewModel: @unchecked Sendable { + nonisolated static let developerLoggingDefaultsKey = "AssistantRuntimeDemoApp.developerLoggingEnabled" nonisolated static let logger = Logger( subsystem: "ai.assistantruntime.demoapp", category: "DemoTool" @@ -137,6 +138,19 @@ final class AgentDemoViewModel: @unchecked Sendable { var streamingText = "" var lastError: String? var showResolvedInstructionsDebug = false + var developerLoggingEnabled: Bool { + didSet { + UserDefaults.standard.set( + developerLoggingEnabled, + forKey: Self.developerLoggingDefaultsKey + ) + developerLog( + developerLoggingEnabled + ? "Developer logging enabled." + : "Developer logging disabled." + ) + } + } var lastResolvedInstructions: String? var lastResolvedInstructionsThreadTitle: String? var isRunningSkillPolicyProbe = false @@ -204,6 +218,9 @@ final class AgentDemoViewModel: @unchecked Sendable { self.model = model self.enableWebSearch = enableWebSearch self.reasoningEffort = reasoningEffort + self.developerLoggingEnabled = UserDefaults.standard.bool( + forKey: Self.developerLoggingDefaultsKey + ) self.stateURL = stateURL self.keychainAccount = keychainAccount self.approvalInbox = approvalInbox @@ -221,6 +238,14 @@ final class AgentDemoViewModel: @unchecked Sendable { personaSummary(for: activeThread) } + var resolvedStateURL: URL { + stateURL ?? AgentDemoRuntimeFactory.defaultStateURL() + } + + var legacyStateURL: URL { + resolvedStateURL.deletingPathExtension().appendingPathExtension("json") + } + var healthProgressFraction: Double { guard dailyStepGoal > 0 else { return 0 @@ -248,13 +273,19 @@ final class AgentDemoViewModel: @unchecked Sendable { } func restore() async { + developerLog( + "Restore started. store=\(resolvedStateURL.path) legacyJSONPresent=\(FileManager.default.fileExists(atPath: legacyStateURL.path))" + ) do { _ = try await runtime.restore() await registerDemoTool() await registerDemoSkills() await refreshSnapshot() + developerLog( + "Restore finished. sessionPresent=\(session != nil) threadCount=\(threads.count)" + ) } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -266,6 +297,7 @@ final class AgentDemoViewModel: @unchecked Sendable { isAuthenticating = true lastError = nil currentAuthenticationMethod = authenticationMethod + developerLog("Sign-in started. method=\(authenticationMethod.rawValue)") runtime = AgentDemoRuntimeFactory.makeRuntime( authenticationMethod: authenticationMethod, model: model, @@ -290,10 +322,13 @@ final class AgentDemoViewModel: @unchecked Sendable { if healthCoachInitialized { await refreshHealthCoachProgress() } + developerLog( + "Sign-in finished. account=\(session?.account.email ?? "") threadCount=\(threads.count)" + ) } catch { await deviceCodePromptCoordinator.clear() await refreshSnapshot() - lastError = error.localizedDescription + reportError(error) } } @@ -308,6 +343,7 @@ final class AgentDemoViewModel: @unchecked Sendable { } self.reasoningEffort = reasoningEffort + developerLog("Reconfiguring runtime. reasoningEffort=\(reasoningEffort.rawValue)") let preservedActiveThreadID = activeThreadID let preservedHealthCoachThreadID = healthCoachThreadID @@ -337,8 +373,11 @@ final class AgentDemoViewModel: @unchecked Sendable { threads.contains(where: { $0.id == preservedHealthCoachThreadID }) { healthCoachThreadID = preservedHealthCoachThreadID } + developerLog( + "Runtime reconfigured. reasoningEffort=\(reasoningEffort.rawValue) threadCount=\(threads.count)" + ) } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -369,7 +408,7 @@ final class AgentDemoViewModel: @unchecked Sendable { ) threads = await runtime.threads() } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -454,9 +493,19 @@ final class AgentDemoViewModel: @unchecked Sendable { } func reportError(_ message: String) { + developerErrorLog(message) lastError = message } + func reportError(_ error: Error) { + guard !Self.isCancellationError(error) else { + developerLog("Ignoring CancellationError from async UI task.") + return + } + developerErrorLog(error.localizedDescription) + lastError = error.localizedDescription + } + func approvePendingRequest() { approvalInbox.approveCurrent() } @@ -469,6 +518,26 @@ final class AgentDemoViewModel: @unchecked Sendable { lastError = nil } + nonisolated static func isCancellationError(_ error: Error) -> Bool { + error is CancellationError + } + + func developerLog(_ message: String) { + guard developerLoggingEnabled else { + return + } + Self.logger.notice("\(message, privacy: .public)") + print("[CodexKit Demo] \(message)") + } + + func developerErrorLog(_ message: String) { + guard developerLoggingEnabled else { + return + } + Self.logger.error("\(message, privacy: .public)") + print("[CodexKit Demo][Error] \(message)") + } + func signOut() async { do { try await runtime.signOut() @@ -506,7 +575,7 @@ final class AgentDemoViewModel: @unchecked Sendable { cachedAIReminderGeneratedAt = nil lastError = nil } catch { - lastError = error.localizedDescription + reportError(error) } } @@ -514,10 +583,14 @@ final class AgentDemoViewModel: @unchecked Sendable { session = await runtime.currentSession() guard session != nil else { clearConversationSnapshot() + developerLog("Snapshot refreshed with no active session.") return } threads = await runtime.threads() + developerLog( + "Snapshot refreshed. session=\(session?.account.email ?? "") threadCount=\(threads.count)" + ) let selectedThreadID = activeThreadID if let selectedThreadID, From b1b9eb16d74ef1b4b030e6630adb8e8685b8576d Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Mon, 23 Mar 2026 22:22:07 +1100 Subject: [PATCH 03/19] Externalize attachments and smooth demo streaming --- .../Shared/AgentDemoView+ChatSections.swift | 35 ++- .../Shared/AgentDemoViewModel+Messaging.swift | 38 ++- .../Shared/AgentDemoViewModel.swift | 15 +- .../Shared/ThreadDetailView.swift | 266 +++++++++------- .../Runtime/GRDBRuntimeStateStore.swift | 108 +++++-- .../RuntimeAttachmentPersistence.swift | 285 ++++++++++++++++++ .../CodexKit/Runtime/RuntimeStateStore.swift | 19 +- .../AgentRuntimeHistoryTests.swift | 61 ++++ 8 files changed, 655 insertions(+), 172 deletions(-) create mode 100644 Sources/CodexKit/Runtime/RuntimeAttachmentPersistence.swift diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift index b906ecb..70242ac 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoView+ChatSections.swift @@ -41,6 +41,23 @@ extension AgentDemoView { .foregroundStyle(.secondary) } + Toggle( + "Developer Logging", + isOn: Binding( + get: { viewModel.developerLoggingEnabled }, + set: { viewModel.developerLoggingEnabled = $0 } + ) + ) + .toggleStyle(.switch) + + Text("Debug builds start with logging enabled. Logs print to the Xcode console for restore, sign-in, thread lifecycle, turn events, and tool activity.") + .font(.caption) + .foregroundStyle(.secondary) + + Text("State store: \(viewModel.resolvedStateURL.lastPathComponent)") + .font(.caption.monospaced()) + .foregroundStyle(.secondary) + LazyVGrid(columns: tileColumns, spacing: 12) { registerToolTile @@ -88,24 +105,6 @@ extension AgentDemoView { Text("Pick a thinking level for future requests. Existing threads stay intact; only new turns use the updated effort.") .font(.subheadline) .foregroundStyle(.secondary) - - Toggle( - "Developer Logging", - isOn: Binding( - get: { viewModel.developerLoggingEnabled }, - set: { viewModel.developerLoggingEnabled = $0 } - ) - ) - .toggleStyle(.switch) - - Text("Logs restore, sign-in, thread lifecycle, turn events, and tool activity to the Xcode console.") - .font(.caption) - .foregroundStyle(.secondary) - - Text("State store: \(viewModel.resolvedStateURL.lastPathComponent)") - .font(.caption.monospaced()) - .foregroundStyle(.secondary) - LazyVGrid(columns: tileColumns, spacing: 12) { ForEach(ReasoningEffort.allCases, id: \.self) { effort in reasoningEffortTile(for: effort) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift index 9f22d41..7b0e53f 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift @@ -209,9 +209,11 @@ extension AgentDemoViewModel { if renderInActiveTranscript { streamingText = "" } + var bufferedStreamingText = "" + var lastStreamingFlushAt = Date.distantPast developerLog( - "Starting streamed turn. threadID=\(threadID) textLength=\(request.text?.count ?? 0) imageCount=\(request.images.count)" + "Starting streamed turn. threadID=\(threadID) textLength=\(request.text.count) imageCount=\(request.images.count)" ) let stream = try await runtime.streamMessage( @@ -222,6 +224,24 @@ extension AgentDemoViewModel { setMessages(await runtime.messages(for: threadID)) } + func flushStreamingText(force: Bool = false) { + guard renderInActiveTranscript else { + return + } + guard !bufferedStreamingText.isEmpty else { + return + } + + let now = Date() + guard force || now.timeIntervalSince(lastStreamingFlushAt) >= 0.05 else { + return + } + + streamingText.append(bufferedStreamingText) + bufferedStreamingText.removeAll(keepingCapacity: true) + lastStreamingFlushAt = now + } + for try await event in stream { switch event { case let .threadStarted(thread): @@ -248,10 +268,12 @@ extension AgentDemoViewModel { case let .assistantMessageDelta(_, _, delta): if renderInActiveTranscript { - streamingText.append(delta) + bufferedStreamingText.append(delta) + flushStreamingText() } case let .messageCommitted(message): + flushStreamingText(force: true) if message.role == .assistant { let reply = message.text.trimmingCharacters(in: .whitespacesAndNewlines) if !reply.isEmpty { @@ -276,8 +298,8 @@ extension AgentDemoViewModel { case let .toolCallStarted(invocation): diagnostics.sawToolCall = true - Self.logger.info( - "Tool call requested: \(invocation.toolName, privacy: .public) with arguments: \(String(describing: invocation.arguments), privacy: .public)" + developerLog( + "Tool call requested. threadID=\(threadID) tool=\(invocation.toolName) arguments=\(String(describing: invocation.arguments))" ) case let .toolCallFinished(result): @@ -290,11 +312,12 @@ extension AgentDemoViewModel { diagnostics.firstFailureMessage = result.primaryText ?? result.errorMessage } } - Self.logger.info( - "Tool call finished: \(result.toolName, privacy: .public) success=\(result.success, privacy: .public) output=\(result.primaryText ?? "", privacy: .public)" + developerLog( + "Tool call finished. threadID=\(threadID) tool=\(result.toolName) success=\(result.success) output=\(result.primaryText ?? result.errorMessage ?? "")" ) case .turnCompleted: + flushStreamingText(force: true) if renderInActiveTranscript { setMessages(await runtime.messages(for: threadID)) } @@ -302,11 +325,12 @@ extension AgentDemoViewModel { developerLog("Turn completed. threadID=\(threadID)") case let .turnFailed(error): + flushStreamingText(force: true) diagnostics.turnFailedCode = error.code developerErrorLog( "Turn failed. threadID=\(threadID) code=\(error.code) message=\(error.message)" ) - lastError = error.message + reportError(error.message) } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift index ba5d50e..afb81e4 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift @@ -218,9 +218,7 @@ final class AgentDemoViewModel: @unchecked Sendable { self.model = model self.enableWebSearch = enableWebSearch self.reasoningEffort = reasoningEffort - self.developerLoggingEnabled = UserDefaults.standard.bool( - forKey: Self.developerLoggingDefaultsKey - ) + self.developerLoggingEnabled = Self.initialDeveloperLoggingEnabled() self.stateURL = stateURL self.keychainAccount = keychainAccount self.approvalInbox = approvalInbox @@ -522,6 +520,17 @@ final class AgentDemoViewModel: @unchecked Sendable { error is CancellationError } + nonisolated static func initialDeveloperLoggingEnabled() -> Bool { + if UserDefaults.standard.object(forKey: developerLoggingDefaultsKey) != nil { + return UserDefaults.standard.bool(forKey: developerLoggingDefaultsKey) + } +#if DEBUG + return true +#else + return false +#endif + } + func developerLog(_ message: String) { guard developerLoggingEnabled else { return diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift index d77e17f..680e718 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift @@ -114,52 +114,11 @@ private extension ThreadDetailView { } else { LazyVStack(alignment: .leading, spacing: 12) { ForEach(threadMessages) { message in - VStack(alignment: .leading, spacing: 6) { - Text(message.role.rawValue.capitalized) - .font(.caption.weight(.semibold)) - .foregroundStyle(.secondary) - - if shouldShowVisibleText(for: message) { - Text(message.displayText) - .frame(maxWidth: .infinity, alignment: .leading) - } - - if let structuredOutput = message.structuredOutput { - structuredOutputCard(structuredOutput, for: message) - } - - if !message.images.isEmpty { - attachmentGallery(for: message.images) - - Text(message.images.count == 1 ? "1 image attached" : "\(message.images.count) images attached") - .font(.caption) - .foregroundStyle(.secondary) - } - } - .padding(14) - .background( - RoundedRectangle(cornerRadius: 16, style: .continuous) - .fill( - message.role == .user - ? Color.accentColor.opacity(0.12) - : Color.primary.opacity(0.04) - ) - ) + ThreadMessageBubble(message: message) } if isStreamingActive { - VStack(alignment: .leading, spacing: 6) { - Text("Assistant") - .font(.caption.weight(.semibold)) - .foregroundStyle(.secondary) - Text(viewModel.streamingText) - .frame(maxWidth: .infinity, alignment: .leading) - } - .padding(14) - .background( - RoundedRectangle(cornerRadius: 16, style: .continuous) - .fill(Color.primary.opacity(0.04)) - ) + ThreadStreamingBubble(text: viewModel.streamingText) } } .frame(maxWidth: .infinity, alignment: .leading) @@ -277,64 +236,106 @@ private extension ThreadDetailView { } } - @ViewBuilder - func attachmentGallery(for images: [AgentImageAttachment]) -> some View { - ScrollView(.horizontal, showsIndicators: false) { - HStack(spacing: 10) { - ForEach(images) { image in - if let platformImage = platformImage(from: image.data) { - Image(platformImage: platformImage) - .resizable() - .scaledToFill() - .frame(width: 120, height: 120) - .clipShape(RoundedRectangle(cornerRadius: 14, style: .continuous)) - } - } + func importPhoto(from item: PhotosPickerItem) async { + isImportingPhoto = true + + defer { + isImportingPhoto = false + selectedPhotoItem = nil + } + + do { + guard let data = try await item.loadTransferable(type: Data.self) else { + viewModel.reportError("The selected photo could not be loaded.") + return } - .padding(.top, 4) + + let mimeType = preferredMIMEType(for: item) + viewModel.queueComposerImage( + data: data, + mimeType: mimeType + ) + } catch { + viewModel.reportError(error.localizedDescription) } } - func structuredOutputCard( - _ structuredOutput: AgentStructuredOutputMetadata, - for message: AgentMessage - ) -> some View { - VStack(alignment: .leading, spacing: 8) { - Text("Structured Payload") + func preferredMIMEType(for item: PhotosPickerItem) -> String { + for contentType in item.supportedContentTypes { + if let mimeType = contentType.preferredMIMEType { + return mimeType + } + } + + return "image/jpeg" + } + +} + +private struct ThreadStreamingBubble: View { + let text: String + + var body: some View { + VStack(alignment: .leading, spacing: 6) { + Text("Assistant") .font(.caption.weight(.semibold)) .foregroundStyle(.secondary) + Text(text) + .frame(maxWidth: .infinity, alignment: .leading) + } + .padding(14) + .background( + RoundedRectangle(cornerRadius: 16, style: .continuous) + .fill(Color.primary.opacity(0.04)) + ) + } +} - if isPureStructuredPayloadMessage(message) { - Text("This assistant turn resolved into a typed structured payload.") - .font(.callout) - .foregroundStyle(.secondary) - } +private struct ThreadMessageBubble: View { + let message: AgentMessage - Label(structuredOutput.formatName, systemImage: "square.stack.3d.up.fill") - .font(.caption.weight(.medium)) + var body: some View { + VStack(alignment: .leading, spacing: 6) { + Text(message.role.rawValue.capitalized) + .font(.caption.weight(.semibold)) .foregroundStyle(.secondary) - Text(structuredOutput.payload.prettyJSONString) - .font(.system(.footnote, design: .monospaced)) - .frame(maxWidth: .infinity, alignment: .leading) - .padding(12) - .background( - RoundedRectangle(cornerRadius: 14, style: .continuous) - .fill(Color.primary.opacity(0.04)) - ) - .textSelection(.enabled) + if shouldShowVisibleText { + Text(message.displayText) + .frame(maxWidth: .infinity, alignment: .leading) + } + + if let structuredOutput = message.structuredOutput { + structuredOutputCard(structuredOutput) + } + + if !message.images.isEmpty { + ThreadAttachmentGallery(images: message.images) + + Text(message.images.count == 1 ? "1 image attached" : "\(message.images.count) images attached") + .font(.caption) + .foregroundStyle(.secondary) + } } - .padding(.top, shouldShowVisibleText(for: message) ? 4 : 0) + .padding(14) + .background( + RoundedRectangle(cornerRadius: 16, style: .continuous) + .fill( + message.role == .user + ? Color.accentColor.opacity(0.12) + : Color.primary.opacity(0.04) + ) + ) } - func shouldShowVisibleText(for message: AgentMessage) -> Bool { - guard !isPureStructuredPayloadMessage(message) else { + private var shouldShowVisibleText: Bool { + guard !isPureStructuredPayloadMessage else { return false } return !message.displayText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty } - func isPureStructuredPayloadMessage(_ message: AgentMessage) -> Bool { + private var isPureStructuredPayloadMessage: Bool { guard let structuredOutput = message.structuredOutput else { return false } @@ -349,47 +350,88 @@ private extension ThreadDetailView { return parsed == structuredOutput.payload } - func importPhoto(from item: PhotosPickerItem) async { - isImportingPhoto = true - - defer { - isImportingPhoto = false - selectedPhotoItem = nil - } + @ViewBuilder + private func structuredOutputCard(_ structuredOutput: AgentStructuredOutputMetadata) -> some View { + VStack(alignment: .leading, spacing: 8) { + Text("Structured Payload") + .font(.caption.weight(.semibold)) + .foregroundStyle(.secondary) - do { - guard let data = try await item.loadTransferable(type: Data.self) else { - viewModel.reportError("The selected photo could not be loaded.") - return + if isPureStructuredPayloadMessage { + Text("This assistant turn resolved into a typed structured payload.") + .font(.callout) + .foregroundStyle(.secondary) } - let mimeType = preferredMIMEType(for: item) - viewModel.queueComposerImage( - data: data, - mimeType: mimeType - ) - } catch { - viewModel.reportError(error.localizedDescription) + Label(structuredOutput.formatName, systemImage: "square.stack.3d.up.fill") + .font(.caption.weight(.medium)) + .foregroundStyle(.secondary) + + Text(structuredOutput.payload.prettyJSONString) + .font(.system(.footnote, design: .monospaced)) + .frame(maxWidth: .infinity, alignment: .leading) + .padding(12) + .background( + RoundedRectangle(cornerRadius: 14, style: .continuous) + .fill(Color.primary.opacity(0.04)) + ) + .textSelection(.enabled) } + .padding(.top, shouldShowVisibleText ? 4 : 0) } +} - func preferredMIMEType(for item: PhotosPickerItem) -> String { - for contentType in item.supportedContentTypes { - if let mimeType = contentType.preferredMIMEType { - return mimeType +private struct ThreadAttachmentGallery: View { + let images: [AgentImageAttachment] + + var body: some View { + ScrollView(.horizontal, showsIndicators: false) { + HStack(spacing: 10) { + ForEach(images) { image in + ThreadAttachmentThumbnail(image: image) + } } + .padding(.top, 4) } + } +} - return "image/jpeg" +private struct ThreadAttachmentThumbnail: View { + let image: AgentImageAttachment + + var body: some View { + Group { + if let platformImage = ThreadAttachmentImageCache.image(for: image) { + Image(platformImage: platformImage) + .resizable() + .scaledToFill() + .frame(width: 120, height: 120) + .clipShape(RoundedRectangle(cornerRadius: 14, style: .continuous)) + } + } } +} #if canImport(UIKit) - func platformImage(from data: Data) -> UIImage? { - UIImage(data: data) - } +private typealias ThreadPlatformImage = UIImage #elseif canImport(AppKit) - func platformImage(from data: Data) -> NSImage? { - NSImage(data: data) - } +private typealias ThreadPlatformImage = NSImage #endif + +@MainActor +private enum ThreadAttachmentImageCache { + private static let cache = NSCache() + + static func image(for attachment: AgentImageAttachment) -> ThreadPlatformImage? { + let key = attachment.id as NSString + if let cached = cache.object(forKey: key) { + return cached + } + + guard let image = ThreadPlatformImage(data: attachment.data) else { + return nil + } + cache.setObject(image, forKey: key) + return image + } } diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift index ed632d6..3e5c8a3 100644 --- a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift @@ -6,6 +6,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private let url: URL private let legacyStateURL: URL? + private let attachmentStore: RuntimeAttachmentStore private let databaseExistedAtInitialization: Bool private let dbQueue: DatabaseQueue private let migrator: DatabaseMigrator @@ -17,8 +18,14 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, ) throws { self.url = url let fileManager = FileManager.default + let basename = url.deletingPathExtension().lastPathComponent self.databaseExistedAtInitialization = fileManager.fileExists(atPath: url.path) self.legacyStateURL = legacyStateURL ?? Self.defaultLegacyImportURL(for: url) + self.attachmentStore = RuntimeAttachmentStore( + rootURL: url.deletingLastPathComponent() + .appendingPathComponent("\(basename).codexkit-state", isDirectory: true) + .appendingPathComponent("attachments", isDirectory: true) + ) let directory = url.deletingLastPathComponent() if !directory.path.isEmpty { @@ -72,10 +79,10 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, (row.threadID, try Self.decodeSummary(from: row)) } ) - let historyByThread = try Dictionary( - grouping: historyRows.map { try Self.decodeHistoryRecord(from: $0) }, - by: \.item.threadID - ) + let decodedHistoryRows = try historyRows.map { + try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) + } + let historyByThread = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) return StoredRuntimeState( threads: threads, @@ -89,8 +96,13 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try await ensurePrepared() let normalized = state.normalized() + try attachmentStore.reset() try await dbQueue.write { db in - try Self.replaceDatabaseContents(with: normalized, in: db) + try Self.replaceDatabaseContents( + with: normalized, + in: db, + attachmentStore: attachmentStore + ) } } @@ -106,17 +118,23 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, } try await dbQueue.write { db in - var partialState = try Self.loadPartialState(for: affectedThreadIDs, from: db) + var partialState = try Self.loadPartialState( + for: affectedThreadIDs, + from: db, + attachmentStore: attachmentStore + ) partialState = try partialState.applying(operations) for threadID in affectedThreadIDs { try Self.deletePersistedThread(threadID, in: db) + try attachmentStore.removeThread(threadID) } try Self.persistThreads( ids: affectedThreadIDs, from: partialState, - in: db + in: db, + attachmentStore: attachmentStore ) } } @@ -150,7 +168,8 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, return try Self.fetchHistoryPage( threadID: id, query: query, - in: db + in: db, + attachmentStore: attachmentStore ) } } @@ -222,7 +241,8 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, createdAtRange: query.createdAtRange, turnID: query.turnID, includeRedacted: query.includeRedacted, - in: db + in: db, + attachmentStore: attachmentStore ) let state = StoredRuntimeState( @@ -301,7 +321,8 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private static func replaceDatabaseContents( with normalized: StoredRuntimeState, - in db: Database + in db: Database, + attachmentStore: RuntimeAttachmentStore ) throws { let threadRows = try normalized.threads.map(Self.makeThreadRow) let summaryRows = try normalized.threads.compactMap { thread -> RuntimeSummaryRow? in @@ -312,7 +333,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, } let historyRows = try normalized.historyByThread.values .flatMap { $0 } - .map(Self.makeHistoryRow) + .map { try Self.makeHistoryRow(from: $0, attachmentStore: attachmentStore) } let structuredOutputRows = try Self.structuredOutputRows(from: normalized.historyByThread) try RuntimeStructuredOutputRow.deleteAll(db) @@ -366,20 +387,26 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, } try await dbQueue.write { db in - try Self.replaceDatabaseContents(with: state, in: db) + try attachmentStore.reset() + try Self.replaceDatabaseContents( + with: state, + in: db, + attachmentStore: attachmentStore + ) } } private static func loadPartialState( for threadIDs: Set, - from db: Database + from db: Database, + attachmentStore: RuntimeAttachmentStore ) throws -> StoredRuntimeState { guard !threadIDs.isEmpty else { return .empty } let ids = Array(threadIDs) - let placeholders = sqlPlaceholders(count: ids.count) + let placeholders = Self.sqlPlaceholders(count: ids.count) let threadRows = try RuntimeThreadRow.fetchAll( db, sql: "SELECT * FROM \(RuntimeThreadRow.databaseTableName) WHERE threadID IN \(placeholders)", @@ -401,13 +428,13 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, ) let threads = try threadRows.map { try Self.decodeThread(from: $0) } - let summaries = try Dictionary( + let summaries = try Dictionary( uniqueKeysWithValues: summaryRows.map { ($0.threadID, try Self.decodeSummary(from: $0)) } ) - let history = try Dictionary( - grouping: historyRows.map { try Self.decodeHistoryRecord(from: $0) }, - by: \.item.threadID - ) + let decodedHistoryRows = try historyRows.map { + try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) + } + let history = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) let nextSequence = history.mapValues { ($0.last?.sequenceNumber ?? 0) + 1 } return StoredRuntimeState( @@ -421,7 +448,8 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private static func persistThreads( ids threadIDs: Set, from state: StoredRuntimeState, - in db: Database + in db: Database, + attachmentStore: RuntimeAttachmentStore ) throws { let normalized = state.normalized() let threads = normalized.threads.filter { threadIDs.contains($0.id) } @@ -435,7 +463,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try Self.makeSummaryRow(from: summary).insert(db) } for record in normalized.historyByThread[thread.id] ?? [] { - try Self.makeHistoryRow(from: record).insert(db) + try Self.makeHistoryRow(from: record, attachmentStore: attachmentStore).insert(db) } } @@ -462,7 +490,8 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, createdAtRange: ClosedRange?, turnID: String?, includeRedacted: Bool, - in db: Database + in db: Database, + attachmentStore: RuntimeAttachmentStore ) throws -> [AgentHistoryRecord] { var clauses = ["threadID = ?"] var arguments: [any DatabaseValueConvertible] = [threadID] @@ -494,13 +523,14 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, db, sql: sql, arguments: StatementArguments(arguments) - ).map { try Self.decodeHistoryRecord(from: $0) } + ).map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } } private static func fetchHistoryPage( threadID: String, query: AgentHistoryQuery, - in db: Database + in db: Database, + attachmentStore: RuntimeAttachmentStore ) throws -> AgentThreadHistoryPage { let limit = max(1, query.limit) let kinds = historyKinds(from: query.filter) @@ -533,7 +563,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, let hasMoreBefore = fetched.count > limit let pageRowsDescending = Array(fetched.prefix(limit)) let pageRecords = try pageRowsDescending - .map { try Self.decodeHistoryRecord(from: $0) } + .map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } .reversed() let hasMoreAfter: Bool @@ -583,7 +613,9 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, ) let hasMoreAfter = fetched.count > limit let pageRows = Array(fetched.prefix(limit)) - let pageRecords = try pageRows.map { try Self.decodeHistoryRecord(from: $0) } + let pageRecords = try pageRows.map { + try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) + } let hasMoreBefore: Bool if let anchor { @@ -878,8 +910,15 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, ) } - private static func makeHistoryRow(from record: AgentHistoryRecord) throws -> RuntimeHistoryRow { - RuntimeHistoryRow( + private static func makeHistoryRow( + from record: AgentHistoryRecord, + attachmentStore: RuntimeAttachmentStore + ) throws -> RuntimeHistoryRow { + let persisted = try PersistedAgentHistoryRecord( + record: record, + attachmentStore: attachmentStore + ) + return RuntimeHistoryRow( storageID: "\(record.item.threadID):\(record.sequenceNumber)", recordID: record.id, threadID: record.item.threadID, @@ -888,7 +927,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, kind: record.item.kind.rawValue, turnID: record.item.turnID, isRedacted: record.redaction != nil, - encodedRecord: try JSONEncoder().encode(record) + encodedRecord: try JSONEncoder().encode(persisted) ) } @@ -947,8 +986,15 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try JSONDecoder().decode(AgentThreadSummary.self, from: row.encodedSummary) } - private static func decodeHistoryRecord(from row: RuntimeHistoryRow) throws -> AgentHistoryRecord { - try JSONDecoder().decode(AgentHistoryRecord.self, from: row.encodedRecord) + private static func decodeHistoryRecord( + from row: RuntimeHistoryRow, + attachmentStore: RuntimeAttachmentStore + ) throws -> AgentHistoryRecord { + let decoder = JSONDecoder() + if let persisted = try? decoder.decode(PersistedAgentHistoryRecord.self, from: row.encodedRecord) { + return try persisted.decode(using: attachmentStore) + } + return try decoder.decode(AgentHistoryRecord.self, from: row.encodedRecord) } private static func decodeStructuredOutputRecord(from row: RuntimeStructuredOutputRow) throws -> AgentStructuredOutputRecord { diff --git a/Sources/CodexKit/Runtime/RuntimeAttachmentPersistence.swift b/Sources/CodexKit/Runtime/RuntimeAttachmentPersistence.swift new file mode 100644 index 0000000..84b1a66 --- /dev/null +++ b/Sources/CodexKit/Runtime/RuntimeAttachmentPersistence.swift @@ -0,0 +1,285 @@ +import Foundation + +struct RuntimeAttachmentStore: Sendable { + let rootURL: URL + + func prepare() throws { + try FileManager.default.createDirectory( + at: rootURL, + withIntermediateDirectories: true + ) + } + + func reset() throws { + if FileManager.default.fileExists(atPath: rootURL.path) { + try FileManager.default.removeItem(at: rootURL) + } + try prepare() + } + + func removeThread(_ threadID: String) throws { + let threadURL = rootURL.appendingPathComponent(sanitizedPathComponent(threadID), isDirectory: true) + guard FileManager.default.fileExists(atPath: threadURL.path) else { + return + } + try FileManager.default.removeItem(at: threadURL) + } + + func persist( + _ attachment: AgentImageAttachment, + threadID: String, + recordID: String, + index: Int + ) throws -> PersistedImageAttachment { + try prepare() + + let threadComponent = sanitizedPathComponent(threadID) + let recordComponent = sanitizedPathComponent(recordID) + let fileName = "\(index)-\(sanitizedPathComponent(attachment.id)).\(fileExtension(for: attachment.mimeType))" + let relativePath = threadComponent + "/" + recordComponent + "/" + fileName + let fileURL = rootURL + .appendingPathComponent(threadComponent, isDirectory: true) + .appendingPathComponent(recordComponent, isDirectory: true) + .appendingPathComponent(fileName, isDirectory: false) + + try FileManager.default.createDirectory( + at: fileURL.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + try attachment.data.write(to: fileURL, options: .atomic) + + return PersistedImageAttachment( + id: attachment.id, + mimeType: attachment.mimeType, + storageKey: relativePath + ) + } + + func load(_ attachment: PersistedImageAttachment) throws -> AgentImageAttachment { + let fileURL = rootURL.appendingPathComponent(attachment.storageKey, isDirectory: false) + let data = try Data(contentsOf: fileURL) + return AgentImageAttachment( + id: attachment.id, + mimeType: attachment.mimeType, + data: data + ) + } + + private func fileExtension(for mimeType: String) -> String { + switch mimeType.lowercased() { + case "image/jpeg", "image/jpg": + return "jpg" + case "image/png": + return "png" + case "image/gif": + return "gif" + case "image/webp": + return "webp" + case "image/heic": + return "heic" + default: + return "bin" + } + } + + private func sanitizedPathComponent(_ value: String) -> String { + let allowed = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "-_.")) + let scalars = value.unicodeScalars.map { scalar -> Character in + allowed.contains(scalar) ? Character(scalar) : "_" + } + let result = String(scalars) + return result.isEmpty ? UUID().uuidString : result + } +} + +struct PersistedImageAttachment: Codable, Hashable { + let id: String + let mimeType: String + let storageKey: String +} + +struct PersistedAgentMessage: Codable, Hashable { + let id: String + let threadID: String + let role: AgentRole + let text: String + let images: [PersistedImageAttachment] + let structuredOutput: AgentStructuredOutputMetadata? + let createdAt: Date + + init( + message: AgentMessage, + attachmentStore: RuntimeAttachmentStore + ) throws { + self.id = message.id + self.threadID = message.threadID + self.role = message.role + self.text = message.text + self.images = try message.images.enumerated().map { index, attachment in + try attachmentStore.persist( + attachment, + threadID: message.threadID, + recordID: message.id, + index: index + ) + } + self.structuredOutput = message.structuredOutput + self.createdAt = message.createdAt + } + + func decode(using attachmentStore: RuntimeAttachmentStore) throws -> AgentMessage { + AgentMessage( + id: id, + threadID: threadID, + role: role, + text: text, + images: try images.map { try attachmentStore.load($0) }, + structuredOutput: structuredOutput, + createdAt: createdAt + ) + } +} + +enum PersistedAgentHistoryItem: Hashable { + case message(PersistedAgentMessage) + case toolCall(AgentToolCallRecord) + case toolResult(AgentToolResultRecord) + case structuredOutput(AgentStructuredOutputRecord) + case approval(AgentApprovalRecord) + case systemEvent(AgentSystemEventRecord) + + init( + item: AgentHistoryItem, + attachmentStore: RuntimeAttachmentStore + ) throws { + switch item { + case let .message(message): + self = .message(try PersistedAgentMessage( + message: message, + attachmentStore: attachmentStore + )) + case let .toolCall(record): + self = .toolCall(record) + case let .toolResult(record): + self = .toolResult(record) + case let .structuredOutput(record): + self = .structuredOutput(record) + case let .approval(record): + self = .approval(record) + case let .systemEvent(record): + self = .systemEvent(record) + } + } + + func decode(using attachmentStore: RuntimeAttachmentStore) throws -> AgentHistoryItem { + switch self { + case let .message(message): + return .message(try message.decode(using: attachmentStore)) + case let .toolCall(record): + return .toolCall(record) + case let .toolResult(record): + return .toolResult(record) + case let .structuredOutput(record): + return .structuredOutput(record) + case let .approval(record): + return .approval(record) + case let .systemEvent(record): + return .systemEvent(record) + } + } +} + +extension PersistedAgentHistoryItem: Codable { + private enum CodingKeys: String, CodingKey { + case kind + case message + case toolCall + case toolResult + case structuredOutput + case approval + case systemEvent + } + + private enum Kind: String, Codable { + case message + case toolCall + case toolResult + case structuredOutput + case approval + case systemEvent + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + switch try container.decode(Kind.self, forKey: .kind) { + case .message: + self = .message(try container.decode(PersistedAgentMessage.self, forKey: .message)) + case .toolCall: + self = .toolCall(try container.decode(AgentToolCallRecord.self, forKey: .toolCall)) + case .toolResult: + self = .toolResult(try container.decode(AgentToolResultRecord.self, forKey: .toolResult)) + case .structuredOutput: + self = .structuredOutput(try container.decode(AgentStructuredOutputRecord.self, forKey: .structuredOutput)) + case .approval: + self = .approval(try container.decode(AgentApprovalRecord.self, forKey: .approval)) + case .systemEvent: + self = .systemEvent(try container.decode(AgentSystemEventRecord.self, forKey: .systemEvent)) + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case let .message(message): + try container.encode(Kind.message, forKey: .kind) + try container.encode(message, forKey: .message) + case let .toolCall(record): + try container.encode(Kind.toolCall, forKey: .kind) + try container.encode(record, forKey: .toolCall) + case let .toolResult(record): + try container.encode(Kind.toolResult, forKey: .kind) + try container.encode(record, forKey: .toolResult) + case let .structuredOutput(record): + try container.encode(Kind.structuredOutput, forKey: .kind) + try container.encode(record, forKey: .structuredOutput) + case let .approval(record): + try container.encode(Kind.approval, forKey: .kind) + try container.encode(record, forKey: .approval) + case let .systemEvent(record): + try container.encode(Kind.systemEvent, forKey: .kind) + try container.encode(record, forKey: .systemEvent) + } + } +} + +struct PersistedAgentHistoryRecord: Codable, Hashable { + let id: String + let sequenceNumber: Int + let createdAt: Date + let item: PersistedAgentHistoryItem + let redaction: AgentHistoryRedaction? + + init( + record: AgentHistoryRecord, + attachmentStore: RuntimeAttachmentStore + ) throws { + self.id = record.id + self.sequenceNumber = record.sequenceNumber + self.createdAt = record.createdAt + self.item = try PersistedAgentHistoryItem( + item: record.item, + attachmentStore: attachmentStore + ) + self.redaction = record.redaction + } + + func decode(using attachmentStore: RuntimeAttachmentStore) throws -> AgentHistoryRecord { + AgentHistoryRecord( + id: id, + sequenceNumber: sequenceNumber, + createdAt: createdAt, + item: try item.decode(using: attachmentStore), + redaction: redaction + ) + } +} diff --git a/Sources/CodexKit/Runtime/RuntimeStateStore.swift b/Sources/CodexKit/Runtime/RuntimeStateStore.swift index 650aec8..4c1c02d 100644 --- a/Sources/CodexKit/Runtime/RuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/RuntimeStateStore.swift @@ -166,9 +166,16 @@ public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private let encoder = JSONEncoder() private let decoder = JSONDecoder() private let fileManager = FileManager.default + private let attachmentStore: RuntimeAttachmentStore public init(url: URL) { self.url = url + let basename = url.deletingPathExtension().lastPathComponent + self.attachmentStore = RuntimeAttachmentStore( + rootURL: url.deletingLastPathComponent() + .appendingPathComponent("\(basename).codexkit-state", isDirectory: true) + .appendingPathComponent("attachments", isDirectory: true) + ) } public func loadState() async throws -> StoredRuntimeState { @@ -283,6 +290,9 @@ public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, } let data = try Data(contentsOf: historyURL) + if let persisted = try? decoder.decode([PersistedAgentHistoryRecord].self, from: data) { + return try persisted.map { try $0.decode(using: attachmentStore) } + } return try decoder.decode([AgentHistoryRecord].self, from: data) } @@ -300,11 +310,18 @@ public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, at: historyDirectoryURL, withIntermediateDirectories: true ) + try attachmentStore.reset() for thread in normalized.threads { let historyURL = historyFileURL(for: thread.id) let history = normalized.historyByThread[thread.id] ?? [] - let data = try encoder.encode(history) + let persisted = try history.map { + try PersistedAgentHistoryRecord( + record: $0, + attachmentStore: attachmentStore + ) + } + let data = try encoder.encode(persisted) try data.write(to: historyURL, options: .atomic) } diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift index b20aaa6..f63d1f4 100644 --- a/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift @@ -610,6 +610,67 @@ extension AgentRuntimeTests { ) XCTAssertFalse(history.records.isEmpty) } + + func testGRDBRuntimeStateStoreExternalizesImageAttachments() async throws { + let url = temporaryRuntimeSQLiteURL() + let attachmentsDirectory = url.deletingLastPathComponent() + .appendingPathComponent("\(url.deletingPathExtension().lastPathComponent).codexkit-state", isDirectory: true) + .appendingPathComponent("attachments", isDirectory: true) + defer { + try? FileManager.default.removeItem(at: url) + try? FileManager.default.removeItem(at: attachmentsDirectory.deletingLastPathComponent()) + } + + let imageData = Data([0x89, 0x50, 0x4E, 0x47, 0xDE, 0xAD, 0xBE, 0xEF]) + let runtime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: try GRDBRuntimeStateStore(url: url) + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Attachment Thread") + _ = try await runtime.sendMessage( + UserMessageRequest( + text: "here is an image", + images: [.png(imageData)] + ), + in: thread.id + ) + + let reloadedRuntime = try makeHistoryRuntime( + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: try GRDBRuntimeStateStore(url: url) + ) + + let history = try await reloadedRuntime.execute( + HistoryItemsQuery(threadID: thread.id, kinds: [.message]) + ) + guard let userMessage = history.records.compactMap({ record -> AgentMessage? in + guard case let .message(message) = record.item, message.role == .user else { + return nil + } + return message + }).first else { + return XCTFail("Expected a persisted user message with an attachment.") + } + + XCTAssertEqual(userMessage.images.count, 1) + XCTAssertEqual(userMessage.images.first?.data, imageData) + + let attachmentFiles = try FileManager.default.contentsOfDirectory( + at: attachmentsDirectory.appendingPathComponent(thread.id, isDirectory: true) + .appendingPathComponent(userMessage.id, isDirectory: true), + includingPropertiesForKeys: nil + ) + XCTAssertEqual(attachmentFiles.count, 1) + + let databaseData = try Data(contentsOf: url) + XCTAssertNil(databaseData.range(of: imageData.base64EncodedData())) + } } private func makeHistoryRuntime( From b22cef595f307989ec55360fc1f177cac913a354 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Mon, 23 Mar 2026 22:33:43 +1100 Subject: [PATCH 04/19] Migrate SQLiteMemoryStore to GRDB --- .../CodexKit/Memory/SQLiteMemoryStore.swift | 913 ++++++++---------- 1 file changed, 394 insertions(+), 519 deletions(-) diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift index b7b6c10..b7303ed 100644 --- a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift +++ b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift @@ -1,67 +1,86 @@ import Foundation -import SQLite3 +import GRDB public actor SQLiteMemoryStore: MemoryStoring { private static let currentSchemaVersion = 1 - private let url: URL - private nonisolated(unsafe) var database: OpaquePointer? - private let encoder = JSONEncoder() - private let decoder = JSONDecoder() + private let url: URL + private let dbQueue: DatabaseQueue + private let migrator: DatabaseMigrator public init(url: URL) throws { self.url = url - self.database = try Self.openDatabase(at: url) - try Self.migrateIfNeeded(in: database) - } - deinit { - if let database { - sqlite3_close(database) + let directory = url.deletingLastPathComponent() + if !directory.path.isEmpty { + try FileManager.default.createDirectory( + at: directory, + withIntermediateDirectories: true + ) } + + var configuration = Configuration() + configuration.foreignKeysEnabled = true + configuration.label = "CodexKit.SQLiteMemoryStore" + dbQueue = try DatabaseQueue(path: url.path, configuration: configuration) + migrator = Self.makeMigrator() + + let existingVersion = try dbQueue.read { db in + try Self.schemaVersion(in: db) + } + if existingVersion > Self.currentSchemaVersion { + throw MemoryStoreError.unsupportedSchemaVersion(existingVersion) + } + + try migrator.migrate(dbQueue) } public func put(_ record: MemoryRecord) async throws { try MemoryQueryEngine.validateNamespace(record.namespace) - try transaction { - try ensureRecordIDAvailable(record.id, namespace: record.namespace) + try await writeTransaction { db in + try Self.ensureRecordIDAvailable(record.id, namespace: record.namespace, in: db) if let dedupeKey = record.dedupeKey { - try ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace) + try Self.ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace, in: db) } - try upsertRecord(record) + try Self.upsertRecord(record, in: db) } } public func putMany(_ records: [MemoryRecord]) async throws { - try transaction { + try await writeTransaction { db in for record in records { try MemoryQueryEngine.validateNamespace(record.namespace) - try ensureRecordIDAvailable(record.id, namespace: record.namespace) + try Self.ensureRecordIDAvailable(record.id, namespace: record.namespace, in: db) if let dedupeKey = record.dedupeKey { - try ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace) + try Self.ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace, in: db) } - try upsertRecord(record) + try Self.upsertRecord(record, in: db) } } } public func upsert(_ record: MemoryRecord, dedupeKey: String) async throws { try MemoryQueryEngine.validateNamespace(record.namespace) - try transaction { - try deleteRecord(withDedupeKey: dedupeKey, namespace: record.namespace) - try deleteRecord(id: record.id, namespace: record.namespace) + try await writeTransaction { db in + try Self.deleteRecord(withDedupeKey: dedupeKey, namespace: record.namespace, in: db) + try Self.deleteRecord(id: record.id, namespace: record.namespace, in: db) var updatedRecord = record updatedRecord.dedupeKey = dedupeKey - try upsertRecord(updatedRecord) + try Self.upsertRecord(updatedRecord, in: db) } } public func query(_ query: MemoryQuery) async throws -> MemoryQueryResult { try MemoryQueryEngine.validateNamespace(query.namespace) - let records = try loadRecords(namespace: query.namespace) - let rawScores = try loadFTSRawScores( - namespace: query.namespace, - queryText: query.text - ) + let records = try await dbQueue.read { db in + try Self.loadRecords(namespace: query.namespace, in: db) + } + let rawScores = try await dbQueue.read { db in + try Self.loadFTSRawScores( + namespace: query.namespace, + queryText: query.text, + in: db + ) + } let candidates = records.map { record in MemoryQueryEngine.Candidate( @@ -82,41 +101,50 @@ public actor SQLiteMemoryStore: MemoryStoring { namespace: String ) async throws -> MemoryRecord? { try MemoryQueryEngine.validateNamespace(namespace) - return try loadRecords(namespace: namespace).first { $0.id == id } + return try await dbQueue.read { db in + try Self.loadRecords(namespace: namespace, in: db).first { $0.id == id } + } } public func list(_ query: MemoryRecordListQuery) async throws -> [MemoryRecord] { try MemoryQueryEngine.validateNamespace(query.namespace) - return try loadRecords(namespace: query.namespace) - .filter { record in - if !query.includeArchived, record.status == .archived { - return false + return try await dbQueue.read { db in + try Self.loadRecords(namespace: query.namespace, in: db) + .filter { record in + if !query.includeArchived, record.status == .archived { + return false + } + if !query.scopes.isEmpty, !query.scopes.contains(record.scope) { + return false + } + if !query.kinds.isEmpty, !query.kinds.contains(record.kind) { + return false + } + return true } - if !query.scopes.isEmpty, !query.scopes.contains(record.scope) { - return false + .sorted { + if $0.effectiveDate == $1.effectiveDate { + return $0.id < $1.id + } + return $0.effectiveDate > $1.effectiveDate } - if !query.kinds.isEmpty, !query.kinds.contains(record.kind) { - return false - } - return true - } - .sorted { - if $0.effectiveDate == $1.effectiveDate { - return $0.id < $1.id - } - return $0.effectiveDate > $1.effectiveDate - } - .prefix(query.limit ?? .max) - .map { $0 } + .prefix(query.limit ?? .max) + .map { $0 } + } } public func diagnostics(namespace: String) async throws -> MemoryStoreDiagnostics { try MemoryQueryEngine.validateNamespace(namespace) - let records = try loadRecords(namespace: namespace) + let records = try await dbQueue.read { db in + try Self.loadRecords(namespace: namespace, in: db) + } + let schemaVersion = try await dbQueue.read { db in + try Self.schemaVersion(in: db) + } return MemoryStoreDiagnostics( namespace: namespace, implementation: "sqlite", - schemaVersion: try Self.schemaVersion(in: database), + schemaVersion: schemaVersion, totalRecords: records.count, activeRecords: records.filter { $0.status == .active }.count, archivedRecords: records.filter { $0.status == .archived }.count, @@ -127,32 +155,40 @@ public actor SQLiteMemoryStore: MemoryStoring { public func compact(_ request: MemoryCompactionRequest) async throws { try MemoryQueryEngine.validateNamespace(request.replacement.namespace) - try transaction { - try ensureRecordIDAvailable(request.replacement.id, namespace: request.replacement.namespace) + try await writeTransaction { db in + try Self.ensureRecordIDAvailable( + request.replacement.id, + namespace: request.replacement.namespace, + in: db + ) if let dedupeKey = request.replacement.dedupeKey { - try ensureDedupeKeyAvailable(dedupeKey, namespace: request.replacement.namespace) + try Self.ensureDedupeKeyAvailable( + dedupeKey, + namespace: request.replacement.namespace, + in: db + ) } - try upsertRecord(request.replacement) + try Self.upsertRecord(request.replacement, in: db) for sourceID in request.sourceIDs { - try archiveRecord(id: sourceID, namespace: request.replacement.namespace) + try Self.archiveRecord(id: sourceID, namespace: request.replacement.namespace, in: db) } } } public func archive(ids: [String], namespace: String) async throws { try MemoryQueryEngine.validateNamespace(namespace) - try transaction { + try await writeTransaction { db in for id in ids { - try archiveRecord(id: id, namespace: namespace) + try Self.archiveRecord(id: id, namespace: namespace, in: db) } } } public func delete(ids: [String], namespace: String) async throws { try MemoryQueryEngine.validateNamespace(namespace) - try transaction { + try await writeTransaction { db in for id in ids { - try deleteRecord(id: id, namespace: namespace) + try Self.deleteRecord(id: id, namespace: namespace, in: db) } } } @@ -163,484 +199,303 @@ public actor SQLiteMemoryStore: MemoryStoring { namespace: String ) async throws -> Int { try MemoryQueryEngine.validateNamespace(namespace) - let expiredIDs = try loadRecords(namespace: namespace) - .filter { record in - !record.isPinned && - record.status == .active && - (record.expiresAt?.compare(now) == .orderedAscending || - record.expiresAt?.compare(now) == .orderedSame) - } - .map(\.id) + let expiredIDs = try await dbQueue.read { db in + try Self.loadRecords(namespace: namespace, in: db) + .filter { record in + !record.isPinned && + record.status == .active && + (record.expiresAt?.compare(now) == .orderedAscending || + record.expiresAt?.compare(now) == .orderedSame) + } + .map(\.id) + } - try transaction { + try await writeTransaction { db in for id in expiredIDs { - try deleteRecord(id: id, namespace: namespace) + try Self.deleteRecord(id: id, namespace: namespace, in: db) } } return expiredIDs.count } - private static func openDatabase(at url: URL) throws -> OpaquePointer { - let directory = url.deletingLastPathComponent() - if !directory.path.isEmpty { - try FileManager.default.createDirectory( - at: directory, - withIntermediateDirectories: true - ) - } - - var database: OpaquePointer? - let result = sqlite3_open_v2( - url.path, - &database, - SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE | SQLITE_OPEN_FULLMUTEX, - nil - ) - guard result == SQLITE_OK, let database else { - throw sqliteError( - database, - message: "Failed to open SQLite memory store." - ) - } - sqlite3_exec(database, "PRAGMA foreign_keys = ON;", nil, nil, nil) - return database - } - - private static func migrateIfNeeded(in database: OpaquePointer?) throws { - let existingVersion = try schemaVersion(in: database) - if existingVersion > currentSchemaVersion { - throw MemoryStoreError.unsupportedSchemaVersion(existingVersion) - } - - try createSchemaIfNeeded(in: database) - if existingVersion < currentSchemaVersion { - try setSchemaVersion(currentSchemaVersion, in: database) - } - } - - private static func createSchemaIfNeeded(in database: OpaquePointer?) throws { - let schema = """ - CREATE TABLE IF NOT EXISTS memory_records ( - namespace TEXT NOT NULL, - id TEXT NOT NULL, - scope TEXT NOT NULL, - kind TEXT NOT NULL, - summary TEXT NOT NULL, - evidence_json TEXT NOT NULL, - importance REAL NOT NULL, - created_at REAL NOT NULL, - observed_at REAL, - expires_at REAL, - tags_json TEXT NOT NULL, - related_ids_json TEXT NOT NULL, - dedupe_key TEXT, - is_pinned INTEGER NOT NULL, - attributes_json TEXT, - status TEXT NOT NULL, - PRIMARY KEY(namespace, id) - ); - CREATE UNIQUE INDEX IF NOT EXISTS memory_records_namespace_dedupe - ON memory_records(namespace, dedupe_key) - WHERE dedupe_key IS NOT NULL; - CREATE INDEX IF NOT EXISTS memory_records_namespace_scope - ON memory_records(namespace, scope); - CREATE INDEX IF NOT EXISTS memory_records_namespace_kind - ON memory_records(namespace, kind); - CREATE INDEX IF NOT EXISTS memory_records_namespace_status - ON memory_records(namespace, status); - CREATE TABLE IF NOT EXISTS memory_tags ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - tag TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS memory_tags_lookup - ON memory_tags(namespace, tag, record_id); - CREATE TABLE IF NOT EXISTS memory_related_ids ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - related_id TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS memory_related_lookup - ON memory_related_ids(namespace, related_id, record_id); - CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts - USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); - """ - try execSQL(database, schema) - } - - private static func schemaVersion(in database: OpaquePointer?) throws -> Int { - let statement = try prepareSQL(database, "PRAGMA user_version;") - defer { sqlite3_finalize(statement) } - guard sqlite3_step(statement) == SQLITE_ROW else { - throw sqliteError(database, message: "Failed to read SQLite schema version.") + private func writeTransaction(_ operation: @escaping @Sendable (Database) throws -> Void) async throws { + try await dbQueue.writeWithoutTransaction { db in + try db.inTransaction { + try operation(db) + return .commit + } } - return Int(sqlite3_column_int(statement, 0)) } - private static func setSchemaVersion( - _ version: Int, - in database: OpaquePointer? - ) throws { - try execSQL(database, "PRAGMA user_version = \(version);") - } - - private func ensureRecordIDAvailable( + private static func ensureRecordIDAvailable( _ id: String, - namespace: String + namespace: String, + in db: Database ) throws { - if try recordExists(id: id, namespace: namespace) { + if try recordExists(id: id, namespace: namespace, in: db) { throw MemoryStoreError.duplicateRecordID(id) } } - private func ensureDedupeKeyAvailable( + private static func ensureDedupeKeyAvailable( _ dedupeKey: String, - namespace: String + namespace: String, + in db: Database ) throws { - if try recordExists(dedupeKey: dedupeKey, namespace: namespace) { + if try recordExists(dedupeKey: dedupeKey, namespace: namespace, in: db) { throw MemoryStoreError.duplicateDedupeKey(dedupeKey) } } - private func loadRecords(namespace: String) throws -> [MemoryRecord] { - let sql = """ - SELECT - id, scope, kind, summary, evidence_json, importance, - created_at, observed_at, expires_at, tags_json, - related_ids_json, dedupe_key, is_pinned, attributes_json, status - FROM memory_records - WHERE namespace = ?; - """ - let statement = try prepare(sql) - defer { sqlite3_finalize(statement) } - try bindText(namespace, to: statement, index: 1) - - var records: [MemoryRecord] = [] - while sqlite3_step(statement) == SQLITE_ROW { - let id = try columnText(statement, index: 0) - let scope = MemoryScope(rawValue: try columnText(statement, index: 1)) - let kind = try columnText(statement, index: 2) - let summary = try columnText(statement, index: 3) - let evidence = try decodeJSON([String].self, from: try columnText(statement, index: 4)) - let importance = sqlite3_column_double(statement, 5) - let createdAt = Date(timeIntervalSince1970: sqlite3_column_double(statement, 6)) - let observedAt = sqlite3_column_type(statement, 7) == SQLITE_NULL - ? nil - : Date(timeIntervalSince1970: sqlite3_column_double(statement, 7)) - let expiresAt = sqlite3_column_type(statement, 8) == SQLITE_NULL - ? nil - : Date(timeIntervalSince1970: sqlite3_column_double(statement, 8)) - let tags = try decodeJSON([String].self, from: try columnText(statement, index: 9)) - let relatedIDs = try decodeJSON([String].self, from: try columnText(statement, index: 10)) - let dedupeKey = sqlite3_column_type(statement, 11) == SQLITE_NULL ? nil : try columnText(statement, index: 11) - let isPinned = sqlite3_column_int(statement, 12) == 1 - let attributes = sqlite3_column_type(statement, 13) == SQLITE_NULL - ? nil - : try decodeJSON(JSONValue.self, from: try columnText(statement, index: 13)) - let status = MemoryRecordStatus(rawValue: try columnText(statement, index: 14)) ?? .active - - records.append( - MemoryRecord( - id: id, - namespace: namespace, - scope: scope, - kind: kind, - summary: summary, - evidence: evidence, - importance: importance, - createdAt: createdAt, - observedAt: observedAt, - expiresAt: expiresAt, - tags: tags, - relatedIDs: relatedIDs, - dedupeKey: dedupeKey, - isPinned: isPinned, - attributes: attributes, - status: status - ) + private static func loadRecords( + namespace: String, + in db: Database + ) throws -> [MemoryRecord] { + let rows = try Row.fetchAll( + db, + sql: """ + SELECT + id, scope, kind, summary, evidence_json, importance, + created_at, observed_at, expires_at, tags_json, + related_ids_json, dedupe_key, is_pinned, attributes_json, status + FROM memory_records + WHERE namespace = ?; + """, + arguments: [namespace] + ) + + return try rows.map { row in + MemoryRecord( + id: row["id"], + namespace: namespace, + scope: MemoryScope(rawValue: row["scope"]), + kind: row["kind"], + summary: row["summary"], + evidence: try decodeJSON([String].self, from: row["evidence_json"]), + importance: row["importance"], + createdAt: Date(timeIntervalSince1970: row["created_at"]), + observedAt: (row["observed_at"] as Double?).map(Date.init(timeIntervalSince1970:)), + expiresAt: (row["expires_at"] as Double?).map(Date.init(timeIntervalSince1970:)), + tags: try decodeJSON([String].self, from: row["tags_json"]), + relatedIDs: try decodeJSON([String].self, from: row["related_ids_json"]), + dedupeKey: row["dedupe_key"], + isPinned: (row["is_pinned"] as Int64? ?? 0) == 1, + attributes: try decodeNullableJSON(JSONValue.self, from: row["attributes_json"]), + status: MemoryRecordStatus(rawValue: row["status"]) ?? .active ) } - - return records } - private func loadFTSRawScores( + private static func loadFTSRawScores( namespace: String, - queryText: String? + queryText: String?, + in db: Database ) throws -> [String: Double] { - let matchQuery = ftsQuery(from: queryText) + let matchQuery = Self.ftsQuery(from: queryText) guard !matchQuery.isEmpty else { return [:] } - let sql = """ - SELECT record_id, bm25(memory_fts) - FROM memory_fts - WHERE namespace = ? AND memory_fts MATCH ?; - """ - let statement = try prepare(sql) - defer { sqlite3_finalize(statement) } - try bindText(namespace, to: statement, index: 1) - try bindText(matchQuery, to: statement, index: 2) + let rows = try Row.fetchAll( + db, + sql: """ + SELECT record_id, bm25(memory_fts) AS score + FROM memory_fts + WHERE namespace = ? AND memory_fts MATCH ?; + """, + arguments: [namespace, matchQuery] + ) var scores: [String: Double] = [:] - while sqlite3_step(statement) == SQLITE_ROW { - let recordID = try columnText(statement, index: 0) - let score = sqlite3_column_double(statement, 1) + for row in rows { + let recordID: String = row["record_id"] + let score: Double = row["score"] scores[recordID] = score } return scores } - private func recordExists( + private static func recordExists( id: String, - namespace: String + namespace: String, + in db: Database ) throws -> Bool { - let sql = "SELECT 1 FROM memory_records WHERE namespace = ? AND id = ? LIMIT 1;" - let statement = try prepare(sql) - defer { sqlite3_finalize(statement) } - try bindText(namespace, to: statement, index: 1) - try bindText(id, to: statement, index: 2) - return sqlite3_step(statement) == SQLITE_ROW - } - - private func recordExists( + try Bool.fetchOne( + db, + sql: """ + SELECT EXISTS( + SELECT 1 FROM memory_records WHERE namespace = ? AND id = ? + ); + """, + arguments: [namespace, id] + ) ?? false + } + + private static func recordExists( dedupeKey: String, - namespace: String + namespace: String, + in db: Database ) throws -> Bool { - let sql = "SELECT 1 FROM memory_records WHERE namespace = ? AND dedupe_key = ? LIMIT 1;" - let statement = try prepare(sql) - defer { sqlite3_finalize(statement) } - try bindText(namespace, to: statement, index: 1) - try bindText(dedupeKey, to: statement, index: 2) - return sqlite3_step(statement) == SQLITE_ROW - } - - private func archiveRecord( + try Bool.fetchOne( + db, + sql: """ + SELECT EXISTS( + SELECT 1 FROM memory_records WHERE namespace = ? AND dedupe_key = ? + ); + """, + arguments: [namespace, dedupeKey] + ) ?? false + } + + private static func archiveRecord( id: String, - namespace: String + namespace: String, + in db: Database ) throws { - let sql = """ - UPDATE memory_records - SET status = ? - WHERE namespace = ? AND id = ?; - """ - let statement = try prepare(sql) - defer { sqlite3_finalize(statement) } - try bindText(MemoryRecordStatus.archived.rawValue, to: statement, index: 1) - try bindText(namespace, to: statement, index: 2) - try bindText(id, to: statement, index: 3) - try step(statement) + try db.execute( + sql: """ + UPDATE memory_records + SET status = ? + WHERE namespace = ? AND id = ?; + """, + arguments: [MemoryRecordStatus.archived.rawValue, namespace, id] + ) } - private func deleteRecord( + private static func deleteRecord( id: String, - namespace: String + namespace: String, + in db: Database ) throws { - try exec( - "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", - bindings: [.text(namespace), .text(id)] - ) - try exec( - "DELETE FROM memory_related_ids WHERE namespace = ? AND record_id = ?;", - bindings: [.text(namespace), .text(id)] + try db.execute( + sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", + arguments: [namespace, id] ) - try exec( - "DELETE FROM memory_tags WHERE namespace = ? AND record_id = ?;", - bindings: [.text(namespace), .text(id)] - ) - try exec( - "DELETE FROM memory_records WHERE namespace = ? AND id = ?;", - bindings: [.text(namespace), .text(id)] + try db.execute( + sql: "DELETE FROM memory_records WHERE namespace = ? AND id = ?;", + arguments: [namespace, id] ) } - private func deleteRecord( + private static func deleteRecord( withDedupeKey dedupeKey: String, - namespace: String + namespace: String, + in db: Database + ) throws { + if let id = try String.fetchOne( + db, + sql: """ + SELECT id + FROM memory_records + WHERE namespace = ? AND dedupe_key = ? + LIMIT 1; + """, + arguments: [namespace, dedupeKey] + ) { + try deleteRecord(id: id, namespace: namespace, in: db) + } + } + + private static func upsertRecord( + _ record: MemoryRecord, + in db: Database ) throws { - let statement = try prepare( - "SELECT id FROM memory_records WHERE namespace = ? AND dedupe_key = ? LIMIT 1;" + try db.execute( + sql: """ + INSERT OR REPLACE INTO memory_records ( + namespace, id, scope, kind, summary, evidence_json, importance, + created_at, observed_at, expires_at, tags_json, related_ids_json, + dedupe_key, is_pinned, attributes_json, status + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + """, + arguments: [ + record.namespace, + record.id, + record.scope.rawValue, + record.kind, + record.summary, + try encodeJSON(record.evidence), + record.importance, + record.createdAt.timeIntervalSince1970, + record.observedAt?.timeIntervalSince1970, + record.expiresAt?.timeIntervalSince1970, + try encodeJSON(record.tags), + try encodeJSON(record.relatedIDs), + record.dedupeKey, + record.isPinned ? 1 : 0, + try encodeNullableJSON(record.attributes), + record.status.rawValue, + ] ) - defer { sqlite3_finalize(statement) } - try bindText(namespace, to: statement, index: 1) - try bindText(dedupeKey, to: statement, index: 2) - if sqlite3_step(statement) == SQLITE_ROW { - let id = try columnText(statement, index: 0) - try deleteRecord(id: id, namespace: namespace) - } - } - private func upsertRecord(_ record: MemoryRecord) throws { - let sql = """ - INSERT OR REPLACE INTO memory_records ( - namespace, id, scope, kind, summary, evidence_json, importance, - created_at, observed_at, expires_at, tags_json, related_ids_json, - dedupe_key, is_pinned, attributes_json, status - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); - """ - let statement = try prepare(sql) - defer { sqlite3_finalize(statement) } - - try bindText(record.namespace, to: statement, index: 1) - try bindText(record.id, to: statement, index: 2) - try bindText(record.scope.rawValue, to: statement, index: 3) - try bindText(record.kind, to: statement, index: 4) - try bindText(record.summary, to: statement, index: 5) - try bindText(try encodeJSON(record.evidence), to: statement, index: 6) - sqlite3_bind_double(statement, 7, record.importance) - sqlite3_bind_double(statement, 8, record.createdAt.timeIntervalSince1970) - if let observedAt = record.observedAt { - sqlite3_bind_double(statement, 9, observedAt.timeIntervalSince1970) - } else { - sqlite3_bind_null(statement, 9) - } - if let expiresAt = record.expiresAt { - sqlite3_bind_double(statement, 10, expiresAt.timeIntervalSince1970) - } else { - sqlite3_bind_null(statement, 10) - } - try bindText(try encodeJSON(record.tags), to: statement, index: 11) - try bindText(try encodeJSON(record.relatedIDs), to: statement, index: 12) - if let dedupeKey = record.dedupeKey { - try bindText(dedupeKey, to: statement, index: 13) - } else { - sqlite3_bind_null(statement, 13) - } - sqlite3_bind_int(statement, 14, record.isPinned ? 1 : 0) - if let attributes = record.attributes { - try bindText(try encodeJSON(attributes), to: statement, index: 15) - } else { - sqlite3_bind_null(statement, 15) - } - try bindText(record.status.rawValue, to: statement, index: 16) - try step(statement) - - try exec( - "DELETE FROM memory_tags WHERE namespace = ? AND record_id = ?;", - bindings: [.text(record.namespace), .text(record.id)] + try db.execute( + sql: "DELETE FROM memory_tags WHERE namespace = ? AND record_id = ?;", + arguments: [record.namespace, record.id] ) - try exec( - "DELETE FROM memory_related_ids WHERE namespace = ? AND record_id = ?;", - bindings: [.text(record.namespace), .text(record.id)] + try db.execute( + sql: "DELETE FROM memory_related_ids WHERE namespace = ? AND record_id = ?;", + arguments: [record.namespace, record.id] ) - try exec( - "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", - bindings: [.text(record.namespace), .text(record.id)] + try db.execute( + sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", + arguments: [record.namespace, record.id] ) for tag in record.tags { - try exec( - "INSERT INTO memory_tags(namespace, record_id, tag) VALUES (?, ?, ?);", - bindings: [.text(record.namespace), .text(record.id), .text(tag)] + try db.execute( + sql: "INSERT INTO memory_tags(namespace, record_id, tag) VALUES (?, ?, ?);", + arguments: [record.namespace, record.id, tag] ) } for relatedID in record.relatedIDs { - try exec( - "INSERT INTO memory_related_ids(namespace, record_id, related_id) VALUES (?, ?, ?);", - bindings: [.text(record.namespace), .text(record.id), .text(relatedID)] + try db.execute( + sql: """ + INSERT INTO memory_related_ids(namespace, record_id, related_id) + VALUES (?, ?, ?); + """, + arguments: [record.namespace, record.id, relatedID] ) } - let ftsContent = ([record.summary] + record.evidence + record.tags + [record.kind]).joined(separator: " ") - try exec( - "INSERT INTO memory_fts(namespace, record_id, content) VALUES (?, ?, ?);", - bindings: [.text(record.namespace), .text(record.id), .text(ftsContent)] + let ftsContent = ([record.summary] + record.evidence + record.tags + [record.kind]) + .joined(separator: " ") + try db.execute( + sql: "INSERT INTO memory_fts(namespace, record_id, content) VALUES (?, ?, ?);", + arguments: [record.namespace, record.id, ftsContent] ) } - private func transaction(_ operation: () throws -> Void) throws { - try exec("BEGIN IMMEDIATE;") - do { - try operation() - try exec("COMMIT;") - } catch { - try? exec("ROLLBACK;") - throw error - } - } - - private func prepare(_ sql: String) throws -> OpaquePointer? { - try prepareSQL(database, sql) - } - - private func exec( - _ sql: String, - bindings: [SQLiteBinding] = [] - ) throws { - let statement = try prepare(sql) - defer { sqlite3_finalize(statement) } - - for (index, binding) in bindings.enumerated() { - try bind(binding, to: statement, index: Int32(index + 1)) - } - - try step(statement) - } - - private func step(_ statement: OpaquePointer?) throws { - let result = sqlite3_step(statement) - guard result == SQLITE_DONE || result == SQLITE_ROW else { - throw sqliteError(message: "SQLite step failed.") - } - } - - private func bind( - _ binding: SQLiteBinding, - to statement: OpaquePointer?, - index: Int32 - ) throws { - switch binding { - case let .text(value): - try bindText(value, to: statement, index: index) - case let .double(value): - sqlite3_bind_double(statement, index, value) - case .null: - sqlite3_bind_null(statement, index) - } - } - - private func bindText( - _ value: String, - to statement: OpaquePointer?, - index: Int32 - ) throws { - let result = sqlite3_bind_text(statement, index, value, -1, SQLITE_TRANSIENT) - guard result == SQLITE_OK else { - throw sqliteError(message: "Failed to bind SQLite text value.") - } + private static func encodeJSON(_ value: T) throws -> String { + let data = try JSONEncoder().encode(value) + return String(decoding: data, as: UTF8.self) } - private func columnText( - _ statement: OpaquePointer?, - index: Int32 - ) throws -> String { - guard let cString = sqlite3_column_text(statement, index) else { - throw sqliteError(message: "SQLite column was unexpectedly null.") + private static func encodeNullableJSON(_ value: T?) throws -> String? { + guard let value else { + return nil } - return String(cString: cString) - } - - private func encodeJSON(_ value: T) throws -> String { - let data = try encoder.encode(value) - return String(decoding: data, as: UTF8.self) + return try Self.encodeJSON(value) } - private func decodeJSON( + private static func decodeJSON( _ type: T.Type, from string: String ) throws -> T { - try decoder.decode(type, from: Data(string.utf8)) + try JSONDecoder().decode(type, from: Data(string.utf8)) } - private func ftsQuery(from value: String?) -> String { + private static func decodeNullableJSON( + _ type: T.Type, + from string: String? + ) throws -> T? { + guard let string else { + return nil + } + return try Self.decodeJSON(type, from: string) + } + + private static func ftsQuery(from value: String?) -> String { let tokens = MemoryQueryEngine.tokenize(value) guard !tokens.isEmpty else { return "" @@ -648,72 +503,92 @@ public actor SQLiteMemoryStore: MemoryStoring { return tokens.joined(separator: " OR ") } - private static func sqliteError( - _ database: OpaquePointer?, - message: String - ) -> NSError { - let detail = if let database, let messagePointer = sqlite3_errmsg(database) { - String(cString: messagePointer) - } else { - "Unknown SQLite error" - } - - return NSError( - domain: "CodexKit.SQLiteMemoryStore", - code: Int(sqlite3_errcode(database)), - userInfo: [NSLocalizedDescriptionKey: "\(message) \(detail)"] - ) - } - - private func sqliteError(message: String) -> NSError { - Self.sqliteError(database, message: message) - } -} + private static func schemaVersion(in db: Database) throws -> Int { + try Int.fetchOne(db, sql: "PRAGMA user_version;") ?? 0 + } + + private static func makeMigrator() -> DatabaseMigrator { + var migrator = DatabaseMigrator() + + migrator.registerMigration("memory_store_v1") { db in + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_records ( + namespace TEXT NOT NULL, + id TEXT NOT NULL, + scope TEXT NOT NULL, + kind TEXT NOT NULL, + summary TEXT NOT NULL, + evidence_json TEXT NOT NULL, + importance REAL NOT NULL, + created_at REAL NOT NULL, + observed_at REAL, + expires_at REAL, + tags_json TEXT NOT NULL, + related_ids_json TEXT NOT NULL, + dedupe_key TEXT, + is_pinned INTEGER NOT NULL, + attributes_json TEXT, + status TEXT NOT NULL, + PRIMARY KEY(namespace, id) + ); + """) + + try db.execute(sql: """ + CREATE UNIQUE INDEX IF NOT EXISTS memory_records_namespace_dedupe + ON memory_records(namespace, dedupe_key) + WHERE dedupe_key IS NOT NULL; + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_scope + ON memory_records(namespace, scope); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_kind + ON memory_records(namespace, kind); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_status + ON memory_records(namespace, status); + """) + + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_tags ( + namespace TEXT NOT NULL, + record_id TEXT NOT NULL, + tag TEXT NOT NULL, + FOREIGN KEY(namespace, record_id) + REFERENCES memory_records(namespace, id) + ON DELETE CASCADE + ); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_tags_lookup + ON memory_tags(namespace, tag, record_id); + """) + + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_related_ids ( + namespace TEXT NOT NULL, + record_id TEXT NOT NULL, + related_id TEXT NOT NULL, + FOREIGN KEY(namespace, record_id) + REFERENCES memory_records(namespace, id) + ON DELETE CASCADE + ); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_related_lookup + ON memory_related_ids(namespace, related_id, record_id); + """) -private enum SQLiteBinding { - case text(String) - case double(Double) - case null -} + try db.execute(sql: """ + CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts + USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); + """) -private let SQLITE_TRANSIENT = unsafeBitCast(-1, to: sqlite3_destructor_type.self) - -private func execSQL( - _ database: OpaquePointer?, - _ sql: String -) throws { - var errorPointer: UnsafeMutablePointer? - let result = sqlite3_exec(database, sql, nil, nil, &errorPointer) - guard result == SQLITE_OK else { - let detail = errorPointer.map { String(cString: $0) } ?? "Unknown SQLite error" - sqlite3_free(errorPointer) - throw NSError( - domain: "CodexKit.SQLiteMemoryStore", - code: Int(result), - userInfo: [NSLocalizedDescriptionKey: detail] - ) - } -} + try db.execute(sql: "PRAGMA user_version = \(currentSchemaVersion)") + } -private func prepareSQL( - _ database: OpaquePointer?, - _ sql: String -) throws -> OpaquePointer? { - guard let database else { - throw NSError( - domain: "CodexKit.SQLiteMemoryStore", - code: 0, - userInfo: [NSLocalizedDescriptionKey: "SQLite database is unavailable."] - ) - } - var statement: OpaquePointer? - let result = sqlite3_prepare_v2(database, sql, -1, &statement, nil) - guard result == SQLITE_OK else { - throw NSError( - domain: "CodexKit.SQLiteMemoryStore", - code: Int(result), - userInfo: [NSLocalizedDescriptionKey: "Failed to prepare SQLite statement."] - ) + return migrator } - return statement } From 4a55a51e0c7a7f44b390bb2dcdd1b8eb769b0bae Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 06:54:14 +1100 Subject: [PATCH 05/19] Refactor demo helpers and memory store collaborators --- .../AgentDemoViewModel+HealthCoach.swift | 176 ++-- ...entDemoViewModel+HealthCoachPlatform.swift | 16 +- .../Shared/AgentDemoViewModel+Messaging.swift | 2 +- .../AgentDemoViewModel+StructuredOutput.swift | 6 +- .../Shared/AgentDemoViewModel+Tools.swift | 131 +-- .../Shared/AgentDemoViewModel.swift | 192 ++-- .../CodexKit/Memory/SQLiteMemoryStore.swift | 837 +++++++++--------- 7 files changed, 726 insertions(+), 634 deletions(-) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift index 5a868b7..9104190 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoach.swift @@ -29,17 +29,96 @@ enum HealthCoachToneMode: String, CaseIterable, Identifiable { } } -@MainActor -extension AgentDemoViewModel { - nonisolated static let healthCoachThreadTitle = "Health Coach" - nonisolated static let healthReminderIdentifierPrefix = "health-coach-reminder" - nonisolated static let healthReminderSchedule: [(hour: Int, minute: Int)] = [ +struct DemoHealthCoachDesign { + let threadTitle = "Health Coach" + let reminderIdentifierPrefix = "health-coach-reminder" + let reminderSchedule: [(hour: Int, minute: Int)] = [ (10, 0), (13, 0), (16, 0), (19, 0), ] + func persona(for toneMode: HealthCoachToneMode) -> AgentPersonaStack { + let styleInstructions: String + + switch toneMode { + case .hardcorePersonal: + styleInstructions = """ + Be blunt, forceful, and no-nonsense. Push accountability hard, call out excuses directly, and give action commands. + Address the user directly as "you." + Never use slurs, body-shaming labels, identity-targeted insults, or humiliation. + If the goal is completed, switch to brief earned praise: "Well done, you pushed through." + """ + + case .firmCoach: + styleInstructions = """ + Be firm, direct, and practical. Keep pressure on execution with concise next actions. + Address the user directly as "you." + Avoid insults and avoid coddling. + If the goal is completed, give brief earned praise. + """ + } + + return AgentPersonaStack(layers: [ + .init( + name: "domain", + instructions: "You are a step-goal accountability coach for a mobile health app." + ), + .init( + name: "style", + instructions: styleInstructions + ), + ]) + } + + func coachFeedbackCacheKey( + steps: Int, + goal: Int, + toneMode: HealthCoachToneMode + ) -> String { + "\(toneMode.rawValue)-\(steps)-\(goal)-\(max(goal - steps, 0))" + } + + func fallbackReminderBody( + remaining: Int, + toneMode: HealthCoachToneMode + ) -> String { + if remaining <= 0 { + return "You are on pace. Stay consistent and finish strong." + } + + switch toneMode { + case .hardcorePersonal: + return "You still owe \(remaining) steps today. Move now and close it." + case .firmCoach: + return "\(remaining) steps remain. Take a focused walking block now." + } + } + + func reminderCopyCacheKey( + remaining: Int, + goal: Int, + toneMode: HealthCoachToneMode + ) -> String { + let ratio = goal > 0 ? Double(remaining) / Double(goal) : 0 + let band: String + switch ratio { + case ...0: + band = "complete" + case ..<0.2: + band = "close" + case ..<0.6: + band = "mid" + default: + band = "far" + } + return "\(toneMode.rawValue)-\(band)" + } +} + +@MainActor +extension AgentDemoViewModel { func initializeHealthCoachIfNeeded() async { guard !healthCoachInitialized else { return @@ -161,7 +240,7 @@ extension AgentDemoViewModel { return } - let cacheKey = Self.coachFeedbackCacheKey( + let cacheKey = healthCoachDesign.coachFeedbackCacheKey( steps: todayStepCount, goal: dailyStepGoal, toneMode: healthCoachToneMode @@ -244,7 +323,7 @@ extension AgentDemoViewModel { } let existingThreads = await runtime.threads() - if let existing = existingThreads.first(where: { $0.title == Self.healthCoachThreadTitle }) { + if let existing = existingThreads.first(where: { $0.title == healthCoachDesign.threadTitle }) { try await runtime.setPersonaStack(persona, for: existing.id) healthCoachThreadID = existing.id threads = await runtime.threads() @@ -252,7 +331,7 @@ extension AgentDemoViewModel { } let thread = try await runtime.createThread( - title: Self.healthCoachThreadTitle, + title: healthCoachDesign.threadTitle, personaStack: persona ) healthCoachThreadID = thread.id @@ -261,86 +340,7 @@ extension AgentDemoViewModel { } func currentHealthCoachPersona() -> AgentPersonaStack { - Self.healthCoachPersona(toneMode: healthCoachToneMode) - } - - nonisolated static func healthCoachPersona( - toneMode: HealthCoachToneMode - ) -> AgentPersonaStack { - let styleInstructions: String - - switch toneMode { - case .hardcorePersonal: - styleInstructions = """ - Be blunt, forceful, and no-nonsense. Push accountability hard, call out excuses directly, and give action commands. - Address the user directly as "you." - Never use slurs, body-shaming labels, identity-targeted insults, or humiliation. - If the goal is completed, switch to brief earned praise: "Well done, you pushed through." - """ - - case .firmCoach: - styleInstructions = """ - Be firm, direct, and practical. Keep pressure on execution with concise next actions. - Address the user directly as "you." - Avoid insults and avoid coddling. - If the goal is completed, give brief earned praise. - """ - } - - return AgentPersonaStack(layers: [ - .init( - name: "domain", - instructions: "You are a step-goal accountability coach for a mobile health app." - ), - .init( - name: "style", - instructions: styleInstructions - ), - ]) - } - - nonisolated static func coachFeedbackCacheKey( - steps: Int, - goal: Int, - toneMode: HealthCoachToneMode - ) -> String { - "\(toneMode.rawValue)-\(steps)-\(goal)-\(max(goal - steps, 0))" - } - - nonisolated static func fallbackReminderBody( - remaining: Int, - toneMode: HealthCoachToneMode - ) -> String { - if remaining <= 0 { - return "You are on pace. Stay consistent and finish strong." - } - - switch toneMode { - case .hardcorePersonal: - return "You still owe \(remaining) steps today. Move now and close it." - case .firmCoach: - return "\(remaining) steps remain. Take a focused walking block now." - } - } - - nonisolated static func reminderCopyCacheKey( - remaining: Int, - goal: Int, - toneMode: HealthCoachToneMode - ) -> String { - let ratio = goal > 0 ? Double(remaining) / Double(goal) : 0 - let band: String - switch ratio { - case ...0: - band = "complete" - case ..<0.2: - band = "close" - case ..<0.6: - band = "mid" - default: - band = "far" - } - return "\(toneMode.rawValue)-\(band)" + healthCoachDesign.persona(for: healthCoachToneMode) } func updateReminderScheduleIfPossible() async { diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift index 0246a9b..de25cb5 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+HealthCoachPlatform.swift @@ -85,8 +85,8 @@ extension AgentDemoViewModel { } func scheduleStepReminders() async { - let identifiers = Self.healthReminderSchedule.map { schedule in - "\(Self.healthReminderIdentifierPrefix)-\(schedule.hour)-\(schedule.minute)" + let identifiers = healthCoachDesign.reminderSchedule.map { schedule in + "\(healthCoachDesign.reminderIdentifierPrefix)-\(schedule.hour)-\(schedule.minute)" } notificationCenter.removePendingNotificationRequests(withIdentifiers: identifiers) @@ -100,7 +100,7 @@ extension AgentDemoViewModel { let title = "Health Coach Checkpoint" let body = await reminderNotificationBody(remaining: remaining) - for schedule in Self.healthReminderSchedule { + for schedule in healthCoachDesign.reminderSchedule { guard let reminderDate = calendar.date( bySettingHour: schedule.hour, minute: schedule.minute, @@ -124,7 +124,7 @@ extension AgentDemoViewModel { content.body = body content.sound = .default - let identifier = "\(Self.healthReminderIdentifierPrefix)-\(schedule.hour)-\(schedule.minute)" + let identifier = "\(healthCoachDesign.reminderIdentifierPrefix)-\(schedule.hour)-\(schedule.minute)" let request = UNNotificationRequest( identifier: identifier, content: content, @@ -148,7 +148,7 @@ extension AgentDemoViewModel { } func reminderNotificationBody(remaining: Int) async -> String { - let cacheKey = Self.reminderCopyCacheKey( + let cacheKey = healthCoachDesign.reminderCopyCacheKey( remaining: remaining, goal: dailyStepGoal, toneMode: healthCoachToneMode @@ -162,7 +162,7 @@ extension AgentDemoViewModel { } guard session != nil else { - return Self.fallbackReminderBody( + return healthCoachDesign.fallbackReminderBody( remaining: remaining, toneMode: healthCoachToneMode ) @@ -218,10 +218,10 @@ extension AgentDemoViewModel { return cleaned } } catch { - Self.logger.error("Failed to generate AI reminder copy: \(error.localizedDescription, privacy: .public)") + developerErrorLog("Failed to generate AI reminder copy: \(error.localizedDescription)") } - return Self.fallbackReminderBody( + return healthCoachDesign.fallbackReminderBody( remaining: remaining, toneMode: healthCoachToneMode ) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift index 7b0e53f..092f635 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift @@ -105,7 +105,7 @@ extension AgentDemoViewModel { ) let skillThread = try await runtime.createThread( title: "Skill Policy Probe: Health Coach", - skillIDs: [Self.healthCoachSkill.id] + skillIDs: [catalog.healthCoachSkill.id] ) threads = await runtime.threads() diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift index d667e8d..14dd0e4 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+StructuredOutput.swift @@ -22,7 +22,7 @@ extension AgentDemoViewModel { developerLog("Running structured shipping reply demo.") let thread = try await runtime.createThread( title: "Structured Output: Shipping Draft", - personaStack: Self.supportPersona + personaStack: catalog.supportPersona ) let request = DemoStructuredOutputExamples.shippingReplyRequest() if showResolvedInstructionsDebug { @@ -125,7 +125,7 @@ extension AgentDemoViewModel { developerLog("Running streamed structured output demo.") let thread = try await runtime.createThread( title: "Structured Output: Streamed Delivery Update", - personaStack: Self.supportPersona + personaStack: catalog.supportPersona ) let request = DemoStructuredOutputExamples.streamedStructuredRequest() if showResolvedInstructionsDebug { @@ -201,7 +201,7 @@ extension AgentDemoViewModel { "Streamed structured output demo finished. threadID=\(thread.id) partialCount=\(partialSnapshots.count) persistedMetadata=\(persistedMetadata != nil)" ) } catch { - guard !Self.isCancellationError(error) else { + guard !diagnostics.isCancellationError(error) else { return } structuredStreamingError = error.localizedDescription diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift index 2ffb15d..54ce6e8 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Tools.swift @@ -1,7 +1,7 @@ import CodexKit import Foundation -private struct HealthCoachToolSnapshot: Sendable { +struct HealthCoachToolSnapshot: Sendable { let stepsToday: Int let dailyGoal: Int let remainingSteps: Int @@ -9,13 +9,63 @@ private struct HealthCoachToolSnapshot: Sendable { let healthKitAuthorized: Bool } +struct DemoToolOutputFactory { + func makeHealthCoachProgress( + invocation: ToolInvocation, + snapshot: HealthCoachToolSnapshot + ) -> ToolResultEnvelope { + let freshness = snapshot.healthKitAuthorized ? "live_or_cached_healthkit" : "app_cached_only" + + return .success( + invocation: invocation, + text: """ + health_progress[stepsToday=\(snapshot.stepsToday), dailyGoal=\(snapshot.dailyGoal), remainingSteps=\(snapshot.remainingSteps), hoursLeftToday=\(snapshot.hoursLeftToday), healthKitAuthorized=\(snapshot.healthKitAuthorized), freshness=\(freshness)] + """ + ) + } + + func makeTravelDayPlan(invocation: ToolInvocation) -> ToolResultEnvelope { + guard case let .object(arguments) = invocation.arguments else { + return .failure( + invocation: invocation, + message: "The travel planner tool expected object arguments." + ) + } + + let destination = arguments["destination"]?.stringValue? + .trimmingCharacters(in: .whitespacesAndNewlines) + let tripDays = max(Int(arguments["trip_days"]?.numberValue ?? 3), 1) + let budget = arguments["budget_level"]?.stringValue? + .trimmingCharacters(in: .whitespacesAndNewlines) + .lowercased() ?? "medium" + let companions = arguments["companions"]?.stringValue? + .trimmingCharacters(in: .whitespacesAndNewlines) + .lowercased() ?? "solo" + + guard let destination, !destination.isEmpty else { + return .failure(invocation: invocation, message: "destination is required.") + } + + let planLines = (1 ... min(tripDays, 10)).map { day in + "day\(day):arrival_walk=\(budget == "high" ? "taxi+priority-pass" : "public-transit"),focus=\(companions == "family" ? "kid-friendly highlight + early dinner" : "local highlight + flexible dinner")" + } + + return .success( + invocation: invocation, + text: """ + travel_day_plan[destination=\(destination), tripDays=\(tripDays), budget=\(budget), companions=\(companions), plan=\(planLines.joined(separator: " | "))] + """ + ) + } +} + @MainActor extension AgentDemoViewModel { func registerDemoSkills() async { do { - try await runtime.replaceSkill(Self.healthCoachSkill) - try await runtime.replaceSkill(Self.travelPlannerSkill) - developerLog("Registered demo skills: \(Self.healthCoachSkill.id), \(Self.travelPlannerSkill.id)") + try await runtime.replaceSkill(catalog.healthCoachSkill) + try await runtime.replaceSkill(catalog.travelPlannerSkill) + developerLog("Registered demo skills: \(catalog.healthCoachSkill.id), \(catalog.travelPlannerSkill.id)") } catch { reportError(error) } @@ -24,7 +74,7 @@ extension AgentDemoViewModel { func registerDemoTool() async { do { let healthCoachDefinition = ToolDefinition( - name: Self.healthCoachToolName, + name: catalog.healthCoachToolName, description: "Fetch a live health-coach progress snapshot from HealthKit-aware app state.", inputSchema: .object([ "type": .string("object"), @@ -33,7 +83,7 @@ extension AgentDemoViewModel { ) let travelPlannerDefinition = ToolDefinition( - name: Self.travelPlannerToolName, + name: catalog.travelPlannerToolName, description: "Build a compact deterministic day-by-day travel plan.", inputSchema: .object([ "type": .string("object"), @@ -66,16 +116,16 @@ extension AgentDemoViewModel { return .failure(invocation: invocation, message: "Health coach context is unavailable.") } let snapshot = await self.captureHealthCoachToolSnapshot() - return Self.makeHealthCoachProgress( + return self.toolOutputFactory.makeHealthCoachProgress( invocation: invocation, snapshot: snapshot ) } try await registerTool(travelPlannerDefinition) { invocation, _ in - Self.makeTravelDayPlan(invocation: invocation) + self.toolOutputFactory.makeTravelDayPlan(invocation: invocation) } developerLog( - "Registered demo tools: \(Self.healthCoachToolName), \(Self.travelPlannerToolName)" + "Registered demo tools: \(catalog.healthCoachToolName), \(catalog.travelPlannerToolName)" ) } catch { reportError(error) @@ -87,13 +137,17 @@ extension AgentDemoViewModel { execute: @escaping @Sendable (ToolInvocation, ToolExecutionContext) async throws -> ToolResultEnvelope ) async throws { try await runtime.replaceTool(definition, executor: AnyToolExecutor { invocation, context in - Self.logger.info( - "Executing tool \(invocation.toolName, privacy: .public) with arguments: \(String(describing: invocation.arguments), privacy: .public)" - ) + await MainActor.run { + self.developerLog( + "Executing tool \(invocation.toolName) with arguments: \(String(describing: invocation.arguments))" + ) + } let result = try await execute(invocation, context) - Self.logger.info( - "Tool \(invocation.toolName, privacy: .public) returned: \(result.primaryText ?? "", privacy: .public)" - ) + await MainActor.run { + self.developerLog( + "Tool \(invocation.toolName) returned: \(result.primaryText ?? "")" + ) + } return result }) } @@ -122,53 +176,6 @@ extension AgentDemoViewModel { ) } - private nonisolated static func makeHealthCoachProgress( - invocation: ToolInvocation, - snapshot: HealthCoachToolSnapshot - ) -> ToolResultEnvelope { - let freshness = snapshot.healthKitAuthorized ? "live_or_cached_healthkit" : "app_cached_only" - - return .success( - invocation: invocation, - text: """ - health_progress[stepsToday=\(snapshot.stepsToday), dailyGoal=\(snapshot.dailyGoal), remainingSteps=\(snapshot.remainingSteps), hoursLeftToday=\(snapshot.hoursLeftToday), healthKitAuthorized=\(snapshot.healthKitAuthorized), freshness=\(freshness)] - """ - ) - } - - nonisolated static func makeTravelDayPlan(invocation: ToolInvocation) -> ToolResultEnvelope { - guard case let .object(arguments) = invocation.arguments else { - return .failure( - invocation: invocation, - message: "The travel planner tool expected object arguments." - ) - } - - let destination = arguments["destination"]?.stringValue? - .trimmingCharacters(in: .whitespacesAndNewlines) - let tripDays = max(Int(arguments["trip_days"]?.numberValue ?? 3), 1) - let budget = arguments["budget_level"]?.stringValue? - .trimmingCharacters(in: .whitespacesAndNewlines) - .lowercased() ?? "medium" - let companions = arguments["companions"]?.stringValue? - .trimmingCharacters(in: .whitespacesAndNewlines) - .lowercased() ?? "solo" - - guard let destination, !destination.isEmpty else { - return .failure(invocation: invocation, message: "destination is required.") - } - - let planLines = (1 ... min(tripDays, 10)).map { day in - "day\(day):arrival_walk=\(budget == "high" ? "taxi+priority-pass" : "public-transit"),focus=\(companions == "family" ? "kid-friendly highlight + early dinner" : "local highlight + flexible dinner")" - } - - return .success( - invocation: invocation, - text: """ - travel_day_plan[destination=\(destination), tripDays=\(tripDays), budget=\(budget), companions=\(companions), plan=\(planLines.joined(separator: " | "))] - """ - ) - } } private extension JSONValue { diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift index afb81e4..6b9885e 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift @@ -80,58 +80,106 @@ struct AutomaticPolicyMemoryDemoResult: Sendable { let records: [MemoryRecord] } -@MainActor -@Observable -final class AgentDemoViewModel: @unchecked Sendable { - nonisolated static let developerLoggingDefaultsKey = "AssistantRuntimeDemoApp.developerLoggingEnabled" - nonisolated static let logger = Logger( - subsystem: "ai.assistantruntime.demoapp", - category: "DemoTool" - ) - nonisolated static let supportPersona = AgentPersonaStack(layers: [ - .init( - name: "domain", - instructions: "You are an expert customer support agent for a shipping app." - ), - .init( - name: "style", - instructions: "Be concise, calm, and action-oriented." - ), - ]) - nonisolated static let plannerPersona = AgentPersonaStack(layers: [ - .init( - name: "planner", - instructions: "Act as a careful technical planner. Focus on tradeoffs and implementation sequencing." - ), - ]) - nonisolated static let reviewerOverridePersona = AgentPersonaStack(layers: [ - .init( - name: "reviewer", - instructions: "For this reply only, act as a strict reviewer and call out risks first." - ), - ]) - nonisolated static let healthCoachToolName = "health_coach_fetch_progress" - nonisolated static let travelPlannerToolName = "travel_planner_build_day_plan" - nonisolated static let healthCoachSkill = AgentSkill( - id: "health_coach", - name: "Health Coach", - instructions: "You are a health coach focused on daily step goals and execution. For every user turn, call the \(healthCoachToolName) tool exactly once before your final reply, then provide one practical walking plan and one accountability line.", - executionPolicy: .init( - allowedToolNames: [healthCoachToolName], - requiredToolNames: [healthCoachToolName], - maxToolCalls: 1 +struct DemoCatalog { + let supportPersona: AgentPersonaStack + let plannerPersona: AgentPersonaStack + let reviewerOverridePersona: AgentPersonaStack + let healthCoachToolName: String + let travelPlannerToolName: String + let healthCoachSkill: AgentSkill + let travelPlannerSkill: AgentSkill + + init() { + supportPersona = AgentPersonaStack(layers: [ + .init( + name: "domain", + instructions: "You are an expert customer support agent for a shipping app." + ), + .init( + name: "style", + instructions: "Be concise, calm, and action-oriented." + ), + ]) + plannerPersona = AgentPersonaStack(layers: [ + .init( + name: "planner", + instructions: "Act as a careful technical planner. Focus on tradeoffs and implementation sequencing." + ), + ]) + reviewerOverridePersona = AgentPersonaStack(layers: [ + .init( + name: "reviewer", + instructions: "For this reply only, act as a strict reviewer and call out risks first." + ), + ]) + + healthCoachToolName = "health_coach_fetch_progress" + travelPlannerToolName = "travel_planner_build_day_plan" + healthCoachSkill = AgentSkill( + id: "health_coach", + name: "Health Coach", + instructions: "You are a health coach focused on daily step goals and execution. For every user turn, call the \(healthCoachToolName) tool exactly once before your final reply, then provide one practical walking plan and one accountability line.", + executionPolicy: .init( + allowedToolNames: [healthCoachToolName], + requiredToolNames: [healthCoachToolName], + maxToolCalls: 1 + ) ) - ) - nonisolated static let travelPlannerSkill = AgentSkill( - id: "travel_planner", - name: "Travel Planner", - instructions: "You are a travel planning assistant for mobile users. Provide concise day-by-day itineraries, practical logistics, and a compact packing checklist.", - executionPolicy: .init( - allowedToolNames: [travelPlannerToolName], - maxToolCalls: 1 + travelPlannerSkill = AgentSkill( + id: "travel_planner", + name: "Travel Planner", + instructions: "You are a travel planning assistant for mobile users. Provide concise day-by-day itineraries, practical logistics, and a compact packing checklist.", + executionPolicy: .init( + allowedToolNames: [travelPlannerToolName], + maxToolCalls: 1 + ) ) + } +} + +struct DemoDiagnostics { + private let developerLoggingDefaultsKey = "AssistantRuntimeDemoApp.developerLoggingEnabled" + private let logger = Logger( + subsystem: "ai.assistantruntime.demoapp", + category: "DemoTool" ) + func initialDeveloperLoggingEnabled(userDefaults: UserDefaults = .standard) -> Bool { + if userDefaults.object(forKey: developerLoggingDefaultsKey) != nil { + return userDefaults.bool(forKey: developerLoggingDefaultsKey) + } +#if DEBUG + return true +#else + return false +#endif + } + + func persistDeveloperLoggingEnabled( + _ enabled: Bool, + userDefaults: UserDefaults = .standard + ) { + userDefaults.set(enabled, forKey: developerLoggingDefaultsKey) + } + + func isCancellationError(_ error: Error) -> Bool { + error is CancellationError + } + + func log(_ message: String) { + logger.notice("\(message, privacy: .public)") + print("[CodexKit Demo] \(message)") + } + + func error(_ message: String) { + logger.error("\(message, privacy: .public)") + print("[CodexKit Demo][Error] \(message)") + } +} + +@MainActor +@Observable +final class AgentDemoViewModel: @unchecked Sendable { var session: ChatGPTSession? var threads: [AgentThread] = [] var messages: [AgentMessage] = [] @@ -140,10 +188,7 @@ final class AgentDemoViewModel: @unchecked Sendable { var showResolvedInstructionsDebug = false var developerLoggingEnabled: Bool { didSet { - UserDefaults.standard.set( - developerLoggingEnabled, - forKey: Self.developerLoggingDefaultsKey - ) + diagnostics.persistDeveloperLoggingEnabled(developerLoggingEnabled) developerLog( developerLoggingEnabled ? "Developer logging enabled." @@ -194,6 +239,10 @@ final class AgentDemoViewModel: @unchecked Sendable { let enableWebSearch: Bool let stateURL: URL? let keychainAccount: String + let catalog: DemoCatalog + let diagnostics: DemoDiagnostics + let toolOutputFactory: DemoToolOutputFactory + let healthCoachDesign: DemoHealthCoachDesign var runtime: AgentRuntime var activeThreadID: String? @@ -214,11 +263,15 @@ final class AgentDemoViewModel: @unchecked Sendable { approvalInbox: ApprovalInbox, deviceCodePromptCoordinator: DeviceCodePromptCoordinator = DeviceCodePromptCoordinator() ) { + self.catalog = DemoCatalog() + self.diagnostics = DemoDiagnostics() + self.toolOutputFactory = DemoToolOutputFactory() + self.healthCoachDesign = DemoHealthCoachDesign() self.runtime = runtime self.model = model self.enableWebSearch = enableWebSearch self.reasoningEffort = reasoningEffort - self.developerLoggingEnabled = Self.initialDeveloperLoggingEnabled() + self.developerLoggingEnabled = diagnostics.initialDeveloperLoggingEnabled() self.stateURL = stateURL self.keychainAccount = keychainAccount self.approvalInbox = approvalInbox @@ -389,7 +442,7 @@ final class AgentDemoViewModel: @unchecked Sendable { func createSupportPersonaThread() async { await createThreadInternal( title: "Support Persona Demo", - personaStack: Self.supportPersona + personaStack: catalog.supportPersona ) } @@ -401,7 +454,7 @@ final class AgentDemoViewModel: @unchecked Sendable { do { try await runtime.setPersonaStack( - Self.plannerPersona, + catalog.plannerPersona, for: activeThreadID ) threads = await runtime.threads() @@ -417,7 +470,7 @@ final class AgentDemoViewModel: @unchecked Sendable { await sendMessageInternal( "Review this conversation setup and tell me the biggest risks first.", - personaOverride: Self.reviewerOverridePersona + personaOverride: catalog.reviewerOverridePersona ) } @@ -425,7 +478,7 @@ final class AgentDemoViewModel: @unchecked Sendable { await createThreadInternal( title: "Skill Demo: Health Coach", personaStack: nil, - skillIDs: [Self.healthCoachSkill.id] + skillIDs: [catalog.healthCoachSkill.id] ) } @@ -433,7 +486,7 @@ final class AgentDemoViewModel: @unchecked Sendable { await createThreadInternal( title: "Skill Demo: Travel Planner", personaStack: nil, - skillIDs: [Self.travelPlannerSkill.id] + skillIDs: [catalog.travelPlannerSkill.id] ) } @@ -496,7 +549,7 @@ final class AgentDemoViewModel: @unchecked Sendable { } func reportError(_ error: Error) { - guard !Self.isCancellationError(error) else { + guard !diagnostics.isCancellationError(error) else { developerLog("Ignoring CancellationError from async UI task.") return } @@ -516,35 +569,18 @@ final class AgentDemoViewModel: @unchecked Sendable { lastError = nil } - nonisolated static func isCancellationError(_ error: Error) -> Bool { - error is CancellationError - } - - nonisolated static func initialDeveloperLoggingEnabled() -> Bool { - if UserDefaults.standard.object(forKey: developerLoggingDefaultsKey) != nil { - return UserDefaults.standard.bool(forKey: developerLoggingDefaultsKey) - } -#if DEBUG - return true -#else - return false -#endif - } - func developerLog(_ message: String) { guard developerLoggingEnabled else { return } - Self.logger.notice("\(message, privacy: .public)") - print("[CodexKit Demo] \(message)") + diagnostics.log(message) } func developerErrorLog(_ message: String) { guard developerLoggingEnabled else { return } - Self.logger.error("\(message, privacy: .public)") - print("[CodexKit Demo][Error] \(message)") + diagnostics.error(message) } func signOut() async { diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift index b7303ed..74d8f22 100644 --- a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift +++ b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift @@ -1,299 +1,188 @@ import Foundation import GRDB -public actor SQLiteMemoryStore: MemoryStoring { - private static let currentSchemaVersion = 1 +private struct SQLiteMemoryStoreSchema: Sendable { + let currentVersion = 1 - private let url: URL - private let dbQueue: DatabaseQueue - private let migrator: DatabaseMigrator - public init(url: URL) throws { - self.url = url + func existingVersion(in db: Database) throws -> Int { + try Int.fetchOne(db, sql: "PRAGMA user_version;") ?? 0 + } - let directory = url.deletingLastPathComponent() - if !directory.path.isEmpty { - try FileManager.default.createDirectory( - at: directory, - withIntermediateDirectories: true - ) - } + func makeMigrator() -> DatabaseMigrator { + var migrator = DatabaseMigrator() - var configuration = Configuration() - configuration.foreignKeysEnabled = true - configuration.label = "CodexKit.SQLiteMemoryStore" - dbQueue = try DatabaseQueue(path: url.path, configuration: configuration) - migrator = Self.makeMigrator() + migrator.registerMigration("memory_store_v1") { db in + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_records ( + namespace TEXT NOT NULL, + id TEXT NOT NULL, + scope TEXT NOT NULL, + kind TEXT NOT NULL, + summary TEXT NOT NULL, + evidence_json TEXT NOT NULL, + importance REAL NOT NULL, + created_at REAL NOT NULL, + observed_at REAL, + expires_at REAL, + tags_json TEXT NOT NULL, + related_ids_json TEXT NOT NULL, + dedupe_key TEXT, + is_pinned INTEGER NOT NULL, + attributes_json TEXT, + status TEXT NOT NULL, + PRIMARY KEY(namespace, id) + ); + """) - let existingVersion = try dbQueue.read { db in - try Self.schemaVersion(in: db) - } - if existingVersion > Self.currentSchemaVersion { - throw MemoryStoreError.unsupportedSchemaVersion(existingVersion) + try db.execute(sql: """ + CREATE UNIQUE INDEX IF NOT EXISTS memory_records_namespace_dedupe + ON memory_records(namespace, dedupe_key) + WHERE dedupe_key IS NOT NULL; + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_scope + ON memory_records(namespace, scope); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_kind + ON memory_records(namespace, kind); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_status + ON memory_records(namespace, status); + """) + + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_tags ( + namespace TEXT NOT NULL, + record_id TEXT NOT NULL, + tag TEXT NOT NULL, + FOREIGN KEY(namespace, record_id) + REFERENCES memory_records(namespace, id) + ON DELETE CASCADE + ); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_tags_lookup + ON memory_tags(namespace, tag, record_id); + """) + + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_related_ids ( + namespace TEXT NOT NULL, + record_id TEXT NOT NULL, + related_id TEXT NOT NULL, + FOREIGN KEY(namespace, record_id) + REFERENCES memory_records(namespace, id) + ON DELETE CASCADE + ); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_related_lookup + ON memory_related_ids(namespace, related_id, record_id); + """) + + try db.execute(sql: """ + CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts + USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); + """) + + try db.execute(sql: "PRAGMA user_version = \(currentVersion)") } - try migrator.migrate(dbQueue) + return migrator } +} - public func put(_ record: MemoryRecord) async throws { - try MemoryQueryEngine.validateNamespace(record.namespace) - try await writeTransaction { db in - try Self.ensureRecordIDAvailable(record.id, namespace: record.namespace, in: db) - if let dedupeKey = record.dedupeKey { - try Self.ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace, in: db) - } - try Self.upsertRecord(record, in: db) - } +private struct SQLiteMemoryStoreCodec: Sendable { + private let encoder = JSONEncoder() + private let decoder = JSONDecoder() + + func encode(_ value: T) throws -> String { + let data = try encoder.encode(value) + return String(decoding: data, as: UTF8.self) } - public func putMany(_ records: [MemoryRecord]) async throws { - try await writeTransaction { db in - for record in records { - try MemoryQueryEngine.validateNamespace(record.namespace) - try Self.ensureRecordIDAvailable(record.id, namespace: record.namespace, in: db) - if let dedupeKey = record.dedupeKey { - try Self.ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace, in: db) - } - try Self.upsertRecord(record, in: db) - } + func encodeNullable(_ value: T?) throws -> String? { + guard let value else { + return nil } + return try encode(value) } - public func upsert(_ record: MemoryRecord, dedupeKey: String) async throws { - try MemoryQueryEngine.validateNamespace(record.namespace) - try await writeTransaction { db in - try Self.deleteRecord(withDedupeKey: dedupeKey, namespace: record.namespace, in: db) - try Self.deleteRecord(id: record.id, namespace: record.namespace, in: db) - var updatedRecord = record - updatedRecord.dedupeKey = dedupeKey - try Self.upsertRecord(updatedRecord, in: db) - } + func decode(_ type: T.Type, from string: String) throws -> T { + try decoder.decode(type, from: Data(string.utf8)) } - public func query(_ query: MemoryQuery) async throws -> MemoryQueryResult { - try MemoryQueryEngine.validateNamespace(query.namespace) - let records = try await dbQueue.read { db in - try Self.loadRecords(namespace: query.namespace, in: db) - } - let rawScores = try await dbQueue.read { db in - try Self.loadFTSRawScores( - namespace: query.namespace, - queryText: query.text, - in: db - ) + func decodeNullable(_ type: T.Type, from string: String?) throws -> T? { + guard let string else { + return nil } + return try decode(type, from: string) + } - let candidates = records.map { record in - MemoryQueryEngine.Candidate( - record: record, - textScore: rawScores[record.id], - textScoreOrdering: .lowerIsBetter - ) + func makeFTSQuery(from value: String?) -> String { + let tokens = MemoryQueryEngine.tokenize(value) + guard !tokens.isEmpty else { + return "" } - - return try MemoryQueryEngine.evaluate( - candidates: candidates, - query: query - ) + return tokens.joined(separator: " OR ") } +} - public func record( - id: String, - namespace: String - ) async throws -> MemoryRecord? { - try MemoryQueryEngine.validateNamespace(namespace) - return try await dbQueue.read { db in - try Self.loadRecords(namespace: namespace, in: db).first { $0.id == id } +private struct SQLiteMemoryStoreRepository: Sendable { + let codec: SQLiteMemoryStoreCodec + + func ensureRecordIDAvailable( + _ id: String, + namespace: String, + in db: Database + ) throws { + if try recordExists(id: id, namespace: namespace, in: db) { + throw MemoryStoreError.duplicateRecordID(id) } } - public func list(_ query: MemoryRecordListQuery) async throws -> [MemoryRecord] { - try MemoryQueryEngine.validateNamespace(query.namespace) - return try await dbQueue.read { db in - try Self.loadRecords(namespace: query.namespace, in: db) - .filter { record in - if !query.includeArchived, record.status == .archived { - return false - } - if !query.scopes.isEmpty, !query.scopes.contains(record.scope) { - return false - } - if !query.kinds.isEmpty, !query.kinds.contains(record.kind) { - return false - } - return true - } - .sorted { - if $0.effectiveDate == $1.effectiveDate { - return $0.id < $1.id - } - return $0.effectiveDate > $1.effectiveDate - } - .prefix(query.limit ?? .max) - .map { $0 } + func ensureDedupeKeyAvailable( + _ dedupeKey: String, + namespace: String, + in db: Database + ) throws { + if try recordExists(dedupeKey: dedupeKey, namespace: namespace, in: db) { + throw MemoryStoreError.duplicateDedupeKey(dedupeKey) } } - public func diagnostics(namespace: String) async throws -> MemoryStoreDiagnostics { - try MemoryQueryEngine.validateNamespace(namespace) - let records = try await dbQueue.read { db in - try Self.loadRecords(namespace: namespace, in: db) - } - let schemaVersion = try await dbQueue.read { db in - try Self.schemaVersion(in: db) - } - return MemoryStoreDiagnostics( - namespace: namespace, - implementation: "sqlite", - schemaVersion: schemaVersion, - totalRecords: records.count, - activeRecords: records.filter { $0.status == .active }.count, - archivedRecords: records.filter { $0.status == .archived }.count, - countsByScope: Dictionary(grouping: records, by: \.scope).mapValues(\.count), - countsByKind: Dictionary(grouping: records, by: \.kind).mapValues(\.count) + func loadRecords( + namespace: String, + in db: Database + ) throws -> [MemoryRecord] { + let rows = try Row.fetchAll( + db, + sql: """ + SELECT + id, scope, kind, summary, evidence_json, importance, + created_at, observed_at, expires_at, tags_json, + related_ids_json, dedupe_key, is_pinned, attributes_json, status + FROM memory_records + WHERE namespace = ?; + """, + arguments: [namespace] ) - } - public func compact(_ request: MemoryCompactionRequest) async throws { - try MemoryQueryEngine.validateNamespace(request.replacement.namespace) - try await writeTransaction { db in - try Self.ensureRecordIDAvailable( - request.replacement.id, - namespace: request.replacement.namespace, - in: db - ) - if let dedupeKey = request.replacement.dedupeKey { - try Self.ensureDedupeKeyAvailable( - dedupeKey, - namespace: request.replacement.namespace, - in: db - ) - } - try Self.upsertRecord(request.replacement, in: db) - for sourceID in request.sourceIDs { - try Self.archiveRecord(id: sourceID, namespace: request.replacement.namespace, in: db) - } + return try rows.map { row in + try makeRecord(from: row, namespace: namespace) } } - public func archive(ids: [String], namespace: String) async throws { - try MemoryQueryEngine.validateNamespace(namespace) - try await writeTransaction { db in - for id in ids { - try Self.archiveRecord(id: id, namespace: namespace, in: db) - } - } - } - - public func delete(ids: [String], namespace: String) async throws { - try MemoryQueryEngine.validateNamespace(namespace) - try await writeTransaction { db in - for id in ids { - try Self.deleteRecord(id: id, namespace: namespace, in: db) - } - } - } - - @discardableResult - public func pruneExpired( - now: Date, - namespace: String - ) async throws -> Int { - try MemoryQueryEngine.validateNamespace(namespace) - let expiredIDs = try await dbQueue.read { db in - try Self.loadRecords(namespace: namespace, in: db) - .filter { record in - !record.isPinned && - record.status == .active && - (record.expiresAt?.compare(now) == .orderedAscending || - record.expiresAt?.compare(now) == .orderedSame) - } - .map(\.id) - } - - try await writeTransaction { db in - for id in expiredIDs { - try Self.deleteRecord(id: id, namespace: namespace, in: db) - } - } - return expiredIDs.count - } - - private func writeTransaction(_ operation: @escaping @Sendable (Database) throws -> Void) async throws { - try await dbQueue.writeWithoutTransaction { db in - try db.inTransaction { - try operation(db) - return .commit - } - } - } - - private static func ensureRecordIDAvailable( - _ id: String, - namespace: String, - in db: Database - ) throws { - if try recordExists(id: id, namespace: namespace, in: db) { - throw MemoryStoreError.duplicateRecordID(id) - } - } - - private static func ensureDedupeKeyAvailable( - _ dedupeKey: String, - namespace: String, - in db: Database - ) throws { - if try recordExists(dedupeKey: dedupeKey, namespace: namespace, in: db) { - throw MemoryStoreError.duplicateDedupeKey(dedupeKey) - } - } - - private static func loadRecords( - namespace: String, - in db: Database - ) throws -> [MemoryRecord] { - let rows = try Row.fetchAll( - db, - sql: """ - SELECT - id, scope, kind, summary, evidence_json, importance, - created_at, observed_at, expires_at, tags_json, - related_ids_json, dedupe_key, is_pinned, attributes_json, status - FROM memory_records - WHERE namespace = ?; - """, - arguments: [namespace] - ) - - return try rows.map { row in - MemoryRecord( - id: row["id"], - namespace: namespace, - scope: MemoryScope(rawValue: row["scope"]), - kind: row["kind"], - summary: row["summary"], - evidence: try decodeJSON([String].self, from: row["evidence_json"]), - importance: row["importance"], - createdAt: Date(timeIntervalSince1970: row["created_at"]), - observedAt: (row["observed_at"] as Double?).map(Date.init(timeIntervalSince1970:)), - expiresAt: (row["expires_at"] as Double?).map(Date.init(timeIntervalSince1970:)), - tags: try decodeJSON([String].self, from: row["tags_json"]), - relatedIDs: try decodeJSON([String].self, from: row["related_ids_json"]), - dedupeKey: row["dedupe_key"], - isPinned: (row["is_pinned"] as Int64? ?? 0) == 1, - attributes: try decodeNullableJSON(JSONValue.self, from: row["attributes_json"]), - status: MemoryRecordStatus(rawValue: row["status"]) ?? .active - ) - } - } - - private static func loadFTSRawScores( - namespace: String, - queryText: String?, - in db: Database - ) throws -> [String: Double] { - let matchQuery = Self.ftsQuery(from: queryText) - guard !matchQuery.isEmpty else { - return [:] + func loadRawFTSScores( + namespace: String, + queryText: String?, + in db: Database + ) throws -> [String: Double] { + let matchQuery = codec.makeFTSQuery(from: queryText) + guard !matchQuery.isEmpty else { + return [:] } let rows = try Row.fetchAll( @@ -315,39 +204,31 @@ public actor SQLiteMemoryStore: MemoryStoring { return scores } - private static func recordExists( + func loadRecord( id: String, namespace: String, in db: Database - ) throws -> Bool { - try Bool.fetchOne( + ) throws -> MemoryRecord? { + let row = try Row.fetchOne( db, sql: """ - SELECT EXISTS( - SELECT 1 FROM memory_records WHERE namespace = ? AND id = ? - ); + SELECT + id, scope, kind, summary, evidence_json, importance, + created_at, observed_at, expires_at, tags_json, + related_ids_json, dedupe_key, is_pinned, attributes_json, status + FROM memory_records + WHERE namespace = ? AND id = ? + LIMIT 1; """, arguments: [namespace, id] - ) ?? false - } - - private static func recordExists( - dedupeKey: String, - namespace: String, - in db: Database - ) throws -> Bool { - try Bool.fetchOne( - db, - sql: """ - SELECT EXISTS( - SELECT 1 FROM memory_records WHERE namespace = ? AND dedupe_key = ? - ); - """, - arguments: [namespace, dedupeKey] - ) ?? false + ) + guard let row else { + return nil + } + return try makeRecord(from: row, namespace: namespace) } - private static func archiveRecord( + func archiveRecord( id: String, namespace: String, in db: Database @@ -362,7 +243,7 @@ public actor SQLiteMemoryStore: MemoryStoring { ) } - private static func deleteRecord( + func deleteRecord( id: String, namespace: String, in db: Database @@ -377,7 +258,7 @@ public actor SQLiteMemoryStore: MemoryStoring { ) } - private static func deleteRecord( + func deleteRecord( withDedupeKey dedupeKey: String, namespace: String, in db: Database @@ -396,7 +277,7 @@ public actor SQLiteMemoryStore: MemoryStoring { } } - private static func upsertRecord( + func upsertRecord( _ record: MemoryRecord, in db: Database ) throws { @@ -414,16 +295,16 @@ public actor SQLiteMemoryStore: MemoryStoring { record.scope.rawValue, record.kind, record.summary, - try encodeJSON(record.evidence), + try codec.encode(record.evidence), record.importance, record.createdAt.timeIntervalSince1970, record.observedAt?.timeIntervalSince1970, record.expiresAt?.timeIntervalSince1970, - try encodeJSON(record.tags), - try encodeJSON(record.relatedIDs), + try codec.encode(record.tags), + try codec.encode(record.relatedIDs), record.dedupeKey, record.isPinned ? 1 : 0, - try encodeNullableJSON(record.attributes), + try codec.encodeNullable(record.attributes), record.status.rawValue, ] ) @@ -466,129 +347,297 @@ public actor SQLiteMemoryStore: MemoryStoring { ) } - private static func encodeJSON(_ value: T) throws -> String { - let data = try JSONEncoder().encode(value) - return String(decoding: data, as: UTF8.self) + private func makeRecord( + from row: Row, + namespace: String + ) throws -> MemoryRecord { + MemoryRecord( + id: row["id"], + namespace: namespace, + scope: MemoryScope(rawValue: row["scope"]), + kind: row["kind"], + summary: row["summary"], + evidence: try codec.decode([String].self, from: row["evidence_json"]), + importance: row["importance"], + createdAt: Date(timeIntervalSince1970: row["created_at"]), + observedAt: (row["observed_at"] as Double?).map(Date.init(timeIntervalSince1970:)), + expiresAt: (row["expires_at"] as Double?).map(Date.init(timeIntervalSince1970:)), + tags: try codec.decode([String].self, from: row["tags_json"]), + relatedIDs: try codec.decode([String].self, from: row["related_ids_json"]), + dedupeKey: row["dedupe_key"], + isPinned: (row["is_pinned"] as Int64? ?? 0) == 1, + attributes: try codec.decodeNullable(JSONValue.self, from: row["attributes_json"]), + status: MemoryRecordStatus(rawValue: row["status"]) ?? .active + ) } - private static func encodeNullableJSON(_ value: T?) throws -> String? { - guard let value else { - return nil + private func recordExists( + id: String, + namespace: String, + in db: Database + ) throws -> Bool { + try Bool.fetchOne( + db, + sql: """ + SELECT EXISTS( + SELECT 1 FROM memory_records WHERE namespace = ? AND id = ? + ); + """, + arguments: [namespace, id] + ) ?? false + } + + private func recordExists( + dedupeKey: String, + namespace: String, + in db: Database + ) throws -> Bool { + try Bool.fetchOne( + db, + sql: """ + SELECT EXISTS( + SELECT 1 FROM memory_records WHERE namespace = ? AND dedupe_key = ? + ); + """, + arguments: [namespace, dedupeKey] + ) ?? false + } +} + +public actor SQLiteMemoryStore: MemoryStoring { + private let url: URL + private let dbQueue: DatabaseQueue + private let schema: SQLiteMemoryStoreSchema + private let repository: SQLiteMemoryStoreRepository + private let migrator: DatabaseMigrator + + public init(url: URL) throws { + self.url = url + self.schema = SQLiteMemoryStoreSchema() + let codec = SQLiteMemoryStoreCodec() + self.repository = SQLiteMemoryStoreRepository(codec: codec) + + let directory = url.deletingLastPathComponent() + if !directory.path.isEmpty { + try FileManager.default.createDirectory( + at: directory, + withIntermediateDirectories: true + ) } - return try Self.encodeJSON(value) + + var configuration = Configuration() + configuration.foreignKeysEnabled = true + configuration.label = "CodexKit.SQLiteMemoryStore" + dbQueue = try DatabaseQueue(path: url.path, configuration: configuration) + migrator = schema.makeMigrator() + + let existingVersion = try dbQueue.read { db in + try schema.existingVersion(in: db) + } + if existingVersion > schema.currentVersion { + throw MemoryStoreError.unsupportedSchemaVersion(existingVersion) + } + + try migrator.migrate(dbQueue) } - private static func decodeJSON( - _ type: T.Type, - from string: String - ) throws -> T { - try JSONDecoder().decode(type, from: Data(string.utf8)) + public func put(_ record: MemoryRecord) async throws { + try MemoryQueryEngine.validateNamespace(record.namespace) + let repository = self.repository + try await writeTransaction { db in + try repository.ensureRecordIDAvailable(record.id, namespace: record.namespace, in: db) + if let dedupeKey = record.dedupeKey { + try repository.ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace, in: db) + } + try repository.upsertRecord(record, in: db) + } } - private static func decodeNullableJSON( - _ type: T.Type, - from string: String? - ) throws -> T? { - guard let string else { - return nil + public func putMany(_ records: [MemoryRecord]) async throws { + let repository = self.repository + try await writeTransaction { db in + for record in records { + try MemoryQueryEngine.validateNamespace(record.namespace) + try repository.ensureRecordIDAvailable(record.id, namespace: record.namespace, in: db) + if let dedupeKey = record.dedupeKey { + try repository.ensureDedupeKeyAvailable(dedupeKey, namespace: record.namespace, in: db) + } + try repository.upsertRecord(record, in: db) + } } - return try Self.decodeJSON(type, from: string) } - private static func ftsQuery(from value: String?) -> String { - let tokens = MemoryQueryEngine.tokenize(value) - guard !tokens.isEmpty else { - return "" + public func upsert(_ record: MemoryRecord, dedupeKey: String) async throws { + try MemoryQueryEngine.validateNamespace(record.namespace) + let repository = self.repository + try await writeTransaction { db in + try repository.deleteRecord(withDedupeKey: dedupeKey, namespace: record.namespace, in: db) + try repository.deleteRecord(id: record.id, namespace: record.namespace, in: db) + var updatedRecord = record + updatedRecord.dedupeKey = dedupeKey + try repository.upsertRecord(updatedRecord, in: db) } - return tokens.joined(separator: " OR ") } - private static func schemaVersion(in db: Database) throws -> Int { - try Int.fetchOne(db, sql: "PRAGMA user_version;") ?? 0 + public func query(_ query: MemoryQuery) async throws -> MemoryQueryResult { + try MemoryQueryEngine.validateNamespace(query.namespace) + let records = try await dbQueue.read { db in + try repository.loadRecords(namespace: query.namespace, in: db) + } + let rawScores = try await dbQueue.read { db in + try repository.loadRawFTSScores( + namespace: query.namespace, + queryText: query.text, + in: db + ) + } + + let candidates = records.map { record in + MemoryQueryEngine.Candidate( + record: record, + textScore: rawScores[record.id], + textScoreOrdering: .lowerIsBetter + ) + } + + return try MemoryQueryEngine.evaluate( + candidates: candidates, + query: query + ) } - private static func makeMigrator() -> DatabaseMigrator { - var migrator = DatabaseMigrator() + public func record( + id: String, + namespace: String + ) async throws -> MemoryRecord? { + try MemoryQueryEngine.validateNamespace(namespace) + return try await dbQueue.read { db in + try repository.loadRecord(id: id, namespace: namespace, in: db) + } + } - migrator.registerMigration("memory_store_v1") { db in - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_records ( - namespace TEXT NOT NULL, - id TEXT NOT NULL, - scope TEXT NOT NULL, - kind TEXT NOT NULL, - summary TEXT NOT NULL, - evidence_json TEXT NOT NULL, - importance REAL NOT NULL, - created_at REAL NOT NULL, - observed_at REAL, - expires_at REAL, - tags_json TEXT NOT NULL, - related_ids_json TEXT NOT NULL, - dedupe_key TEXT, - is_pinned INTEGER NOT NULL, - attributes_json TEXT, - status TEXT NOT NULL, - PRIMARY KEY(namespace, id) - ); - """) + public func list(_ query: MemoryRecordListQuery) async throws -> [MemoryRecord] { + try MemoryQueryEngine.validateNamespace(query.namespace) + return try await dbQueue.read { db in + try repository.loadRecords(namespace: query.namespace, in: db) + .filter { record in + if !query.includeArchived, record.status == .archived { + return false + } + if !query.scopes.isEmpty, !query.scopes.contains(record.scope) { + return false + } + if !query.kinds.isEmpty, !query.kinds.contains(record.kind) { + return false + } + return true + } + .sorted { + if $0.effectiveDate == $1.effectiveDate { + return $0.id < $1.id + } + return $0.effectiveDate > $1.effectiveDate + } + .prefix(query.limit ?? .max) + .map { $0 } + } + } - try db.execute(sql: """ - CREATE UNIQUE INDEX IF NOT EXISTS memory_records_namespace_dedupe - ON memory_records(namespace, dedupe_key) - WHERE dedupe_key IS NOT NULL; - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_scope - ON memory_records(namespace, scope); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_kind - ON memory_records(namespace, kind); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_status - ON memory_records(namespace, status); - """) + public func diagnostics(namespace: String) async throws -> MemoryStoreDiagnostics { + try MemoryQueryEngine.validateNamespace(namespace) + let records = try await dbQueue.read { db in + try repository.loadRecords(namespace: namespace, in: db) + } + let schemaVersion = try await dbQueue.read { db in + try schema.existingVersion(in: db) + } + return MemoryStoreDiagnostics( + namespace: namespace, + implementation: "sqlite", + schemaVersion: schemaVersion, + totalRecords: records.count, + activeRecords: records.filter { $0.status == .active }.count, + archivedRecords: records.filter { $0.status == .archived }.count, + countsByScope: Dictionary(grouping: records, by: \.scope).mapValues(\.count), + countsByKind: Dictionary(grouping: records, by: \.kind).mapValues(\.count) + ) + } - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_tags ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - tag TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_tags_lookup - ON memory_tags(namespace, tag, record_id); - """) + public func compact(_ request: MemoryCompactionRequest) async throws { + try MemoryQueryEngine.validateNamespace(request.replacement.namespace) + let repository = self.repository + try await writeTransaction { db in + try repository.ensureRecordIDAvailable( + request.replacement.id, + namespace: request.replacement.namespace, + in: db + ) + if let dedupeKey = request.replacement.dedupeKey { + try repository.ensureDedupeKeyAvailable( + dedupeKey, + namespace: request.replacement.namespace, + in: db + ) + } + try repository.upsertRecord(request.replacement, in: db) + for sourceID in request.sourceIDs { + try repository.archiveRecord(id: sourceID, namespace: request.replacement.namespace, in: db) + } + } + } - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_related_ids ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - related_id TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_related_lookup - ON memory_related_ids(namespace, related_id, record_id); - """) + public func archive(ids: [String], namespace: String) async throws { + try MemoryQueryEngine.validateNamespace(namespace) + let repository = self.repository + try await writeTransaction { db in + for id in ids { + try repository.archiveRecord(id: id, namespace: namespace, in: db) + } + } + } - try db.execute(sql: """ - CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts - USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); - """) + public func delete(ids: [String], namespace: String) async throws { + try MemoryQueryEngine.validateNamespace(namespace) + let repository = self.repository + try await writeTransaction { db in + for id in ids { + try repository.deleteRecord(id: id, namespace: namespace, in: db) + } + } + } + + @discardableResult + public func pruneExpired( + now: Date, + namespace: String + ) async throws -> Int { + try MemoryQueryEngine.validateNamespace(namespace) + let repository = self.repository + let expiredIDs = try await dbQueue.read { db in + try repository.loadRecords(namespace: namespace, in: db) + .filter { record in + !record.isPinned && + record.status == .active && + (record.expiresAt?.compare(now) == .orderedAscending || + record.expiresAt?.compare(now) == .orderedSame) + } + .map(\.id) + } - try db.execute(sql: "PRAGMA user_version = \(currentSchemaVersion)") + try await writeTransaction { db in + for id in expiredIDs { + try repository.deleteRecord(id: id, namespace: namespace, in: db) + } } + return expiredIDs.count + } - return migrator + private func writeTransaction(_ operation: @escaping @Sendable (Database) throws -> Void) async throws { + try await dbQueue.writeWithoutTransaction { db in + try db.inTransaction { + try operation(db) + return .commit + } + } } } From 1afe14a943ae5a9188fef0f760f3a2c80f5fe6b7 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 13:07:53 +1100 Subject: [PATCH 06/19] Add hidden context compaction and clean up GRDB queries --- .../Shared/AgentDemoRuntimeFactory.swift | 20 + .../Shared/AgentDemoViewModel+Messaging.swift | 2 + .../Shared/AgentDemoViewModel.swift | 48 ++ .../Shared/ThreadDetailView.swift | 100 ++++ .../CodexKit/Memory/SQLiteMemoryStore.swift | 180 ++++--- .../Runtime/AgentContextCompaction.swift | 156 ++++++ Sources/CodexKit/Runtime/AgentHistory.swift | 38 +- Sources/CodexKit/Runtime/AgentModels.swift | 14 + Sources/CodexKit/Runtime/AgentQuerying.swift | 12 + .../AgentRuntime+ContextCompaction.swift | 310 +++++++++++ .../Runtime/AgentRuntime+History.swift | 14 + .../Runtime/AgentRuntime+Messaging.swift | 70 ++- .../Runtime/AgentRuntime+Threads.swift | 1 + Sources/CodexKit/Runtime/AgentRuntime.swift | 7 +- .../CodexResponsesBackend+Compaction.swift | 119 +++++ .../CodexResponsesBackend+Models.swift | 20 + .../Runtime/CodexResponsesBackend.swift | 8 +- .../Runtime/GRDBRuntimeStateStore.swift | 496 ++++++++++-------- .../CodexKit/Runtime/RuntimeStateStore.swift | 75 ++- .../AgentRuntimeHistoryTests.swift | 237 ++++++++- 20 files changed, 1608 insertions(+), 319 deletions(-) create mode 100644 Sources/CodexKit/Runtime/AgentContextCompaction.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift create mode 100644 Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift index 5b3d625..4a508f7 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoRuntimeFactory.swift @@ -116,6 +116,16 @@ enum AgentDemoRuntimeFactory { maxMemories: 2 ) ) + ), + contextCompaction: AgentContextCompactionConfiguration( + isEnabled: true, + mode: .automatic, + visibility: .hidden, + strategy: .preferRemoteThenLocal, + trigger: .init( + estimatedTokenThreshold: 2_000, + retryOnContextLimitError: true + ) ) )) } @@ -157,6 +167,16 @@ enum AgentDemoRuntimeFactory { maxMemories: 2 ) ) + ), + contextCompaction: AgentContextCompactionConfiguration( + isEnabled: true, + mode: .automatic, + visibility: .hidden, + strategy: .preferRemoteThenLocal, + trigger: .init( + estimatedTokenThreshold: 2_000, + retryOnContextLimitError: true + ) ) )) } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift index 092f635..eda56b3 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift @@ -323,6 +323,7 @@ extension AgentDemoViewModel { } threads = await runtime.threads() developerLog("Turn completed. threadID=\(threadID)") + await refreshThreadContextState(for: threadID) case let .turnFailed(error): flushStreamingText(force: true) @@ -330,6 +331,7 @@ extension AgentDemoViewModel { developerErrorLog( "Turn failed. threadID=\(threadID) code=\(error.code) message=\(error.message)" ) + await refreshThreadContextState(for: threadID) reportError(error.message) } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift index 6b9885e..3a8c4a8 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift @@ -232,6 +232,8 @@ final class AgentDemoViewModel: @unchecked Sendable { var cachedAIReminderGeneratedAt: Date? var reasoningEffort: ReasoningEffort var currentAuthenticationMethod: DemoAuthenticationMethod = .deviceCode + var activeThreadContextState: AgentThreadContextState? + var isCompactingThreadContext = false let approvalInbox: ApprovalInbox let deviceCodePromptCoordinator: DeviceCodePromptCoordinator @@ -509,6 +511,7 @@ final class AgentDemoViewModel: @unchecked Sendable { activeThreadID = id setMessages(await runtime.messages(for: id)) streamingText = "" + await refreshThreadContextState(for: id) } func sendComposerText() async { @@ -641,15 +644,18 @@ final class AgentDemoViewModel: @unchecked Sendable { if let selectedThreadID, threads.contains(where: { $0.id == selectedThreadID }) { setMessages(await runtime.messages(for: selectedThreadID)) + await refreshThreadContextState(for: selectedThreadID) return } if let firstThread = threads.first { activeThreadID = firstThread.id setMessages(await runtime.messages(for: firstThread.id)) + await refreshThreadContextState(for: firstThread.id) } else { activeThreadID = nil messages = [] + activeThreadContextState = nil } } @@ -663,6 +669,48 @@ final class AgentDemoViewModel: @unchecked Sendable { isRunningSkillPolicyProbe = false skillPolicyProbeResult = nil activeThreadID = nil + activeThreadContextState = nil + } + + func refreshThreadContextState(for threadID: String? = nil) async { + guard let resolvedThreadID = threadID ?? activeThreadID else { + activeThreadContextState = nil + return + } + + do { + activeThreadContextState = try await runtime.fetchThreadContextState(id: resolvedThreadID) + } catch { + activeThreadContextState = nil + developerErrorLog("Failed to fetch thread context state. threadID=\(resolvedThreadID) error=\(error.localizedDescription)") + } + } + + func compactActiveThreadContext() async { + guard let activeThreadID else { + lastError = "Select a thread before compacting its prompt context." + return + } + guard !isCompactingThreadContext else { + return + } + + isCompactingThreadContext = true + defer { + isCompactingThreadContext = false + } + + do { + developerLog("Manual context compaction started. threadID=\(activeThreadID)") + activeThreadContextState = try await runtime.compactThreadContext(id: activeThreadID) + threads = await runtime.threads() + setMessages(await runtime.messages(for: activeThreadID)) + developerLog( + "Manual context compaction finished. threadID=\(activeThreadID) generation=\(activeThreadContextState?.generation ?? 0) effectiveMessages=\(activeThreadContextState?.effectiveMessages.count ?? 0)" + ) + } catch { + reportError(error) + } } func setMessages(_ incoming: [AgentMessage]) { diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift index 680e718..4cd2b92 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift @@ -26,6 +26,7 @@ struct ThreadDetailView: View { ScrollView { VStack(alignment: .leading, spacing: 16) { threadHeaderCard + compactionCard transcriptCard } .padding(20) @@ -126,6 +127,88 @@ private extension ThreadDetailView { } } + var compactionCard: some View { + DemoSectionCard { + VStack(alignment: .leading, spacing: 8) { + Text("Context Compaction") + .font(.headline) + + Text("Preserves the visible transcript, but rewrites the runtime’s hidden effective prompt context for future turns.") + .font(.subheadline) + .foregroundStyle(.secondary) + } + + HStack(spacing: 12) { + compactionMetric( + title: "Visible Messages", + value: "\(threadMessages.count)" + ) + compactionMetric( + title: "Effective Messages", + value: "\(viewModel.activeThreadContextState?.effectiveMessages.count ?? threadMessages.count)" + ) + compactionMetric( + title: "Generation", + value: "\(viewModel.activeThreadContextState?.generation ?? 0)" + ) + } + + if let contextState = viewModel.activeThreadContextState { + VStack(alignment: .leading, spacing: 6) { + if let reason = contextState.lastCompactionReason { + Text("Last compaction: \(reason.rawValue)") + .font(.caption) + .foregroundStyle(.secondary) + } + + if let lastCompactedAt = contextState.lastCompactedAt { + Text("Updated \(lastCompactedAt.formatted(date: .abbreviated, time: .shortened))") + .font(.caption) + .foregroundStyle(.secondary) + } + + if let summaryMessage = contextState.effectiveMessages.first(where: { $0.role == .system }), + !summaryMessage.text.isEmpty { + Text(summaryMessage.text) + .font(.caption) + .foregroundStyle(.secondary) + .lineLimit(4) + .padding(.top, 2) + } + } + } else { + Text("No compacted context exists yet for this thread. Send a few messages, then compact to compare the prompt working set against the preserved transcript.") + .font(.caption) + .foregroundStyle(.secondary) + } + + HStack(spacing: 10) { + Button { + Task { + await viewModel.compactActiveThreadContext() + } + } label: { + Label( + viewModel.isCompactingThreadContext ? "Compacting..." : "Compact Context Now", + systemImage: viewModel.isCompactingThreadContext ? "hourglass" : "arrow.triangle.branch" + ) + } + .buttonStyle(.borderedProminent) + .disabled(viewModel.session == nil || activeThread == nil || viewModel.isCompactingThreadContext) + + Button { + Task { + await viewModel.refreshThreadContextState(for: threadID) + } + } label: { + Label("Refresh State", systemImage: "arrow.clockwise") + } + .buttonStyle(.bordered) + .disabled(viewModel.session == nil || activeThread == nil) + } + } + } + var composer: some View { let photoPickerIconName = isImportingPhoto ? "hourglass" : "photo.on.rectangle" @@ -270,6 +353,23 @@ private extension ThreadDetailView { return "image/jpeg" } + @ViewBuilder + func compactionMetric(title: String, value: String) -> some View { + VStack(alignment: .leading, spacing: 4) { + Text(title) + .font(.caption.weight(.medium)) + .foregroundStyle(.secondary) + Text(value) + .font(.headline.monospacedDigit()) + } + .frame(maxWidth: .infinity, alignment: .leading) + .padding(12) + .background( + RoundedRectangle(cornerRadius: 12, style: .continuous) + .fill(Color.primary.opacity(0.04)) + ) + } + } private struct ThreadStreamingBubble: View { diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift index 74d8f22..de31094 100644 --- a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift +++ b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift @@ -5,7 +5,7 @@ private struct SQLiteMemoryStoreSchema: Sendable { let currentVersion = 1 func existingVersion(in db: Database) throws -> Int { - try Int.fetchOne(db, sql: "PRAGMA user_version;") ?? 0 + try MemoryUserVersionQuery().execute(in: db) } func makeMigrator() -> DatabaseMigrator { @@ -157,18 +157,9 @@ private struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws -> [MemoryRecord] { - let rows = try Row.fetchAll( - db, - sql: """ - SELECT - id, scope, kind, summary, evidence_json, importance, - created_at, observed_at, expires_at, tags_json, - related_ids_json, dedupe_key, is_pinned, attributes_json, status - FROM memory_records - WHERE namespace = ?; - """, - arguments: [namespace] - ) + let rows = try MemoryRecordRow + .filter(Column("namespace") == namespace) + .fetchAll(db) return try rows.map { row in try makeRecord(from: row, namespace: namespace) @@ -185,15 +176,10 @@ private struct SQLiteMemoryStoreRepository: Sendable { return [:] } - let rows = try Row.fetchAll( - db, - sql: """ - SELECT record_id, bm25(memory_fts) AS score - FROM memory_fts - WHERE namespace = ? AND memory_fts MATCH ?; - """, - arguments: [namespace, matchQuery] - ) + let rows = try MemoryFTSScoreRowsRequest( + namespace: namespace, + matchQuery: matchQuery + ).execute(in: db) var scores: [String: Double] = [:] for row in rows { @@ -209,19 +195,10 @@ private struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws -> MemoryRecord? { - let row = try Row.fetchOne( - db, - sql: """ - SELECT - id, scope, kind, summary, evidence_json, importance, - created_at, observed_at, expires_at, tags_json, - related_ids_json, dedupe_key, is_pinned, attributes_json, status - FROM memory_records - WHERE namespace = ? AND id = ? - LIMIT 1; - """, - arguments: [namespace, id] - ) + let row = try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("id") == id) + .fetchOne(db) guard let row else { return nil } @@ -263,17 +240,11 @@ private struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws { - if let id = try String.fetchOne( - db, - sql: """ - SELECT id - FROM memory_records - WHERE namespace = ? AND dedupe_key = ? - LIMIT 1; - """, - arguments: [namespace, dedupeKey] - ) { - try deleteRecord(id: id, namespace: namespace, in: db) + if let row = try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("dedupe_key") == dedupeKey) + .fetchOne(db) { + try deleteRecord(id: row.id, namespace: namespace, in: db) } } @@ -348,26 +319,26 @@ private struct SQLiteMemoryStoreRepository: Sendable { } private func makeRecord( - from row: Row, + from row: MemoryRecordRow, namespace: String ) throws -> MemoryRecord { MemoryRecord( - id: row["id"], + id: row.id, namespace: namespace, - scope: MemoryScope(rawValue: row["scope"]), - kind: row["kind"], - summary: row["summary"], - evidence: try codec.decode([String].self, from: row["evidence_json"]), - importance: row["importance"], - createdAt: Date(timeIntervalSince1970: row["created_at"]), - observedAt: (row["observed_at"] as Double?).map(Date.init(timeIntervalSince1970:)), - expiresAt: (row["expires_at"] as Double?).map(Date.init(timeIntervalSince1970:)), - tags: try codec.decode([String].self, from: row["tags_json"]), - relatedIDs: try codec.decode([String].self, from: row["related_ids_json"]), - dedupeKey: row["dedupe_key"], - isPinned: (row["is_pinned"] as Int64? ?? 0) == 1, - attributes: try codec.decodeNullable(JSONValue.self, from: row["attributes_json"]), - status: MemoryRecordStatus(rawValue: row["status"]) ?? .active + scope: MemoryScope(rawValue: row.scope), + kind: row.kind, + summary: row.summary, + evidence: try codec.decode([String].self, from: row.evidenceJSON), + importance: row.importance, + createdAt: Date(timeIntervalSince1970: row.createdAt), + observedAt: row.observedAt.map(Date.init(timeIntervalSince1970:)), + expiresAt: row.expiresAt.map(Date.init(timeIntervalSince1970:)), + tags: try codec.decode([String].self, from: row.tagsJSON), + relatedIDs: try codec.decode([String].self, from: row.relatedIDsJSON), + dedupeKey: row.dedupeKey, + isPinned: row.isPinned, + attributes: try codec.decodeNullable(JSONValue.self, from: row.attributesJSON), + status: MemoryRecordStatus(rawValue: row.status) ?? .active ) } @@ -376,15 +347,10 @@ private struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws -> Bool { - try Bool.fetchOne( - db, - sql: """ - SELECT EXISTS( - SELECT 1 FROM memory_records WHERE namespace = ? AND id = ? - ); - """, - arguments: [namespace, id] - ) ?? false + try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("id") == id) + .fetchCount(db) > 0 } private func recordExists( @@ -392,15 +358,73 @@ private struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws -> Bool { - try Bool.fetchOne( - db, + try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("dedupe_key") == dedupeKey) + .fetchCount(db) > 0 + } +} + +private struct MemoryUserVersionQuery: Sendable { + func execute(in db: Database) throws -> Int { + // PRAGMA is SQLite-specific and doesn't map cleanly to GRDB's query interface. + let row = try SQLRequest(sql: "PRAGMA user_version;").fetchOne(db) + return row?[0] ?? 0 + } +} + +private struct MemoryFTSScoreRowsRequest: Sendable { + let namespace: String + let matchQuery: String + + func execute(in db: Database) throws -> [Row] { + // FTS5 MATCH and bm25() are much clearer and more direct in raw SQL than in GRDB's query interface. + try SQLRequest( sql: """ - SELECT EXISTS( - SELECT 1 FROM memory_records WHERE namespace = ? AND dedupe_key = ? - ); + SELECT record_id, bm25(memory_fts) AS score + FROM memory_fts + WHERE namespace = ? AND memory_fts MATCH ?; """, - arguments: [namespace, dedupeKey] - ) ?? false + arguments: [namespace, matchQuery] + ).fetchAll(db) + } +} + +private struct MemoryRecordRow: FetchableRecord, TableRecord { + static let databaseTableName = "memory_records" + + let id: String + let scope: String + let kind: String + let summary: String + let evidenceJSON: String + let importance: Double + let createdAt: Double + let observedAt: Double? + let expiresAt: Double? + let tagsJSON: String + let relatedIDsJSON: String + let dedupeKey: String? + let isPinned: Bool + let attributesJSON: String? + let status: String + + init(row: Row) { + id = row["id"] + scope = row["scope"] + kind = row["kind"] + summary = row["summary"] + evidenceJSON = row["evidence_json"] + importance = row["importance"] + createdAt = row["created_at"] + observedAt = row["observed_at"] + expiresAt = row["expires_at"] + tagsJSON = row["tags_json"] + relatedIDsJSON = row["related_ids_json"] + dedupeKey = row["dedupe_key"] + isPinned = row["is_pinned"] as Bool? ?? false + attributesJSON = row["attributes_json"] + status = row["status"] } } diff --git a/Sources/CodexKit/Runtime/AgentContextCompaction.swift b/Sources/CodexKit/Runtime/AgentContextCompaction.swift new file mode 100644 index 0000000..eb73d42 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentContextCompaction.swift @@ -0,0 +1,156 @@ +import Foundation + +public enum AgentContextCompactionMode: String, Codable, Hashable, Sendable { + case manual + case automatic + + var supportsManual: Bool { + true + } + + var supportsAutomatic: Bool { + self == .automatic + } +} + +public enum AgentContextCompactionVisibility: String, Codable, Hashable, Sendable { + case hidden + case debugVisible +} + +public enum AgentContextCompactionStrategy: String, Codable, Hashable, Sendable { + case preferRemoteThenLocal + case remoteOnly + case localOnly +} + +public struct AgentContextCompactionTrigger: Codable, Hashable, Sendable { + public var estimatedTokenThreshold: Int + public var retryOnContextLimitError: Bool + + public init( + estimatedTokenThreshold: Int = 16_000, + retryOnContextLimitError: Bool = true + ) { + self.estimatedTokenThreshold = estimatedTokenThreshold + self.retryOnContextLimitError = retryOnContextLimitError + } +} + +public struct AgentContextCompactionConfiguration: Codable, Hashable, Sendable { + public var isEnabled: Bool + public var mode: AgentContextCompactionMode + public var visibility: AgentContextCompactionVisibility + public var strategy: AgentContextCompactionStrategy + public var trigger: AgentContextCompactionTrigger + + public init( + isEnabled: Bool = false, + mode: AgentContextCompactionMode = .automatic, + visibility: AgentContextCompactionVisibility = .hidden, + strategy: AgentContextCompactionStrategy = .preferRemoteThenLocal, + trigger: AgentContextCompactionTrigger = AgentContextCompactionTrigger() + ) { + self.isEnabled = isEnabled + self.mode = mode + self.visibility = visibility + self.strategy = strategy + self.trigger = trigger + } +} + +public enum AgentContextCompactionReason: String, Codable, Hashable, Sendable { + case manual + case automaticPreTurn + case automaticRetry + case modelChange +} + +public struct AgentContextCompactionMarker: Codable, Hashable, Sendable { + public let generation: Int + public let reason: AgentContextCompactionReason + public let effectiveMessageCountBefore: Int + public let effectiveMessageCountAfter: Int + public let debugSummaryPreview: String? + + public init( + generation: Int, + reason: AgentContextCompactionReason, + effectiveMessageCountBefore: Int, + effectiveMessageCountAfter: Int, + debugSummaryPreview: String? = nil + ) { + self.generation = generation + self.reason = reason + self.effectiveMessageCountBefore = effectiveMessageCountBefore + self.effectiveMessageCountAfter = effectiveMessageCountAfter + self.debugSummaryPreview = debugSummaryPreview + } +} + +public struct AgentThreadContextState: Codable, Hashable, Sendable { + public let threadID: String + public let effectiveMessages: [AgentMessage] + public let generation: Int + public let lastCompactedAt: Date? + public let lastCompactionReason: AgentContextCompactionReason? + public let latestMarkerID: String? + + public init( + threadID: String, + effectiveMessages: [AgentMessage], + generation: Int = 0, + lastCompactedAt: Date? = nil, + lastCompactionReason: AgentContextCompactionReason? = nil, + latestMarkerID: String? = nil + ) { + self.threadID = threadID + self.effectiveMessages = effectiveMessages + self.generation = generation + self.lastCompactedAt = lastCompactedAt + self.lastCompactionReason = lastCompactionReason + self.latestMarkerID = latestMarkerID + } +} + +public struct AgentCompactionResult: Codable, Hashable, Sendable { + public let effectiveMessages: [AgentMessage] + public let summaryPreview: String? + + public init( + effectiveMessages: [AgentMessage], + summaryPreview: String? = nil + ) { + self.effectiveMessages = effectiveMessages + self.summaryPreview = summaryPreview + } +} + +public protocol AgentBackendContextCompacting: Sendable { + func compactContext( + thread: AgentThread, + effectiveHistory: [AgentMessage], + instructions: String, + tools: [ToolDefinition], + session: ChatGPTSession + ) async throws -> AgentCompactionResult +} + +public struct ThreadContextStateQuery: AgentQuerySpec { + public typealias Result = [AgentThreadContextState] + + public var threadIDs: Set? + public var limit: Int? + + public init( + threadIDs: Set? = nil, + limit: Int? = nil + ) { + self.threadIDs = threadIDs + self.limit = limit + } + + public func execute(in state: StoredRuntimeState) throws -> [AgentThreadContextState] { + state.execute(self) + } +} diff --git a/Sources/CodexKit/Runtime/AgentHistory.swift b/Sources/CodexKit/Runtime/AgentHistory.swift index bd75dc5..e2ce06f 100644 --- a/Sources/CodexKit/Runtime/AgentHistory.swift +++ b/Sources/CodexKit/Runtime/AgentHistory.swift @@ -109,6 +109,7 @@ public struct AgentHistoryFilter: Sendable, Hashable { public var includeStructuredOutputs: Bool public var includeApprovals: Bool public var includeSystemEvents: Bool + public var includeCompactionEvents: Bool public init( includeMessages: Bool = true, @@ -116,7 +117,8 @@ public struct AgentHistoryFilter: Sendable, Hashable { includeToolResults: Bool = true, includeStructuredOutputs: Bool = true, includeApprovals: Bool = true, - includeSystemEvents: Bool = true + includeSystemEvents: Bool = true, + includeCompactionEvents: Bool = false ) { self.includeMessages = includeMessages self.includeToolCalls = includeToolCalls @@ -124,6 +126,7 @@ public struct AgentHistoryFilter: Sendable, Hashable { self.includeStructuredOutputs = includeStructuredOutputs self.includeApprovals = includeApprovals self.includeSystemEvents = includeSystemEvents + self.includeCompactionEvents = includeCompactionEvents } } @@ -174,6 +177,7 @@ public protocol AgentRuntimeThreadInspecting: Sendable { query: AgentHistoryQuery ) async throws -> AgentThreadHistoryPage func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? + func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? } public extension AgentThreadSummary { @@ -288,6 +292,7 @@ public enum AgentSystemEventType: String, Codable, Hashable, Sendable { case turnStarted case turnCompleted case turnFailed + case contextCompacted } public struct AgentSystemEventRecord: Codable, Hashable, Sendable { @@ -297,6 +302,7 @@ public struct AgentSystemEventRecord: Codable, Hashable, Sendable { public let status: AgentThreadStatus? public let turnSummary: AgentTurnSummary? public let error: AgentRuntimeError? + public let compaction: AgentContextCompactionMarker? public let occurredAt: Date public init( @@ -306,6 +312,7 @@ public struct AgentSystemEventRecord: Codable, Hashable, Sendable { status: AgentThreadStatus? = nil, turnSummary: AgentTurnSummary? = nil, error: AgentRuntimeError? = nil, + compaction: AgentContextCompactionMarker? = nil, occurredAt: Date = Date() ) { self.type = type @@ -314,6 +321,7 @@ public struct AgentSystemEventRecord: Codable, Hashable, Sendable { self.status = status self.turnSummary = turnSummary self.error = error + self.compaction = compaction self.occurredAt = occurredAt } } @@ -566,17 +574,20 @@ extension AgentHistoryFilter { func matches(_ item: AgentHistoryItem) -> Bool { switch item { case .message: - includeMessages + return includeMessages case .toolCall: - includeToolCalls + return includeToolCalls case .toolResult: - includeToolResults + return includeToolResults case .structuredOutput: - includeStructuredOutputs + return includeStructuredOutputs case .approval: - includeApprovals - case .systemEvent: - includeSystemEvents + return includeApprovals + case let .systemEvent(record): + if record.type == .contextCompacted { + return includeSystemEvents && includeCompactionEvents + } + return includeSystemEvents } } } @@ -629,9 +640,20 @@ extension AgentHistoryItem { case let .approval(record): return "approval:\(record.request?.id ?? record.resolution?.requestID ?? UUID().uuidString)" case let .systemEvent(record): + if record.type == .contextCompacted, + let generation = record.compaction?.generation { + return "systemEvent:\(record.type.rawValue):\(record.threadID):\(generation)" + } return "systemEvent:\(record.type.rawValue):\(record.turnID ?? record.threadID)" } } + + public var isCompactionMarker: Bool { + guard case let .systemEvent(record) = self else { + return false + } + return record.type == .contextCompacted + } } extension AgentHistoryRecord { diff --git a/Sources/CodexKit/Runtime/AgentModels.swift b/Sources/CodexKit/Runtime/AgentModels.swift index 8ec4c27..bbdddf0 100644 --- a/Sources/CodexKit/Runtime/AgentModels.swift +++ b/Sources/CodexKit/Runtime/AgentModels.swift @@ -154,6 +154,20 @@ public struct AgentRuntimeError: Error, LocalizedError, Equatable, Hashable, Sen message: "The active skill policy requires tool calls that did not occur: \(toolNames.sorted().joined(separator: ", "))." ) } + + public static func contextCompactionDisabled() -> AgentRuntimeError { + AgentRuntimeError( + code: "context_compaction_disabled", + message: "Context compaction is not enabled for this runtime." + ) + } + + public static func contextCompactionUnsupported() -> AgentRuntimeError { + AgentRuntimeError( + code: "context_compaction_unsupported", + message: "Context compaction could not be performed with the active backend and strategy." + ) + } } public enum AgentRole: String, Codable, Hashable, Sendable { diff --git a/Sources/CodexKit/Runtime/AgentQuerying.swift b/Sources/CodexKit/Runtime/AgentQuerying.swift index 392ccf8..a6f6ae7 100644 --- a/Sources/CodexKit/Runtime/AgentQuerying.swift +++ b/Sources/CodexKit/Runtime/AgentQuerying.swift @@ -131,6 +131,7 @@ public struct HistoryItemsQuery: AgentQuerySpec { public var createdAtRange: ClosedRange? public var turnID: String? public var includeRedacted: Bool + public var includeCompactionEvents: Bool public var sort: AgentHistorySort public var page: AgentQueryPage? @@ -140,6 +141,7 @@ public struct HistoryItemsQuery: AgentQuerySpec { createdAtRange: ClosedRange? = nil, turnID: String? = nil, includeRedacted: Bool = true, + includeCompactionEvents: Bool = false, sort: AgentHistorySort = .sequence(.ascending), page: AgentQueryPage? = nil ) { @@ -148,6 +150,7 @@ public struct HistoryItemsQuery: AgentQuerySpec { self.createdAtRange = createdAtRange self.turnID = turnID self.includeRedacted = includeRedacted + self.includeCompactionEvents = includeCompactionEvents self.sort = sort self.page = page } @@ -404,6 +407,9 @@ public enum AgentStoreWriteOperation: Sendable, Hashable { case upsertThread(AgentThread) case upsertSummary(threadID: String, summary: AgentThreadSummary) case appendHistoryItems(threadID: String, items: [AgentHistoryRecord]) + case appendCompactionMarker(threadID: String, marker: AgentHistoryRecord) + case upsertThreadContextState(threadID: String, state: AgentThreadContextState?) + case deleteThreadContextState(threadID: String) case setPendingState(threadID: String, state: AgentThreadPendingState?) case setPartialStructuredSnapshot(threadID: String, snapshot: AgentPartialStructuredOutputSnapshot?) case upsertToolSession(threadID: String, session: AgentToolSessionRecord) @@ -427,6 +433,12 @@ extension AgentStoreWriteOperation { threadID case let .appendHistoryItems(threadID, _): threadID + case let .appendCompactionMarker(threadID, _): + threadID + case let .upsertThreadContextState(threadID, _): + threadID + case let .deleteThreadContextState(threadID): + threadID case let .setPendingState(threadID, _): threadID case let .setPartialStructuredSnapshot(threadID, _): diff --git a/Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift b/Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift new file mode 100644 index 0000000..aff780c --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift @@ -0,0 +1,310 @@ +import Foundation + +extension AgentRuntime { + func effectiveHistory(for threadID: String) -> [AgentMessage] { + state.contextStateByThread[threadID]?.effectiveMessages + ?? state.messagesByThread[threadID] + ?? [] + } + + func shouldUseCompaction() -> Bool { + contextCompactionConfiguration.isEnabled + } + + func appendEffectiveMessage(_ message: AgentMessage) { + guard shouldUseCompaction() || state.contextStateByThread[message.threadID] != nil else { + return + } + + let current = state.contextStateByThread[message.threadID] + ?? AgentThreadContextState( + threadID: message.threadID, + effectiveMessages: state.messagesByThread[message.threadID] ?? [] + ) + let updated = AgentThreadContextState( + threadID: current.threadID, + effectiveMessages: current.effectiveMessages + [message], + generation: current.generation, + lastCompactedAt: current.lastCompactedAt, + lastCompactionReason: current.lastCompactionReason, + latestMarkerID: current.latestMarkerID + ) + state.contextStateByThread[message.threadID] = updated + enqueueStoreOperation(.upsertThreadContextState(threadID: message.threadID, state: updated)) + } + + func maybeCompactThreadContextBeforeTurn( + thread: AgentThread, + request: UserMessageRequest, + instructions: String, + tools: [ToolDefinition], + session: ChatGPTSession + ) async throws { + guard shouldUseCompaction(), + contextCompactionConfiguration.mode.supportsAutomatic + else { + return + } + + let threshold = max(1, contextCompactionConfiguration.trigger.estimatedTokenThreshold) + let estimatedTokens = approximateTokenCount( + for: effectiveHistory(for: thread.id), + pendingMessage: request, + instructions: instructions + ) + guard estimatedTokens > threshold else { + return + } + + _ = try await compactThreadContext( + id: thread.id, + reason: .automaticPreTurn, + instructions: instructions, + tools: tools, + session: session + ) + } + + func maybeCompactThreadContextAfterContextFailure( + thread: AgentThread, + request: UserMessageRequest, + instructions: String, + tools: [ToolDefinition], + session: ChatGPTSession, + error: Error + ) async throws -> Bool { + guard shouldUseCompaction(), + contextCompactionConfiguration.mode.supportsAutomatic, + contextCompactionConfiguration.trigger.retryOnContextLimitError, + isContextPressureError(error) + else { + return false + } + + _ = try await compactThreadContext( + id: thread.id, + reason: .automaticRetry, + instructions: instructions, + tools: tools, + session: session + ) + return true + } + + @discardableResult + public func compactThreadContext(id threadID: String) async throws -> AgentThreadContextState { + guard shouldUseCompaction(), + contextCompactionConfiguration.mode.supportsManual + else { + throw AgentRuntimeError.contextCompactionDisabled() + } + + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + let session = try await sessionManager.requireSession() + let tools = await toolRegistry.allDefinitions() + let resolvedInstructions = await resolveInstructions( + thread: thread, + message: UserMessageRequest(text: "", images: []), + resolvedTurnSkills: ResolvedTurnSkills( + threadSkills: [], + turnSkills: [], + compiledToolPolicy: CompiledSkillToolPolicy( + allowedToolNames: nil, + requiredToolNames: [], + toolSequence: nil, + maxToolCalls: nil + ) + ) + ) + return try await compactThreadContext( + id: threadID, + reason: .manual, + instructions: resolvedInstructions, + tools: tools, + session: session + ) + } + + @discardableResult + func compactThreadContext( + id threadID: String, + reason: AgentContextCompactionReason, + instructions: String, + tools: [ToolDefinition], + session: ChatGPTSession + ) async throws -> AgentThreadContextState { + guard shouldUseCompaction() else { + throw AgentRuntimeError.contextCompactionDisabled() + } + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + let current = state.contextStateByThread[threadID] + ?? AgentThreadContextState( + threadID: threadID, + effectiveMessages: state.messagesByThread[threadID] ?? [] + ) + let result = try await performCompaction( + thread: thread, + effectiveHistory: current.effectiveMessages, + instructions: instructions, + tools: tools, + session: session + ) + + let markerTime = Date() + let nextGeneration = current.generation + 1 + let markerPayload = AgentContextCompactionMarker( + generation: nextGeneration, + reason: reason, + effectiveMessageCountBefore: current.effectiveMessages.count, + effectiveMessageCountAfter: result.effectiveMessages.count, + debugSummaryPreview: result.summaryPreview + ) + let markerRecord = AgentHistoryRecord( + sequenceNumber: state.nextHistorySequenceByThread[threadID] + ?? ((state.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1), + createdAt: markerTime, + item: .systemEvent( + AgentSystemEventRecord( + type: .contextCompacted, + threadID: threadID, + compaction: markerPayload, + occurredAt: markerTime + ) + ) + ) + + let updated = AgentThreadContextState( + threadID: threadID, + effectiveMessages: result.effectiveMessages, + generation: nextGeneration, + lastCompactedAt: markerTime, + lastCompactionReason: reason, + latestMarkerID: markerRecord.id + ) + state.contextStateByThread[threadID] = updated + enqueueStoreOperation(.upsertThreadContextState(threadID: threadID, state: updated)) + state.historyByThread[threadID, default: []].append(markerRecord) + state.nextHistorySequenceByThread[threadID] = nextGenerationSequence(afterAppendingTo: threadID) + enqueueStoreOperation(.appendCompactionMarker(threadID: threadID, marker: markerRecord)) + try await persistState() + return updated + } + + private func performCompaction( + thread: AgentThread, + effectiveHistory: [AgentMessage], + instructions: String, + tools: [ToolDefinition], + session: ChatGPTSession + ) async throws -> AgentCompactionResult { + switch contextCompactionConfiguration.strategy { + case .preferRemoteThenLocal: + if let compactingBackend = backend as? any AgentBackendContextCompacting, + let result = try? await compactingBackend.compactContext( + thread: thread, + effectiveHistory: effectiveHistory, + instructions: instructions, + tools: tools, + session: session + ) { + return result + } + return localCompactionResult(for: thread.id, from: effectiveHistory) + + case .remoteOnly: + guard let compactingBackend = backend as? any AgentBackendContextCompacting else { + throw AgentRuntimeError.contextCompactionUnsupported() + } + return try await compactingBackend.compactContext( + thread: thread, + effectiveHistory: effectiveHistory, + instructions: instructions, + tools: tools, + session: session + ) + + case .localOnly: + return localCompactionResult(for: thread.id, from: effectiveHistory) + } + } + + private func localCompactionResult( + for threadID: String, + from history: [AgentMessage] + ) -> AgentCompactionResult { + guard history.count > 2 else { + return AgentCompactionResult( + effectiveMessages: history, + summaryPreview: history.last?.displayText + ) + } + + let lastUser = history.last(where: { $0.role == .user }) + let lastAssistant = history.last(where: { $0.role == .assistant }) + let preservedIDs = Set([lastUser?.id, lastAssistant?.id].compactMap { $0 }) + let summarized = history.filter { !preservedIDs.contains($0.id) } + + let summaryLines = summarized.prefix(12).map { message in + let role = message.role.rawValue.capitalized + let text = message.text.trimmingCharacters(in: .whitespacesAndNewlines) + if !text.isEmpty { + return "\(role): \(String(text.prefix(240)))" + } + if !message.images.isEmpty { + return "\(role): [\(message.images.count) image attachment(s)]" + } + return "\(role): [empty]" + } + let summaryText = """ + Compacted conversation summary: + \(summaryLines.joined(separator: "\n")) + """ + let summaryMessage = AgentMessage( + threadID: threadID, + role: .system, + text: summaryText + ) + + var effectiveMessages = [summaryMessage] + if let lastUser { + effectiveMessages.append(lastUser) + } + if let lastAssistant, lastAssistant.id != lastUser?.id { + effectiveMessages.append(lastAssistant) + } + + return AgentCompactionResult( + effectiveMessages: effectiveMessages, + summaryPreview: summaryLines.first + ) + } + + func approximateTokenCount( + for history: [AgentMessage], + pendingMessage: UserMessageRequest?, + instructions: String + ) -> Int { + let historyCharacters = history.reduce(0) { partialResult, message in + partialResult + message.text.count + (message.images.count * 512) + } + let pendingCharacters = (pendingMessage?.text.count ?? 0) + ((pendingMessage?.images.count ?? 0) * 512) + return max(1, (historyCharacters + pendingCharacters + instructions.count) / 4) + } + + func isContextPressureError(_ error: Error) -> Bool { + let message = ((error as? AgentRuntimeError)?.message ?? error.localizedDescription).lowercased() + return message.contains("context") && message.contains("limit") + || message.contains("maximum context length") + || message.contains("too many tokens") + } + + private func nextGenerationSequence(afterAppendingTo threadID: String) -> Int { + (state.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+History.swift b/Sources/CodexKit/Runtime/AgentRuntime+History.swift index 85f745a..4eb444c 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+History.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+History.swift @@ -40,6 +40,7 @@ extension AgentRuntime: AgentRuntimeQueryable, AgentRuntimeThreadInspecting { threadID: id, kinds: query.filter?.includedKinds, includeRedacted: true, + includeCompactionEvents: query.filter?.includeCompactionEvents ?? false, sort: query.direction == .forward ? .sequence(.ascending) : .sequence(.descending), page: AgentQueryPage(limit: query.limit, cursor: query.cursor) ) @@ -70,6 +71,19 @@ extension AgentRuntime: AgentRuntimeQueryable, AgentRuntimeThreadInspecting { return records.first?.metadata } + public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { + if let inspectingStore = stateStore as? any RuntimeStateInspecting { + return try await inspectingStore.fetchThreadContextState(id: id) + } + + return try await execute( + ThreadContextStateQuery( + threadIDs: [id], + limit: 1 + ) + ).first + } + public func fetchLatestStructuredOutput( id: String, as outputType: Output.Type, diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift b/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift index 791e48d..8a41056 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift @@ -55,7 +55,6 @@ extension AgentRuntime { text: request.text, images: request.images ) - let priorMessages = state.messagesByThread[threadID] ?? [] let resolvedTurnSkills = try resolveTurnSkills( thread: thread, message: request @@ -70,9 +69,16 @@ extension AgentRuntime { try await setThreadStatus(.streaming, for: threadID) let tools = await toolRegistry.allDefinitions() + try await maybeCompactThreadContextBeforeTurn( + thread: thread, + request: request, + instructions: resolvedInstructions, + tools: tools, + session: session + ) let turnStart = try await beginTurnWithUnauthorizedRecovery( thread: thread, - history: priorMessages, + history: effectiveHistory(for: threadID), message: request, instructions: resolvedInstructions, responseFormat: nil, @@ -183,7 +189,6 @@ extension AgentRuntime { text: request.text, images: request.images ) - let priorMessages = state.messagesByThread[threadID] ?? [] let resolvedTurnSkills = try resolveTurnSkills( thread: thread, message: request @@ -198,9 +203,16 @@ extension AgentRuntime { try await setThreadStatus(.streaming, for: threadID) let tools = await toolRegistry.allDefinitions() + try await maybeCompactThreadContextBeforeTurn( + thread: thread, + request: request, + instructions: resolvedInstructions, + tools: tools, + session: session + ) let turnStart = try await beginTurnWithUnauthorizedRecovery( thread: thread, - history: priorMessages, + history: effectiveHistory(for: threadID), message: request, instructions: resolvedInstructions, responseFormat: responseFormat, @@ -241,21 +253,51 @@ extension AgentRuntime { turnStream: any AgentTurnStreaming, session: ChatGPTSession ) { - let beginTurn = try await withUnauthorizedRecovery( - initialSession: session - ) { session in - try await backend.beginTurn( + do { + let beginTurn = try await withUnauthorizedRecovery( + initialSession: session + ) { session in + try await backend.beginTurn( + thread: thread, + history: history, + message: message, + instructions: instructions, + responseFormat: responseFormat, + streamedStructuredOutput: streamedStructuredOutput, + tools: tools, + session: session + ) + } + return (beginTurn.result, beginTurn.session) + } catch { + let compacted = try await maybeCompactThreadContextAfterContextFailure( thread: thread, - history: history, - message: message, + request: message, instructions: instructions, - responseFormat: responseFormat, - streamedStructuredOutput: streamedStructuredOutput, tools: tools, - session: session + session: session, + error: error ) + guard compacted else { + throw error + } + + let beginTurn = try await withUnauthorizedRecovery( + initialSession: session + ) { session in + try await backend.beginTurn( + thread: thread, + history: self.effectiveHistory(for: thread.id), + message: message, + instructions: instructions, + responseFormat: responseFormat, + streamedStructuredOutput: streamedStructuredOutput, + tools: tools, + session: session + ) + } + return (beginTurn.result, beginTurn.session) } - return (beginTurn.result, beginTurn.session) } // MARK: - Previews diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift index 3392db2..1623053 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift @@ -200,6 +200,7 @@ extension AgentRuntime { func appendMessage(_ message: AgentMessage) async throws { state.messagesByThread[message.threadID, default: []].append(message) + appendEffectiveMessage(message) appendHistoryItem( .message(message), threadID: message.threadID, diff --git a/Sources/CodexKit/Runtime/AgentRuntime.swift b/Sources/CodexKit/Runtime/AgentRuntime.swift index a8faa73..9408427 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime.swift @@ -27,6 +27,7 @@ public actor AgentRuntime { public let tools: [ToolRegistration] public let skills: [AgentSkill] public let definitionSourceLoader: AgentDefinitionSourceLoader + public let contextCompaction: AgentContextCompactionConfiguration public init( authProvider: any ChatGPTAuthProviding, @@ -38,7 +39,8 @@ public actor AgentRuntime { baseInstructions: String? = nil, tools: [ToolRegistration] = [], skills: [AgentSkill] = [], - definitionSourceLoader: AgentDefinitionSourceLoader = AgentDefinitionSourceLoader() + definitionSourceLoader: AgentDefinitionSourceLoader = AgentDefinitionSourceLoader(), + contextCompaction: AgentContextCompactionConfiguration = AgentContextCompactionConfiguration() ) { self.authProvider = authProvider self.secureStore = secureStore @@ -50,6 +52,7 @@ public actor AgentRuntime { self.tools = tools self.skills = skills self.definitionSourceLoader = definitionSourceLoader + self.contextCompaction = contextCompaction } } @@ -61,6 +64,7 @@ public actor AgentRuntime { let memoryConfiguration: AgentMemoryConfiguration? let baseInstructions: String? let definitionSourceLoader: AgentDefinitionSourceLoader + let contextCompactionConfiguration: AgentContextCompactionConfiguration var skillsByID: [String: AgentSkill] var state: StoredRuntimeState = .empty @@ -165,6 +169,7 @@ public actor AgentRuntime { self.memoryConfiguration = configuration.memory self.baseInstructions = configuration.baseInstructions ?? configuration.backend.baseInstructions self.definitionSourceLoader = configuration.definitionSourceLoader + self.contextCompactionConfiguration = configuration.contextCompaction self.skillsByID = try Self.validatedSkills(from: configuration.skills) } diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift new file mode 100644 index 0000000..9be1f36 --- /dev/null +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift @@ -0,0 +1,119 @@ +import Foundation + +extension CodexResponsesBackend: AgentBackendContextCompacting { + public func compactContext( + thread: AgentThread, + effectiveHistory: [AgentMessage], + instructions: String, + tools: [ToolDefinition], + session: ChatGPTSession + ) async throws -> AgentCompactionResult { + let requestBody = ResponsesCompactRequestBody( + model: configuration.model, + reasoning: .init(effort: configuration.reasoningEffort), + instructions: instructions, + text: .init(format: .init(responseFormat: nil)), + input: effectiveHistory.map { WorkingHistoryItem.visibleMessage($0).jsonValue }, + tools: CodexResponsesTurnSession.responsesTools( + from: tools, + enableWebSearch: configuration.enableWebSearch + ), + parallelToolCalls: false + ) + + var request = URLRequest(url: configuration.baseURL.appendingPathComponent("responses/compact")) + request.httpMethod = "POST" + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("application/json", forHTTPHeaderField: "Accept") + request.setValue("Bearer \(session.accessToken)", forHTTPHeaderField: "Authorization") + request.setValue(session.account.id, forHTTPHeaderField: "ChatGPT-Account-ID") + request.setValue(thread.id, forHTTPHeaderField: "session_id") + request.setValue(thread.id, forHTTPHeaderField: "x-client-request-id") + request.setValue(configuration.originator, forHTTPHeaderField: "originator") + + for (header, value) in configuration.extraHeaders { + request.setValue(value, forHTTPHeaderField: header) + } + + let (data, response) = try await urlSession.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw AgentRuntimeError( + code: "responses_compact_invalid_response", + message: "The ChatGPT compact endpoint returned an invalid response." + ) + } + guard (200 ..< 300).contains(httpResponse.statusCode) else { + let body = String(data: data, encoding: .utf8) ?? "" + throw AgentRuntimeError( + code: "responses_compact_failed", + message: "The ChatGPT compact endpoint failed with status \(httpResponse.statusCode): \(body)" + ) + } + + let payload = try decoder.decode(JSONValue.self, from: data) + let output = payload.objectValue?["output"]?.arrayValue ?? [] + let messages = output.compactMap { item in + Self.compactedMessage(from: item, threadID: thread.id) + } + guard !messages.isEmpty else { + throw AgentRuntimeError.contextCompactionUnsupported() + } + + return AgentCompactionResult( + effectiveMessages: messages, + summaryPreview: messages.first?.displayText + ) + } + + private static func compactedMessage( + from value: JSONValue, + threadID: String + ) -> AgentMessage? { + guard let object = value.objectValue, + let type = object["type"]?.stringValue + else { + return nil + } + + if type == "compaction", + let summary = object["encrypted_content"]?.stringValue { + return AgentMessage( + threadID: threadID, + role: .system, + text: summary + ) + } + + guard type == "message", + let roleRaw = object["role"]?.stringValue + else { + return nil + } + + let role: AgentRole + switch roleRaw { + case "assistant": + role = .assistant + case "system", "developer": + role = .system + case "user": + role = .user + default: + return nil + } + + let text = (object["content"]?.arrayValue ?? []).compactMap { item -> String? in + guard let content = item.objectValue else { + return nil + } + return content["text"]?.stringValue + }.joined(separator: "\n") + + return AgentMessage( + threadID: threadID, + role: role, + text: text + ) + } +} diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift index 5f098ee..590294d 100644 --- a/Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift @@ -30,6 +30,26 @@ struct ResponsesRequestBody: Encodable { } } +struct ResponsesCompactRequestBody: Encodable { + let model: String + let reasoning: ResponsesReasoningConfiguration + let instructions: String + let text: ResponsesTextConfiguration + let input: [JSONValue] + let tools: [JSONValue] + let parallelToolCalls: Bool + + enum CodingKeys: String, CodingKey { + case model + case reasoning + case instructions + case text + case input + case tools + case parallelToolCalls = "parallel_tool_calls" + } +} + struct ResponsesReasoningConfiguration: Encodable { let effort: String diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift index 54e6519..587bb12 100644 --- a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift @@ -39,10 +39,10 @@ public struct CodexResponsesBackendConfiguration: Sendable { public actor CodexResponsesBackend: AgentBackend { public nonisolated let baseInstructions: String? - private let configuration: CodexResponsesBackendConfiguration - private let urlSession: URLSession - private let encoder = JSONEncoder() - private let decoder = JSONDecoder() + let configuration: CodexResponsesBackendConfiguration + let urlSession: URLSession + let encoder = JSONEncoder() + let decoder = JSONDecoder() public init( configuration: CodexResponsesBackendConfiguration = CodexResponsesBackendConfiguration(), diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift index 3e5c8a3..d82b409 100644 --- a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift @@ -2,7 +2,7 @@ import Foundation import GRDB public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { - private static let currentStoreSchemaVersion = 1 + private static let currentStoreSchemaVersion = 2 private let url: URL private let legacyStateURL: URL? @@ -72,6 +72,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, let threadRows = try RuntimeThreadRow.fetchAll(db) let summaryRows = try RuntimeSummaryRow.fetchAll(db) let historyRows = try RuntimeHistoryRow.fetchAll(db) + let contextRows = try RuntimeContextStateRow.fetchAll(db) let threads = try threadRows.map { try Self.decodeThread(from: $0) } let summariesByThread = try Dictionary( @@ -83,11 +84,17 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } let historyByThread = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) + let contextStateByThread = try Dictionary( + uniqueKeysWithValues: contextRows.map { row in + (row.threadID, try Self.decodeContextState(from: row)) + } + ) return StoredRuntimeState( threads: threads, historyByThread: historyByThread, - summariesByThread: summariesByThread + summariesByThread: summariesByThread, + contextStateByThread: contextStateByThread ) } } @@ -179,6 +186,19 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, return summary.latestStructuredOutputMetadata } + public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { + try await ensurePrepared() + return try await dbQueue.read { db in + guard try RuntimeThreadRow.fetchOne(db, key: id) != nil else { + throw AgentRuntimeError.threadNotFound(id) + } + guard let row = try RuntimeContextStateRow.fetchOne(db, key: id) else { + return nil + } + return try Self.decodeContextState(from: row) + } + } + public func execute(_ query: Query) async throws -> Query.Result { try await ensurePrepared() @@ -197,6 +217,9 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, if let snapshotQuery = query as? ThreadSnapshotQuery { return try await executeThreadSnapshotQuery(snapshotQuery) as! Query.Result } + if let contextQuery = query as? ThreadContextStateQuery { + return try await executeThreadContextStateQuery(contextQuery) as! Query.Result + } let state = try await loadState() return try query.execute(in: state) @@ -241,6 +264,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, createdAtRange: query.createdAtRange, turnID: query.turnID, includeRedacted: query.includeRedacted, + includeCompactionEvents: query.includeCompactionEvents, in: db, attachmentStore: attachmentStore ) @@ -255,22 +279,77 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private func executeThreadQuery(_ query: ThreadMetadataQuery) async throws -> [AgentThread] { try await dbQueue.read { db in - let rows = try RuntimeThreadRow.fetchAll( - db, - sql: Self.threadMetadataSQL(for: query), - arguments: StatementArguments(Self.threadMetadataArguments(for: query)) - ) + var request = RuntimeThreadRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + if let statuses = query.statuses, !statuses.isEmpty { + request = request.filter(statuses.map(\.rawValue).contains(Column("status"))) + } + if let range = query.updatedAtRange { + request = request.filter(Column("updatedAt") >= range.lowerBound.timeIntervalSince1970) + request = request.filter(Column("updatedAt") <= range.upperBound.timeIntervalSince1970) + } + + switch query.sort { + case let .updatedAt(order): + request = order == .ascending + ? request.order(Column("updatedAt").asc, Column("threadID").asc) + : request.order(Column("updatedAt").desc, Column("threadID").asc) + case let .createdAt(order): + request = order == .ascending + ? request.order(Column("createdAt").asc, Column("threadID").asc) + : request.order(Column("createdAt").desc, Column("threadID").asc) + } + + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + let rows = try request.fetchAll(db) return try rows.map { try Self.decodeThread(from: $0) } } } + private func executeThreadContextStateQuery(_ query: ThreadContextStateQuery) async throws -> [AgentThreadContextState] { + try await dbQueue.read { db in + var request = RuntimeContextStateRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + request = request.order(Column("generation").desc, Column("threadID").asc) + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + return try request.fetchAll(db).map { try Self.decodeContextState(from: $0) } + } + } + private func executePendingStateQuery(_ query: PendingStateQuery) async throws -> [AgentPendingStateRecord] { try await dbQueue.read { db in - let summaries = try RuntimeSummaryRow.fetchAll( - db, - sql: Self.pendingStateSQL(for: query), - arguments: StatementArguments(Self.pendingStateArguments(for: query)) - ) + var request = RuntimeSummaryRow + .filter(Column("pendingStateKind") != nil) + + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + if let kinds = query.kinds, !kinds.isEmpty { + request = request.filter(kinds.map(\.rawValue).contains(Column("pendingStateKind"))) + } + + switch query.sort { + case let .updatedAt(order): + request = order == .ascending + ? request.order(Column("updatedAt").asc) + : request.order(Column("updatedAt").desc) + } + + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + let summaries = try request.fetchAll(db) let records = try summaries.compactMap { row -> AgentPendingStateRecord? in let summary = try Self.decodeSummary(from: row) guard let pendingState = summary.pendingState else { @@ -288,11 +367,27 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private func executeStructuredOutputQuery(_ query: StructuredOutputQuery) async throws -> [AgentStructuredOutputRecord] { try await dbQueue.read { db in - var records = try RuntimeStructuredOutputRow.fetchAll( - db, - sql: Self.structuredOutputSQL(for: query), - arguments: StatementArguments(Self.structuredOutputArguments(for: query)) - ).map { try Self.decodeStructuredOutputRecord(from: $0) } + var request = RuntimeStructuredOutputRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + if let formatNames = query.formatNames, !formatNames.isEmpty { + request = request.filter(formatNames.contains(Column("formatName"))) + } + + switch query.sort { + case let .committedAt(order): + request = order == .ascending + ? request.order(Column("committedAt").asc) + : request.order(Column("committedAt").desc) + } + + if let limit = query.limit, !query.latestOnly { + request = request.limit(max(0, limit)) + } + + var records = try request.fetchAll(db) + .map { try Self.decodeStructuredOutputRecord(from: $0) } if query.latestOnly { var seen = Set() @@ -308,11 +403,27 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private func executeThreadSnapshotQuery(_ query: ThreadSnapshotQuery) async throws -> [AgentThreadSnapshot] { try await dbQueue.read { db in - let snapshots = try RuntimeSummaryRow.fetchAll( - db, - sql: Self.threadSnapshotSQL(for: query), - arguments: StatementArguments(Self.threadSnapshotArguments(for: query)) - ) + var request = RuntimeSummaryRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + + switch query.sort { + case let .updatedAt(order): + request = order == .ascending + ? request.order(Column("updatedAt").asc, Column("threadID").asc) + : request.order(Column("updatedAt").desc, Column("threadID").asc) + case let .createdAt(order): + request = order == .ascending + ? request.order(Column("createdAt").asc, Column("threadID").asc) + : request.order(Column("createdAt").desc, Column("threadID").asc) + } + + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + let snapshots = try request.fetchAll(db) .map { try Self.decodeSummary(from: $0) } .map(\.snapshot) return snapshots @@ -335,7 +446,9 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, .flatMap { $0 } .map { try Self.makeHistoryRow(from: $0, attachmentStore: attachmentStore) } let structuredOutputRows = try Self.structuredOutputRows(from: normalized.historyByThread) + let contextRows = try normalized.contextStateByThread.values.map(Self.makeContextStateRow) + try RuntimeContextStateRow.deleteAll(db) try RuntimeStructuredOutputRow.deleteAll(db) try RuntimeHistoryRow.deleteAll(db) try RuntimeSummaryRow.deleteAll(db) @@ -353,6 +466,9 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, for row in structuredOutputRows { try row.insert(db) } + for row in contextRows { + try row.insert(db) + } } private func shouldImportLegacyState() async throws -> Bool { @@ -370,7 +486,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, } let threadCount = try await dbQueue.read { db in - try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM \(RuntimeThreadRow.databaseTableName)") ?? 0 + try RuntimeThreadCountQuery().execute(in: db) } return threadCount == 0 } @@ -406,26 +522,25 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, } let ids = Array(threadIDs) - let placeholders = Self.sqlPlaceholders(count: ids.count) - let threadRows = try RuntimeThreadRow.fetchAll( - db, - sql: "SELECT * FROM \(RuntimeThreadRow.databaseTableName) WHERE threadID IN \(placeholders)", - arguments: StatementArguments(ids) - ) - let summaryRows = try RuntimeSummaryRow.fetchAll( - db, - sql: "SELECT * FROM \(RuntimeSummaryRow.databaseTableName) WHERE threadID IN \(placeholders)", - arguments: StatementArguments(ids) - ) - let historyRows = try RuntimeHistoryRow.fetchAll( - db, + let threadRows = try RuntimeThreadRow + .filter(ids.contains(Column("threadID"))) + .fetchAll(db) + let summaryRows = try RuntimeSummaryRow + .filter(ids.contains(Column("threadID"))) + .fetchAll(db) + // History loading keeps raw SQL here so we can preserve a deterministic + // thread + sequence ordering across multiple thread IDs in one fetch. + let historyRows = try RuntimeHistoryRowsRequest( sql: """ SELECT * FROM \(RuntimeHistoryRow.databaseTableName) - WHERE threadID IN \(placeholders) + WHERE threadID IN \(Self.sqlPlaceholders(count: ids.count)) ORDER BY threadID ASC, sequenceNumber ASC """, arguments: StatementArguments(ids) - ) + ).execute(in: db) + let contextRows = try RuntimeContextStateRow + .filter(ids.contains(Column("threadID"))) + .fetchAll(db) let threads = try threadRows.map { try Self.decodeThread(from: $0) } let summaries = try Dictionary( @@ -435,12 +550,16 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } let history = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) + let contextState = try Dictionary( + uniqueKeysWithValues: contextRows.map { ($0.threadID, try Self.decodeContextState(from: $0)) } + ) let nextSequence = history.mapValues { ($0.last?.sequenceNumber ?? 0) + 1 } return StoredRuntimeState( threads: threads, historyByThread: history, summariesByThread: summaries, + contextStateByThread: contextState, nextHistorySequenceByThread: nextSequence ) } @@ -462,6 +581,9 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, if let summary = normalized.summariesByThread[thread.id] { try Self.makeSummaryRow(from: summary).insert(db) } + if let contextState = normalized.contextStateByThread[thread.id] { + try Self.makeContextStateRow(from: contextState).insert(db) + } for record in normalized.historyByThread[thread.id] ?? [] { try Self.makeHistoryRow(from: record, attachmentStore: attachmentStore).insert(db) } @@ -478,10 +600,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, _ threadID: String, in db: Database ) throws { - try db.execute( - sql: "DELETE FROM \(RuntimeThreadRow.databaseTableName) WHERE threadID = ?", - arguments: [threadID] - ) + _ = try RuntimeThreadRow.deleteOne(db, key: threadID) } private static func fetchHistoryRows( @@ -490,6 +609,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, createdAtRange: ClosedRange?, turnID: String?, includeRedacted: Bool, + includeCompactionEvents: Bool, in db: Database, attachmentStore: RuntimeAttachmentStore ) throws -> [AgentHistoryRecord] { @@ -513,17 +633,21 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, if !includeRedacted { clauses.append("isRedacted = 0") } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + // This stays in SQL because the history query shape is highly dynamic and + // we always want sequence-ordered reads for restore/query replay semantics. let sql = """ SELECT * FROM \(RuntimeHistoryRow.databaseTableName) WHERE \(clauses.joined(separator: " AND ")) ORDER BY sequenceNumber ASC """ - return try RuntimeHistoryRow.fetchAll( - db, + return try RuntimeHistoryRowsRequest( sql: sql, arguments: StatementArguments(arguments) - ).map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } + ).execute(in: db).map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } } private static func fetchHistoryPage( @@ -534,6 +658,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, ) throws -> AgentThreadHistoryPage { let limit = max(1, query.limit) let kinds = historyKinds(from: query.filter) + let includeCompactionEvents = query.filter?.includeCompactionEvents ?? false let anchor = try decodeCursorSequence(query.cursor, expectedThreadID: threadID) switch query.direction { @@ -548,18 +673,22 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, clauses.append("sequenceNumber < ?") arguments.append(anchor) } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + // Cursor paging is kept as raw SQL because the descending window + overfetch + // pattern is much clearer here than trying to express it through chained requests. let sql = """ SELECT * FROM \(RuntimeHistoryRow.databaseTableName) WHERE \(clauses.joined(separator: " AND ")) ORDER BY sequenceNumber DESC LIMIT \(limit + 1) """ - let fetched = try RuntimeHistoryRow.fetchAll( - db, + let fetched = try RuntimeHistoryRowsRequest( sql: sql, arguments: StatementArguments(arguments) - ) + ).execute(in: db) let hasMoreBefore = fetched.count > limit let pageRowsDescending = Array(fetched.prefix(limit)) let pageRecords = try pageRowsDescending @@ -571,6 +700,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, hasMoreAfter = try historyRecordExists( threadID: threadID, kinds: kinds, + includeCompactionEvents: includeCompactionEvents, comparator: "sequenceNumber >= ?", value: anchor, in: db @@ -599,18 +729,22 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, clauses.append("sequenceNumber > ?") arguments.append(anchor) } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + // Forward paging mirrors the backward cursor window and stays in SQL for the + // same reason: explicit sequence bounds and overfetch are easier to verify here. let sql = """ SELECT * FROM \(RuntimeHistoryRow.databaseTableName) WHERE \(clauses.joined(separator: " AND ")) ORDER BY sequenceNumber ASC LIMIT \(limit + 1) """ - let fetched = try RuntimeHistoryRow.fetchAll( - db, + let fetched = try RuntimeHistoryRowsRequest( sql: sql, arguments: StatementArguments(arguments) - ) + ).execute(in: db) let hasMoreAfter = fetched.count > limit let pageRows = Array(fetched.prefix(limit)) let pageRecords = try pageRows.map { @@ -622,6 +756,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, hasMoreBefore = try historyRecordExists( threadID: threadID, kinds: kinds, + includeCompactionEvents: includeCompactionEvents, comparator: "sequenceNumber <= ?", value: anchor, in: db @@ -644,6 +779,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private static func historyRecordExists( threadID: String, kinds: Set?, + includeCompactionEvents: Bool, comparator: String, value: Int, in db: Database @@ -654,139 +790,22 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") for kind in kinds { arguments.append(kind.rawValue) } } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + // EXISTS is one of the few cases where the raw SQL is both shorter and more obvious + // than the equivalent GRDB request composition for cursor-bound history checks. let sql = """ SELECT EXISTS( SELECT 1 FROM \(RuntimeHistoryRow.databaseTableName) WHERE \(clauses.joined(separator: " AND ")) ) """ - return try Bool.fetchOne(db, sql: sql, arguments: StatementArguments(arguments)) ?? false - } - - private static func threadMetadataSQL(for query: ThreadMetadataQuery) -> String { - var sql = "SELECT * FROM \(RuntimeThreadRow.databaseTableName)" - var clauses: [String] = [] - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - clauses.append("threadID IN \(sqlPlaceholders(count: threadIDs.count))") - } - if let statuses = query.statuses, !statuses.isEmpty { - clauses.append("status IN \(sqlPlaceholders(count: statuses.count))") - } - if query.updatedAtRange != nil { - clauses.append("updatedAt >= ?") - clauses.append("updatedAt <= ?") - } - if !clauses.isEmpty { - sql += " WHERE " + clauses.joined(separator: " AND ") - } - sql += " ORDER BY " + threadSortClause(query.sort) - if let limit = query.limit { - sql += " LIMIT \(max(0, limit))" - } - return sql - } - - private static func threadMetadataArguments(for query: ThreadMetadataQuery) -> [any DatabaseValueConvertible] { - var arguments: [any DatabaseValueConvertible] = [] - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - for threadID in threadIDs { - arguments.append(threadID) - } - } - if let statuses = query.statuses, !statuses.isEmpty { - for status in statuses { - arguments.append(status.rawValue) - } - } - if let range = query.updatedAtRange { - arguments.append(range.lowerBound.timeIntervalSince1970) - arguments.append(range.upperBound.timeIntervalSince1970) - } - return arguments - } - - private static func pendingStateSQL(for query: PendingStateQuery) -> String { - var sql = "SELECT * FROM \(RuntimeSummaryRow.databaseTableName) WHERE pendingStateKind IS NOT NULL" - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - sql += " AND threadID IN \(sqlPlaceholders(count: threadIDs.count))" - } - if let kinds = query.kinds, !kinds.isEmpty { - sql += " AND pendingStateKind IN \(sqlPlaceholders(count: kinds.count))" - } - sql += " ORDER BY updatedAt " + sortDirection(for: query.sort) - if let limit = query.limit { - sql += " LIMIT \(max(0, limit))" - } - return sql - } - - private static func pendingStateArguments(for query: PendingStateQuery) -> [any DatabaseValueConvertible] { - var arguments: [any DatabaseValueConvertible] = [] - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - for threadID in threadIDs { - arguments.append(threadID) - } - } - if let kinds = query.kinds, !kinds.isEmpty { - for kind in kinds { - arguments.append(kind.rawValue) - } - } - return arguments - } - - private static func structuredOutputSQL(for query: StructuredOutputQuery) -> String { - var sql = "SELECT * FROM \(RuntimeStructuredOutputRow.databaseTableName)" - var clauses: [String] = [] - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - clauses.append("threadID IN \(sqlPlaceholders(count: threadIDs.count))") - } - if let formatNames = query.formatNames, !formatNames.isEmpty { - clauses.append("formatName IN \(sqlPlaceholders(count: formatNames.count))") - } - if !clauses.isEmpty { - sql += " WHERE " + clauses.joined(separator: " AND ") - } - sql += " ORDER BY committedAt " + sortDirection(for: query.sort) - if let limit = query.limit, !query.latestOnly { - sql += " LIMIT \(max(0, limit))" - } - return sql - } - - private static func structuredOutputArguments(for query: StructuredOutputQuery) -> [any DatabaseValueConvertible] { - var arguments: [any DatabaseValueConvertible] = [] - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - for threadID in threadIDs { - arguments.append(threadID) - } - } - if let formatNames = query.formatNames, !formatNames.isEmpty { - for formatName in formatNames { - arguments.append(formatName) - } - } - return arguments - } - - private static func threadSnapshotSQL(for query: ThreadSnapshotQuery) -> String { - var sql = "SELECT * FROM \(RuntimeSummaryRow.databaseTableName)" - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - sql += " WHERE threadID IN \(sqlPlaceholders(count: threadIDs.count))" - } - sql += " ORDER BY " + snapshotSortClause(query.sort) - if let limit = query.limit { - sql += " LIMIT \(max(0, limit))" - } - return sql - } - - private static func threadSnapshotArguments(for query: ThreadSnapshotQuery) -> [any DatabaseValueConvertible] { - guard let threadIDs = query.threadIDs, !threadIDs.isEmpty else { - return [] - } - return threadIDs.map { $0 as any DatabaseValueConvertible } + return try RuntimeHistoryExistenceQuery( + sql: sql, + arguments: StatementArguments(arguments) + ).execute(in: db) } private static func defaultLegacyImportURL(for url: URL) -> URL { @@ -797,38 +816,6 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, "(" + Array(repeating: "?", count: count).joined(separator: ", ") + ")" } - private static func threadSortClause(_ sort: AgentThreadMetadataSort) -> String { - switch sort { - case let .updatedAt(order): - "updatedAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" - case let .createdAt(order): - "createdAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" - } - } - - private static func snapshotSortClause(_ sort: AgentThreadSnapshotSort) -> String { - switch sort { - case let .updatedAt(order): - "updatedAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" - case let .createdAt(order): - "createdAt \(order == .ascending ? "ASC" : "DESC"), threadID ASC" - } - } - - private static func sortDirection(for sort: AgentPendingStateSort) -> String { - switch sort { - case let .updatedAt(order): - order == .ascending ? "ASC" : "DESC" - } - } - - private static func sortDirection(for sort: AgentStructuredOutputSort) -> String { - switch sort { - case let .committedAt(order): - order == .ascending ? "ASC" : "DESC" - } - } - private static func historyKinds(from filter: AgentHistoryFilter?) -> Set? { guard let filter else { return nil @@ -926,11 +913,20 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, createdAt: record.createdAt.timeIntervalSince1970, kind: record.item.kind.rawValue, turnID: record.item.turnID, + isCompactionMarker: record.item.isCompactionMarker, isRedacted: record.redaction != nil, encodedRecord: try JSONEncoder().encode(persisted) ) } + private static func makeContextStateRow(from state: AgentThreadContextState) throws -> RuntimeContextStateRow { + RuntimeContextStateRow( + threadID: state.threadID, + generation: state.generation, + encodedState: try JSONEncoder().encode(state) + ) + } + private static func structuredOutputRows( from historyByThread: [String: [AgentHistoryRecord]] ) throws -> [RuntimeStructuredOutputRow] { @@ -986,6 +982,10 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try JSONDecoder().decode(AgentThreadSummary.self, from: row.encodedSummary) } + private static func decodeContextState(from row: RuntimeContextStateRow) throws -> AgentThreadContextState { + try JSONDecoder().decode(AgentThreadContextState.self, from: row.encodedState) + } + private static func decodeHistoryRecord( from row: RuntimeHistoryRow, attachmentStore: RuntimeAttachmentStore @@ -1036,6 +1036,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, table.column("createdAt", .double).notNull() table.column("kind", .text).notNull() table.column("turnID", .text) + table.column("isCompactionMarker", .boolean).notNull().defaults(to: false) table.column("isRedacted", .boolean).notNull().defaults(to: false) table.column("encodedRecord", .blob).notNull() } @@ -1058,6 +1059,35 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try db.create(index: "runtime_structured_outputs_thread_committed_at", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["threadID", "committedAt"]) try db.create(index: "runtime_structured_outputs_format_name", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["formatName"]) + try db.create(table: RuntimeContextStateRow.databaseTableName) { table in + table.column("threadID", .text) + .primaryKey() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("generation", .integer).notNull() + table.column("encodedState", .blob).notNull() + } + + try db.execute(sql: "PRAGMA user_version = \(currentStoreSchemaVersion)") + } + + migrator.registerMigration("runtime_store_v2_compaction_state") { db in + let historyColumns = try db.columns(in: RuntimeHistoryRow.databaseTableName).map(\.name) + if !historyColumns.contains("isCompactionMarker") { + try db.alter(table: RuntimeHistoryRow.databaseTableName) { table in + table.add(column: "isCompactionMarker", .boolean).notNull().defaults(to: false) + } + } + + if try !db.tableExists(RuntimeContextStateRow.databaseTableName) { + try db.create(table: RuntimeContextStateRow.databaseTableName) { table in + table.column("threadID", .text) + .primaryKey() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("generation", .integer).notNull() + table.column("encodedState", .blob).notNull() + } + } + try db.execute(sql: "PRAGMA user_version = \(currentStoreSchemaVersion)") } @@ -1116,7 +1146,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, private func readUserVersion() async throws -> Int { try await dbQueue.read { db in - try Int.fetchOne(db, sql: "PRAGMA user_version;") ?? 0 + try RuntimeUserVersionQuery().execute(in: db) } } } @@ -1131,6 +1161,16 @@ private struct RuntimeThreadRow: Codable, FetchableRecord, PersistableRecord, Ta let encodedThread: Data } +private struct RuntimeThreadCountQuery { + func execute(in db: Database) throws -> Int { + let row = try SQLRequest( + sql: "SELECT COUNT(*) AS thread_count FROM \(RuntimeThreadRow.databaseTableName)" + ).fetchOne(db) + let count: Int? = row?["thread_count"] + return count ?? 0 + } +} + private struct RuntimeSummaryRow: Codable, FetchableRecord, PersistableRecord, TableRecord { static let databaseTableName = "runtime_summaries" @@ -1154,10 +1194,31 @@ private struct RuntimeHistoryRow: Codable, FetchableRecord, PersistableRecord, T let createdAt: Double let kind: String let turnID: String? + let isCompactionMarker: Bool let isRedacted: Bool let encodedRecord: Data } +private struct RuntimeHistoryRowsRequest { + let sql: String + let arguments: StatementArguments + + func execute(in db: Database) throws -> [RuntimeHistoryRow] { + try SQLRequest(sql: sql, arguments: arguments).fetchAll(db) + } +} + +private struct RuntimeHistoryExistenceQuery { + let sql: String + let arguments: StatementArguments + + func execute(in db: Database) throws -> Bool { + let row = try SQLRequest(sql: sql, arguments: arguments).fetchOne(db) + let exists: Bool? = row?[0] + return exists ?? false + } +} + private struct RuntimeStructuredOutputRow: Codable, FetchableRecord, PersistableRecord, TableRecord { static let databaseTableName = "runtime_structured_outputs" @@ -1168,6 +1229,21 @@ private struct RuntimeStructuredOutputRow: Codable, FetchableRecord, Persistable let encodedRecord: Data } +private struct RuntimeContextStateRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_context_states" + + let threadID: String + let generation: Int + let encodedState: Data +} + +private struct RuntimeUserVersionQuery { + func execute(in db: Database) throws -> Int { + let row = try SQLRequest(sql: "PRAGMA user_version;").fetchOne(db) + return row?[0] ?? 0 + } +} + private struct GRDBHistoryCursorPayload: Codable { let version: Int let threadID: String diff --git a/Sources/CodexKit/Runtime/RuntimeStateStore.swift b/Sources/CodexKit/Runtime/RuntimeStateStore.swift index 4c1c02d..9184157 100644 --- a/Sources/CodexKit/Runtime/RuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/RuntimeStateStore.swift @@ -5,6 +5,7 @@ public struct StoredRuntimeState: Codable, Hashable, Sendable { public var messagesByThread: [String: [AgentMessage]] public var historyByThread: [String: [AgentHistoryRecord]] public var summariesByThread: [String: AgentThreadSummary] + public var contextStateByThread: [String: AgentThreadContextState] public var nextHistorySequenceByThread: [String: Int] public init( @@ -12,6 +13,7 @@ public struct StoredRuntimeState: Codable, Hashable, Sendable { messagesByThread: [String: [AgentMessage]] = [:], historyByThread: [String: [AgentHistoryRecord]] = [:], summariesByThread: [String: AgentThreadSummary] = [:], + contextStateByThread: [String: AgentThreadContextState] = [:], nextHistorySequenceByThread: [String: Int] = [:] ) { self.init( @@ -19,6 +21,7 @@ public struct StoredRuntimeState: Codable, Hashable, Sendable { messagesByThread: messagesByThread, historyByThread: historyByThread, summariesByThread: summariesByThread, + contextStateByThread: contextStateByThread, nextHistorySequenceByThread: nextHistorySequenceByThread, normalizeState: false ) @@ -30,6 +33,7 @@ public struct StoredRuntimeState: Codable, Hashable, Sendable { messagesByThread: [String: [AgentMessage]], historyByThread: [String: [AgentHistoryRecord]], summariesByThread: [String: AgentThreadSummary], + contextStateByThread: [String: AgentThreadContextState], nextHistorySequenceByThread: [String: Int], normalizeState: Bool ) { @@ -37,6 +41,7 @@ public struct StoredRuntimeState: Codable, Hashable, Sendable { self.messagesByThread = messagesByThread self.historyByThread = historyByThread self.summariesByThread = summariesByThread + self.contextStateByThread = contextStateByThread self.nextHistorySequenceByThread = nextHistorySequenceByThread if normalizeState { self = normalized() @@ -50,6 +55,7 @@ public struct StoredRuntimeState: Codable, Hashable, Sendable { case messagesByThread case historyByThread case summariesByThread + case contextStateByThread case nextHistorySequenceByThread } @@ -60,6 +66,7 @@ public struct StoredRuntimeState: Codable, Hashable, Sendable { messagesByThread: try container.decodeIfPresent([String: [AgentMessage]].self, forKey: .messagesByThread) ?? [:], historyByThread: try container.decodeIfPresent([String: [AgentHistoryRecord]].self, forKey: .historyByThread) ?? [:], summariesByThread: try container.decodeIfPresent([String: AgentThreadSummary].self, forKey: .summariesByThread) ?? [:], + contextStateByThread: try container.decodeIfPresent([String: AgentThreadContextState].self, forKey: .contextStateByThread) ?? [:], nextHistorySequenceByThread: try container.decodeIfPresent([String: Int].self, forKey: .nextHistorySequenceByThread) ?? [:] ) } @@ -80,6 +87,7 @@ public protocol RuntimeStateInspecting: Sendable { query: AgentHistoryQuery ) async throws -> AgentThreadHistoryPage func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? + func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? } public extension RuntimeStateStoring { @@ -159,6 +167,10 @@ public actor InMemoryRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspect public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { try state.threadSummary(id: id).latestStructuredOutputMetadata } + + public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { + state.contextStateByThread[id] + } } public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { @@ -232,6 +244,7 @@ public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, threads: manifest.threads, historyByThread: [id: history], summariesByThread: manifest.summariesByThread, + contextStateByThread: manifest.contextStateByThread, nextHistorySequenceByThread: manifest.nextHistorySequenceByThread ) return try state.threadHistoryPage(id: id, query: query) @@ -245,6 +258,17 @@ public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, return summary.latestStructuredOutputMetadata } + public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { + if let manifest = try loadManifest() { + guard manifest.threads.contains(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + return manifest.contextStateByThread[id] + } + + return try loadNormalizedStateMigratingIfNeeded().contextStateByThread[id] + } + private func loadNormalizedStateMigratingIfNeeded() throws -> StoredRuntimeState { guard fileManager.fileExists(atPath: url.path) else { return .empty @@ -279,6 +303,7 @@ public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, threads: manifest.threads, historyByThread: historyByThread, summariesByThread: manifest.summariesByThread, + contextStateByThread: manifest.contextStateByThread, nextHistorySequenceByThread: manifest.nextHistorySequenceByThread ) } @@ -328,6 +353,7 @@ public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, let manifest = FileRuntimeStateManifest( threads: normalized.threads, summariesByThread: normalized.summariesByThread, + contextStateByThread: normalized.contextStateByThread, nextHistorySequenceByThread: normalized.nextHistorySequenceByThread ) let manifestData = try encoder.encode(manifest) @@ -354,16 +380,19 @@ private struct FileRuntimeStateManifest: Codable { let storageVersion: Int let threads: [AgentThread] let summariesByThread: [String: AgentThreadSummary] + let contextStateByThread: [String: AgentThreadContextState] let nextHistorySequenceByThread: [String: Int] init( threads: [AgentThread], summariesByThread: [String: AgentThreadSummary], + contextStateByThread: [String: AgentThreadContextState], nextHistorySequenceByThread: [String: Int] ) { self.storageVersion = 1 self.threads = threads self.summariesByThread = summariesByThread + self.contextStateByThread = contextStateByThread self.nextHistorySequenceByThread = nextHistorySequenceByThread } } @@ -408,6 +437,7 @@ extension StoredRuntimeState { } var normalizedSummaries: [String: AgentThreadSummary] = [:] + var normalizedContextState = contextStateByThread for thread in sortedThreads { let history = normalizedHistory[thread.id] ?? [] normalizedSummaries[thread.id] = Self.rebuildSummary( @@ -415,6 +445,16 @@ extension StoredRuntimeState { history: history, existing: summariesByThread[thread.id] ) + if let existing = normalizedContextState[thread.id] { + normalizedContextState[thread.id] = AgentThreadContextState( + threadID: thread.id, + effectiveMessages: existing.effectiveMessages, + generation: existing.generation, + lastCompactedAt: existing.lastCompactedAt, + lastCompactionReason: existing.lastCompactionReason, + latestMarkerID: existing.latestMarkerID + ) + } } return StoredRuntimeState( @@ -422,6 +462,7 @@ extension StoredRuntimeState { messagesByThread: normalizedMessages, historyByThread: normalizedHistory, summariesByThread: normalizedSummaries, + contextStateByThread: normalizedContextState, nextHistorySequenceByThread: normalizedNextSequence, normalizeState: false ) @@ -509,6 +550,17 @@ extension StoredRuntimeState { let nextSequence = (updated.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 updated.nextHistorySequenceByThread[threadID] = nextSequence + case let .appendCompactionMarker(threadID, marker): + updated.historyByThread[threadID, default: []].append(marker) + let nextSequence = (updated.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 + updated.nextHistorySequenceByThread[threadID] = nextSequence + + case let .upsertThreadContextState(threadID, state): + updated.contextStateByThread[threadID] = state + + case let .deleteThreadContextState(threadID): + updated.contextStateByThread.removeValue(forKey: threadID) + case let .setPendingState(threadID, state): if let thread = updated.threads.first(where: { $0.id == threadID }) { let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) @@ -603,6 +655,7 @@ extension StoredRuntimeState { updated.messagesByThread.removeValue(forKey: threadID) updated.historyByThread.removeValue(forKey: threadID) updated.summariesByThread.removeValue(forKey: threadID) + updated.contextStateByThread.removeValue(forKey: threadID) updated.nextHistorySequenceByThread.removeValue(forKey: threadID) } } @@ -635,6 +688,9 @@ extension StoredRuntimeState { if !query.includeRedacted { records = records.filter { $0.redaction == nil } } + if !query.includeCompactionEvents { + records = records.filter { !$0.item.isCompactionMarker } + } records = sort(records, using: query.sort) let page = try page(records, threadID: query.threadID, with: query.page, sort: query.sort) @@ -746,6 +802,23 @@ extension StoredRuntimeState { return snapshots } + func execute(_ query: ThreadContextStateQuery) -> [AgentThreadContextState] { + var records = Array(contextStateByThread.values) + if let threadIDs = query.threadIDs { + records = records.filter { threadIDs.contains($0.threadID) } + } + records.sort { lhs, rhs in + if lhs.generation == rhs.generation { + return lhs.threadID < rhs.threadID + } + return lhs.generation > rhs.generation + } + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } + private static func syntheticHistory(from messages: [AgentMessage]) -> [AgentHistoryRecord] { let orderedMessages = messages.enumerated().sorted { lhs, rhs in let left = lhs.element @@ -813,7 +886,7 @@ extension StoredRuntimeState { latestTurnStatus = .completed case .turnFailed: latestTurnStatus = .failed - case .threadCreated, .threadResumed, .threadStatusChanged: + case .threadCreated, .threadResumed, .threadStatusChanged, .contextCompacted: break } } diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift index f63d1f4..e0f0556 100644 --- a/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift @@ -481,7 +481,7 @@ extension AgentRuntimeTests { let metadata = try await reloadedRuntime.prepareStore() XCTAssertEqual(metadata.storeKind, "GRDBRuntimeStateStore") - XCTAssertEqual(metadata.storeSchemaVersion, 1) + XCTAssertEqual(metadata.storeSchemaVersion, 2) let summary = try await reloadedRuntime.fetchThreadSummary(id: thread.id) XCTAssertEqual(summary.latestTurnStatus, .completed) @@ -671,13 +671,149 @@ extension AgentRuntimeTests { let databaseData = try Data(contentsOf: url) XCTAssertNil(databaseData.range(of: imageData.base64EncodedData())) } + + func testManualCompactionPreservesVisibleHistoryAndHidesMarkersByDefault() async throws { + let backend = CompactingTestBackend() + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + contextCompaction: AgentContextCompactionConfiguration( + isEnabled: true, + mode: .automatic + ) + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Compaction") + _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + + let visibleBefore = await runtime.messages(for: thread.id) + XCTAssertEqual(visibleBefore.count, 6) + + let contextState = try await runtime.compactThreadContext(id: thread.id) + XCTAssertEqual(contextState.generation, 1) + XCTAssertLessThan(contextState.effectiveMessages.count, visibleBefore.count) + + let visibleAfter = await runtime.messages(for: thread.id) + XCTAssertEqual(visibleAfter, visibleBefore) + + let hiddenHistory = try await runtime.execute( + HistoryItemsQuery( + threadID: thread.id, + kinds: [.systemEvent] + ) + ) + XCTAssertFalse(hiddenHistory.records.contains(where: { $0.item.isCompactionMarker })) + + let debugHistory = try await runtime.execute( + HistoryItemsQuery( + threadID: thread.id, + kinds: [.systemEvent], + includeCompactionEvents: true + ) + ) + XCTAssertTrue(debugHistory.records.contains(where: { $0.item.isCompactionMarker })) + } + + func testAutomaticRetryCompactionRecoversFromContextLimitError() async throws { + let backend = CompactingTestBackend(failOnHistoryCountAbove: 2) + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + contextCompaction: AgentContextCompactionConfiguration( + isEnabled: true, + mode: .automatic, + trigger: AgentContextCompactionTrigger( + estimatedTokenThreshold: 100_000, + retryOnContextLimitError: true + ) + ) + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Retry Compact") + _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + let reply = try await runtime.sendMessage(UserMessageRequest(text: "four"), in: thread.id) + + XCTAssertEqual(reply, "Echo: four") + let compactCallCount = await backend.compactCallCount() + let beginTurnHistoryCounts = await backend.beginTurnHistoryCounts() + XCTAssertEqual(compactCallCount, 1) + XCTAssertGreaterThanOrEqual(beginTurnHistoryCounts.count, 5) + } + + func testContextStatePersistsAcrossGRDBReload() async throws { + let url = temporaryRuntimeSQLiteURL() + defer { try? FileManager.default.removeItem(at: url) } + + let backend = CompactingTestBackend() + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: try GRDBRuntimeStateStore(url: url), + contextCompaction: AgentContextCompactionConfiguration( + isEnabled: true, + mode: .automatic + ) + ) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Persisted Context") + _ = try await runtime.sendMessage(UserMessageRequest(text: "alpha"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "beta"), in: thread.id) + _ = try await runtime.compactThreadContext(id: thread.id) + + let reloadedRuntime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: try GRDBRuntimeStateStore(url: url), + contextCompaction: AgentContextCompactionConfiguration( + isEnabled: true, + mode: .automatic + ) + ) + + let restoredContext = try await reloadedRuntime.fetchThreadContextState(id: thread.id) + XCTAssertEqual(restoredContext?.generation, 1) + XCTAssertFalse(restoredContext?.effectiveMessages.isEmpty ?? true) + } + + func testContextCompactionConfigurationDefaultsAndCodableShape() throws { + let configuration = AgentContextCompactionConfiguration() + XCTAssertFalse(configuration.isEnabled) + XCTAssertEqual(configuration.mode, .automatic) + + let encoded = try JSONEncoder().encode( + AgentContextCompactionConfiguration(isEnabled: true, mode: .manual) + ) + let decoded = try JSONDecoder().decode( + AgentContextCompactionConfiguration.self, + from: encoded + ) + + XCTAssertTrue(decoded.isEnabled) + XCTAssertEqual(decoded.mode, .manual) + } } private func makeHistoryRuntime( backend: any AgentBackend, approvalPresenter: any ApprovalPresenting, stateStore: any RuntimeStateStoring, - tools: [AgentRuntime.ToolRegistration] = [] + tools: [AgentRuntime.ToolRegistration] = [], + contextCompaction: AgentContextCompactionConfiguration = AgentContextCompactionConfiguration() ) throws -> AgentRuntime { try AgentRuntime(configuration: .init( authProvider: DemoChatGPTAuthProvider(), @@ -688,7 +824,8 @@ private func makeHistoryRuntime( backend: backend, approvalPresenter: approvalPresenter, stateStore: stateStore, - tools: tools + tools: tools, + contextCompaction: contextCompaction )) } @@ -786,6 +923,100 @@ private actor PartialEmissionGate { } } +private actor CompactingTestBackend: AgentBackend, AgentBackendContextCompacting { + nonisolated let baseInstructions: String? = nil + + private let failOnHistoryCountAbove: Int? + private var threads: [String: AgentThread] = [:] + private var compactCalls = 0 + private var historyCounts: [Int] = [] + + init(failOnHistoryCountAbove: Int? = nil) { + self.failOnHistoryCountAbove = failOnHistoryCountAbove + } + + func createThread(session _: ChatGPTSession) async throws -> AgentThread { + let thread = AgentThread(id: UUID().uuidString) + threads[thread.id] = thread + return thread + } + + func resumeThread(id: String, session _: ChatGPTSession) async throws -> AgentThread { + if let thread = threads[id] { + return thread + } + let thread = AgentThread(id: id) + threads[id] = thread + return thread + } + + func beginTurn( + thread: AgentThread, + history: [AgentMessage], + message: UserMessageRequest, + instructions _: String, + responseFormat _: AgentStructuredOutputFormat?, + streamedStructuredOutput _: AgentStreamedStructuredOutputRequest?, + tools _: [ToolDefinition], + session _: ChatGPTSession + ) async throws -> any AgentTurnStreaming { + historyCounts.append(history.count) + if let failOnHistoryCountAbove, + history.count > failOnHistoryCountAbove, + !history.contains(where: { $0.role == .system && $0.text.contains("Compacted conversation summary") }) { + throw AgentRuntimeError( + code: "context_limit_exceeded", + message: "Maximum context length exceeded." + ) + } + + return MockAgentTurnSession( + thread: thread, + message: message, + selectedTool: nil, + structuredResponseText: nil, + streamedStructuredOutput: nil + ) + } + + func compactContext( + thread: AgentThread, + effectiveHistory: [AgentMessage], + instructions _: String, + tools _: [ToolDefinition], + session _: ChatGPTSession + ) async throws -> AgentCompactionResult { + compactCalls += 1 + let lastUser = effectiveHistory.last(where: { $0.role == .user }) + let lastAssistant = effectiveHistory.last(where: { $0.role == .assistant }) + var compacted = [ + AgentMessage( + threadID: thread.id, + role: .system, + text: "Compacted conversation summary" + ), + ] + if let lastUser { + compacted.append(lastUser) + } + if let lastAssistant { + compacted.append(lastAssistant) + } + return AgentCompactionResult( + effectiveMessages: compacted, + summaryPreview: "Compacted conversation summary" + ) + } + + func compactCallCount() -> Int { + compactCalls + } + + func beginTurnHistoryCounts() -> [Int] { + historyCounts + } +} + private actor BlockingStructuredPartialBackend: AgentBackend { private let gate = PartialEmissionGate() private var latestThreadID: String? From ae83ba06471e9e0ad91ca0fa4fa034d6e7885f75 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 13:16:06 +1100 Subject: [PATCH 07/19] Refresh docs for GRDB storage and context compaction --- DemoApp/README.md | 9 +++++++++ README.md | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/DemoApp/README.md b/DemoApp/README.md index 16ad85f..27f82b9 100644 --- a/DemoApp/README.md +++ b/DemoApp/README.md @@ -22,6 +22,7 @@ The Xcode project is the source of truth for the demo app. Edit it directly in X - lets you attach a photo from the library and send it with or without text - renders attached user images in the transcript - streams assistant output into the UI +- includes a thread-level `Context Compaction` card so you can compact effective prompt state without removing visible transcript history - supports approval prompts for host-defined tools that opt into `requiresApproval` - demonstrates thread-pinned personas and one-turn persona overrides - includes first-class framework skill examples for `health_coach` and `travel_planner` @@ -49,6 +50,14 @@ The demo uses the new configuration-first surface: The app links `CodexKit` and `CodexKitUI` from the repo's local `Package.swift`, so it exercises the same SPM integration path a host app would use. Runtime state is stored in `runtime-state.sqlite`, memory is stored in `memory.sqlite`, and the GRDB-backed runtime store will import an older sibling `runtime-state.json` file automatically on first launch if one exists. +The checked-in demo enables context compaction in automatic mode. In a thread detail screen, the `Context Compaction` card shows: + +- visible transcript message count +- effective prompt message count +- compaction generation +- last compaction reason/time +- a `Compact Context Now` action for manual testing + ## Files - `DemoApp/AssistantRuntimeDemoApp/AssistantRuntimeDemoApp.swift` diff --git a/README.md b/README.md index 8e4d375..9e2ace3 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Use `CodexKit` if you are building a SwiftUI/iOS app and want: - typed one-shot text and structured completions - host-defined tools with approval gates - persona- and skill-aware agent behavior +- hidden runtime context compaction with preserved user-visible history - share/import-friendly message construction The SDK stays tool-agnostic. Your app defines the tool surface and runtime UX. @@ -166,6 +167,11 @@ Bundled runtime-state stores now include: - `InMemoryRuntimeStateStore` Useful for previews and tests. +The bundled memory store is: + +- `SQLiteMemoryStore` + Uses SQLite through GRDB for persisted memory records. Ordinary record reads/writes use GRDB requests directly; the remaining raw SQL is limited to SQLite-specific `PRAGMA` and FTS `MATCH` / `bm25()` paths. + If you are migrating from the older file-backed store, `GRDBRuntimeStateStore(url:)` automatically imports a sibling `*.json` runtime state file on first open. For example, `runtime-state.sqlite` will import from `runtime-state.json` if it exists and the SQLite store is still empty. `ChatGPTAuthProvider` supports: @@ -229,6 +235,7 @@ Available values: - use `GRDBRuntimeStateStore` for persisted production state - use `fetchThreadHistory(id:query:)` and `fetchLatestStructuredOutputMetadata(id:)` for common thread inspection - use the typed `execute(_:)` query surface when you need more control over filtering, sorting, paging, or cross-thread reads +- use hidden context compaction when you want to optimize future turns without removing preserved thread history from UI or inspection APIs ```swift let stateStore = try GRDBRuntimeStateStore( @@ -259,6 +266,41 @@ let snapshots = try await runtime.execute( This path also supports explicit history redaction and whole-thread deletion without forcing hosts to replay raw event streams themselves. +## Effective Context Compaction + +`CodexKit` can compact the runtime's effective prompt context without mutating canonical thread history. + +- visible history stays intact for `messages(for:)`, `fetchThreadHistory(...)`, and normal thread UI +- compacted effective context is used only for future turns +- compaction markers are persisted for audit/debug semantics and hidden from normal history reads by default +- manual compaction is always available when the feature is enabled; `.automatic` additionally lets the runtime compact pre-turn or after a context-limit retry path + +```swift +let runtime = try AgentRuntime(configuration: .init( + authProvider: authProvider, + secureStore: secureStore, + backend: backend, + approvalPresenter: approvalPresenter, + stateStore: stateStore, + contextCompaction: .init( + isEnabled: true, + mode: .automatic + ) +)) + +let contextState = try await runtime.compactThreadContext(id: thread.id) +print(contextState.generation) +``` + +For debug tooling or host inspection, you can also read the compacted effective context directly: + +```swift +let contextState = try await runtime.fetchThreadContextState(id: thread.id) +let contexts = try await runtime.execute( + ThreadContextStateQuery(threadIDs: [thread.id]) +) +``` + ## Typed Completions For most apps, there are now three common send paths: @@ -693,6 +735,7 @@ The demo app exercises: - Responses web search in checked-in configuration - thread-pinned personas and one-turn overrides - a one-tap skill policy probe that compares tool behavior in normal vs skill-constrained threads +- a thread-level `Context Compaction` card that shows visible-vs-effective message counts and lets you trigger manual compaction - a Health Coach tab with HealthKit steps, AI-generated coaching, local reminders, and tone switching - GRDB-backed runtime persistence with automatic import from older `runtime-state.json` state on first launch From 1a38b78a65b286409b5da6f2e8ef96d14e21908c Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 13:22:03 +1100 Subject: [PATCH 08/19] Fix demo shortcuts persona reference --- DemoApp/AssistantRuntimeDemoApp/Shared/DemoAppShortcuts.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/DemoAppShortcuts.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/DemoAppShortcuts.swift index 3e53100..540f73b 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/DemoAppShortcuts.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/DemoAppShortcuts.swift @@ -3,6 +3,8 @@ import Foundation #if os(iOS) import AppIntents +private let shortcutCatalog = DemoCatalog() + enum DemoShareImportExamples { static func articleSummaryRequest( excerpt: String, @@ -93,7 +95,7 @@ struct DraftShippingSupportReplyIntent: AppIntent { let thread = try await runtime.createThread( title: "Shortcut Support", - personaStack: AgentDemoViewModel.supportPersona + personaStack: shortcutCatalog.supportPersona ) let draft = try await runtime.sendMessage( UserMessageRequest( From 64ec785fbb5eda6362599b97438c157c73d79544 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 15:04:37 +1100 Subject: [PATCH 09/19] Refactor runtime modules and add Combine observations --- .../project.pbxproj | 8 + .../AgentDemoViewModel+ComposerState.swift | 104 ++ .../AgentDemoViewModel+ThreadState.swift | 306 +++++ .../Shared/AgentDemoViewModel.swift | 400 ------- .../CodexKit/Auth/ChatGPTOAuthProvider.swift | 387 ------ .../CodexKit/Auth/ChatGPTOAuthSupport.swift | 262 +++++ .../Auth/ChatGPTOAuthWebAuthentication.swift | 138 +++ .../CodexKit/Memory/SQLiteMemoryStore.swift | 456 +------- .../Memory/SQLiteMemoryStoreRepository.swift | 328 ++++++ .../Memory/SQLiteMemoryStoreSchema.swift | 103 ++ .../Runtime/AgentHistory+Serialization.swift | 314 +++++ Sources/CodexKit/Runtime/AgentHistory.swift | 580 --------- .../CodexKit/Runtime/AgentHistoryItems.swift | 129 ++ .../Runtime/AgentHistoryPendingState.swift | 140 +++ .../Runtime/AgentRuntime+Messaging.swift | 233 ++-- ...entRuntime+StructuredTurnConsumption.swift | 310 +++++ ...gentRuntime+ToolInvocationResolution.swift | 192 +++ .../AgentRuntime+TurnConsumption.swift | 629 ---------- Sources/CodexKit/Runtime/AgentRuntime.swift | 60 + .../Runtime/AgentRuntimeObservation.swift | 25 + .../Runtime/FileRuntimeStateStore.swift | 225 ++++ .../GRDBRuntimeStateStore+Persistence.swift | 403 +++++++ .../GRDBRuntimeStateStore+Queries.swift | 459 ++++++++ .../Runtime/GRDBRuntimeStateStore.swift | 1042 +---------------- .../Runtime/GRDBRuntimeStateStoreRows.swift | 120 ++ .../Runtime/InMemoryRuntimeStateStore.swift | 56 + .../CodexKit/Runtime/RuntimeStateStore.swift | 1029 ---------------- .../StoredRuntimeState+Execution.swift | 368 ++++++ .../Runtime/StoredRuntimeState+Queries.swift | 385 ++++++ .../AgentRuntimeHistoryCompactionTests.swift | 94 ++ .../AgentRuntimeHistoryGRDBTests.swift | 136 +++ .../AgentRuntimeHistoryTestSupport.swift | 323 +++++ .../AgentRuntimeHistoryTests.swift | 659 ----------- .../AgentRuntimeMessageBehaviorTests.swift | 263 +++++ .../AgentRuntimeMessageTests.swift | 304 ----- .../AgentRuntimePersonaSkillPolicyTests.swift | 109 ++ .../AgentRuntimePersonaSkillTests.swift | 261 ----- .../CodexResponsesBackendRetryTests.swift | 98 ++ ...CodexResponsesBackendStructuredTests.swift | 92 ++ .../CodexResponsesBackendTests.swift | 340 ------ 40 files changed, 5724 insertions(+), 6146 deletions(-) create mode 100644 DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ComposerState.swift create mode 100644 DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift create mode 100644 Sources/CodexKit/Auth/ChatGPTOAuthSupport.swift create mode 100644 Sources/CodexKit/Auth/ChatGPTOAuthWebAuthentication.swift create mode 100644 Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift create mode 100644 Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift create mode 100644 Sources/CodexKit/Runtime/AgentHistory+Serialization.swift create mode 100644 Sources/CodexKit/Runtime/AgentHistoryItems.swift create mode 100644 Sources/CodexKit/Runtime/AgentHistoryPendingState.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+StructuredTurnConsumption.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+ToolInvocationResolution.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntimeObservation.swift create mode 100644 Sources/CodexKit/Runtime/FileRuntimeStateStore.swift create mode 100644 Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift create mode 100644 Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift create mode 100644 Sources/CodexKit/Runtime/GRDBRuntimeStateStoreRows.swift create mode 100644 Sources/CodexKit/Runtime/InMemoryRuntimeStateStore.swift create mode 100644 Sources/CodexKit/Runtime/StoredRuntimeState+Execution.swift create mode 100644 Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift create mode 100644 Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift create mode 100644 Tests/CodexKitTests/AgentRuntimeHistoryGRDBTests.swift create mode 100644 Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift create mode 100644 Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift create mode 100644 Tests/CodexKitTests/AgentRuntimePersonaSkillPolicyTests.swift create mode 100644 Tests/CodexKitTests/CodexResponsesBackendRetryTests.swift create mode 100644 Tests/CodexKitTests/CodexResponsesBackendStructuredTests.swift diff --git a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj index 39fd65e..bfb29e4 100644 --- a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj +++ b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj @@ -25,6 +25,8 @@ 1A2B3C4D5E6F700000000010 /* DemoMemoryExamples.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000010 /* DemoMemoryExamples.swift */; }; 1A2B3C4D5E6F700000000011 /* AgentDemoViewModel+Memory.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000011 /* AgentDemoViewModel+Memory.swift */; }; 1A2B3C4D5E6F700000000012 /* MemoryDemoView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000012 /* MemoryDemoView.swift */; }; + 1A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */; }; + 1A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */; }; 7482123BC63AC10F104DE092 /* AssistantRuntimeDemoApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5A6999E6475919476E726E8C /* AssistantRuntimeDemoApp.swift */; }; 84726927B752451499D9257F /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 906A95007C8ECB92CFC2CE15 /* Foundation.framework */; }; B060448C6464C41789B56EED /* AgentDemoView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3CA22585116A120BA97F76B8 /* AgentDemoView.swift */; }; @@ -53,6 +55,8 @@ 2A2B3C4D5E6F700000000010 /* DemoMemoryExamples.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = DemoMemoryExamples.swift; sourceTree = ""; }; 2A2B3C4D5E6F700000000011 /* AgentDemoViewModel+Memory.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+Memory.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000012 /* MemoryDemoView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = MemoryDemoView.swift; sourceTree = ""; }; + 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+ComposerState.swift"; sourceTree = ""; }; + 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+ThreadState.swift"; sourceTree = ""; }; 2481147A958D00EB4A70C928 /* AgentDemoViewModel.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoViewModel.swift; sourceTree = ""; }; 3CA22585116A120BA97F76B8 /* AgentDemoView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoView.swift; sourceTree = ""; }; 5A6999E6475919476E726E8C /* AssistantRuntimeDemoApp.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AssistantRuntimeDemoApp.swift; sourceTree = ""; }; @@ -133,6 +137,8 @@ 2A2B3C4D5E6F700000000004 /* DeviceCodePromptView.swift */, 2A2B3C4D5E6F700000000005 /* Image+PlatformInit.swift */, 2481147A958D00EB4A70C928 /* AgentDemoViewModel.swift */, + 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */, + 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */, 2A2B3C4D5E6F700000000006 /* AgentDemoViewModel+Messaging.swift */, 2A2B3C4D5E6F700000000007 /* AgentDemoViewModel+Tools.swift */, 2A2B3C4D5E6F700000000008 /* AgentDemoViewModel+HealthCoach.swift */, @@ -240,6 +246,8 @@ 1A2B3C4D5E6F700000000010 /* DemoMemoryExamples.swift in Sources */, 1A2B3C4D5E6F700000000011 /* AgentDemoViewModel+Memory.swift in Sources */, 1A2B3C4D5E6F700000000012 /* MemoryDemoView.swift in Sources */, + 1A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift in Sources */, + 1A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift in Sources */, 7482123BC63AC10F104DE092 /* AssistantRuntimeDemoApp.swift in Sources */, BB4F38E64D1EBBB3821AC4E3 /* AgentDemoRuntimeFactory.swift in Sources */, B060448C6464C41789B56EED /* AgentDemoView.swift in Sources */, diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ComposerState.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ComposerState.swift new file mode 100644 index 0000000..26a617a --- /dev/null +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ComposerState.swift @@ -0,0 +1,104 @@ +import CodexKit +import Foundation + +@MainActor +extension AgentDemoViewModel { + func sendComposerText() async { + let outgoingText = composerText.trimmingCharacters(in: .whitespacesAndNewlines) + let outgoingImages = pendingComposerImages + + guard !outgoingText.isEmpty || !outgoingImages.isEmpty else { + return + } + + composerText = "" + pendingComposerImages = [] + await sendMessageInternal( + outgoingText, + images: outgoingImages + ) + } + + func queueComposerImage( + data: Data, + mimeType: String + ) { + pendingComposerImages.append( + AgentImageAttachment( + mimeType: mimeType, + data: data + ) + ) + } + + func removePendingComposerImage(id: String) { + pendingComposerImages.removeAll { $0.id == id } + } + + func reportError(_ message: String) { + developerErrorLog(message) + lastError = message + } + + func reportError(_ error: Error) { + guard !diagnostics.isCancellationError(error) else { + developerLog("Ignoring CancellationError from async UI task.") + return + } + developerErrorLog(error.localizedDescription) + lastError = error.localizedDescription + } + + func approvePendingRequest() { + approvalInbox.approveCurrent() + } + + func denyPendingRequest() { + approvalInbox.denyCurrent() + } + + func dismissError() { + lastError = nil + } + + func developerLog(_ message: String) { + guard developerLoggingEnabled else { + return + } + diagnostics.log(message) + } + + func developerErrorLog(_ message: String) { + guard developerLoggingEnabled else { + return + } + diagnostics.error(message) + } + + func setMessages(_ incoming: [AgentMessage]) { + messages = deduplicatedMessages(incoming) + } + + func upsertMessage(_ message: AgentMessage) { + if let existingIndex = messages.firstIndex(where: { $0.id == message.id }) { + messages[existingIndex] = message + return + } + messages.append(message) + } + + private func deduplicatedMessages(_ incoming: [AgentMessage]) -> [AgentMessage] { + var seen = Set() + var reversedUnique: [AgentMessage] = [] + reversedUnique.reserveCapacity(incoming.count) + + for message in incoming.reversed() { + guard seen.insert(message.id).inserted else { + continue + } + reversedUnique.append(message) + } + + return reversedUnique.reversed() + } +} diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift new file mode 100644 index 0000000..e9b5f6f --- /dev/null +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift @@ -0,0 +1,306 @@ +import CodexKit +import Foundation + +@MainActor +extension AgentDemoViewModel { + func restore() async { + developerLog( + "Restore started. store=\(resolvedStateURL.path) legacyJSONPresent=\(FileManager.default.fileExists(atPath: legacyStateURL.path))" + ) + do { + _ = try await runtime.restore() + await registerDemoTool() + await registerDemoSkills() + await refreshSnapshot() + developerLog( + "Restore finished. sessionPresent=\(session != nil) threadCount=\(threads.count)" + ) + } catch { + reportError(error) + } + } + + func signIn(using authenticationMethod: DemoAuthenticationMethod) async { + guard !isAuthenticating else { + return + } + + isAuthenticating = true + lastError = nil + currentAuthenticationMethod = authenticationMethod + developerLog("Sign-in started. method=\(authenticationMethod.rawValue)") + runtime = AgentDemoRuntimeFactory.makeRuntime( + authenticationMethod: authenticationMethod, + model: model, + enableWebSearch: enableWebSearch, + reasoningEffort: reasoningEffort, + stateURL: stateURL, + keychainAccount: keychainAccount, + approvalInbox: approvalInbox, + deviceCodePromptCoordinator: deviceCodePromptCoordinator + ) + + defer { + isAuthenticating = false + } + + do { + _ = try await runtime.restore() + await registerDemoTool() + await registerDemoSkills() + session = try await runtime.signIn() + await refreshSnapshot() + if healthCoachInitialized { + await refreshHealthCoachProgress() + } + developerLog( + "Sign-in finished. account=\(session?.account.email ?? "") threadCount=\(threads.count)" + ) + } catch { + await deviceCodePromptCoordinator.clear() + await refreshSnapshot() + reportError(error) + } + } + + func updateReasoningEffort(_ reasoningEffort: ReasoningEffort) async { + guard self.reasoningEffort != reasoningEffort else { + return + } + + guard canReconfigureRuntime else { + lastError = "Wait for the current turn to finish before switching thinking level." + return + } + + self.reasoningEffort = reasoningEffort + developerLog("Reconfiguring runtime. reasoningEffort=\(reasoningEffort.rawValue)") + let preservedActiveThreadID = activeThreadID + let preservedHealthCoachThreadID = healthCoachThreadID + + runtime = AgentDemoRuntimeFactory.makeRuntime( + authenticationMethod: currentAuthenticationMethod, + model: model, + enableWebSearch: enableWebSearch, + reasoningEffort: reasoningEffort, + stateURL: stateURL, + keychainAccount: keychainAccount, + approvalInbox: approvalInbox, + deviceCodePromptCoordinator: deviceCodePromptCoordinator + ) + + do { + _ = try await runtime.restore() + await registerDemoTool() + await refreshSnapshot() + + if let preservedActiveThreadID, + threads.contains(where: { $0.id == preservedActiveThreadID }) { + activeThreadID = preservedActiveThreadID + messages = await runtime.messages(for: preservedActiveThreadID) + } + + if let preservedHealthCoachThreadID, + threads.contains(where: { $0.id == preservedHealthCoachThreadID }) { + healthCoachThreadID = preservedHealthCoachThreadID + } + developerLog( + "Runtime reconfigured. reasoningEffort=\(reasoningEffort.rawValue) threadCount=\(threads.count)" + ) + } catch { + reportError(error) + } + } + + func createThread() async { + await createThreadInternal( + title: nil, + personaStack: nil + ) + } + + func createSupportPersonaThread() async { + await createThreadInternal( + title: "Support Persona Demo", + personaStack: catalog.supportPersona + ) + } + + func setPlannerPersonaOnActiveThread() async { + guard let activeThreadID else { + lastError = "Create or select a thread before swapping personas." + return + } + + do { + try await runtime.setPersonaStack( + catalog.plannerPersona, + for: activeThreadID + ) + threads = await runtime.threads() + } catch { + reportError(error) + } + } + + func sendReviewerOverrideExample() async { + if activeThreadID == nil { + await createSupportPersonaThread() + } + + await sendMessageInternal( + "Review this conversation setup and tell me the biggest risks first.", + personaOverride: catalog.reviewerOverridePersona + ) + } + + func createHealthCoachSkillThread() async { + await createThreadInternal( + title: "Skill Demo: Health Coach", + personaStack: nil, + skillIDs: [catalog.healthCoachSkill.id] + ) + } + + func createTravelPlannerSkillThread() async { + await createThreadInternal( + title: "Skill Demo: Travel Planner", + personaStack: nil, + skillIDs: [catalog.travelPlannerSkill.id] + ) + } + + func activateThread(id: String) async { + activeThreadID = id + setMessages(await runtime.messages(for: id)) + streamingText = "" + await refreshThreadContextState(for: id) + } + + func signOut() async { + do { + try await runtime.signOut() + await deviceCodePromptCoordinator.clear() + session = nil + threads = [] + messages = [] + streamingText = "" + composerText = "" + pendingComposerImages = [] + lastResolvedInstructions = nil + lastResolvedInstructionsThreadTitle = nil + isRunningSkillPolicyProbe = false + skillPolicyProbeResult = nil + isRunningStructuredOutputDemo = false + structuredShippingReplyResult = nil + structuredImportedSummaryResult = nil + isRunningMemoryDemo = false + automaticMemoryResult = nil + automaticPolicyMemoryResult = nil + guidedMemoryResult = nil + rawMemoryResult = nil + memoryPreviewResult = nil + activeThreadID = nil + healthCoachThreadID = nil + healthCoachFeedback = "Set a step goal, then start moving." + healthLastUpdatedAt = nil + healthKitAuthorized = false + notificationAuthorized = false + healthCoachInitialized = false + cachedAICoachFeedbackKey = nil + cachedAICoachFeedbackGeneratedAt = nil + cachedAIReminderBody = nil + cachedAIReminderKey = nil + cachedAIReminderGeneratedAt = nil + lastError = nil + } catch { + reportError(error) + } + } + + func refreshSnapshot() async { + session = await runtime.currentSession() + guard session != nil else { + clearConversationSnapshot() + developerLog("Snapshot refreshed with no active session.") + return + } + + threads = await runtime.threads() + developerLog( + "Snapshot refreshed. session=\(session?.account.email ?? "") threadCount=\(threads.count)" + ) + + let selectedThreadID = activeThreadID + if let selectedThreadID, + threads.contains(where: { $0.id == selectedThreadID }) { + setMessages(await runtime.messages(for: selectedThreadID)) + await refreshThreadContextState(for: selectedThreadID) + return + } + + if let firstThread = threads.first { + activeThreadID = firstThread.id + setMessages(await runtime.messages(for: firstThread.id)) + await refreshThreadContextState(for: firstThread.id) + } else { + activeThreadID = nil + messages = [] + activeThreadContextState = nil + } + } + + func clearConversationSnapshot() { + threads = [] + messages = [] + streamingText = "" + pendingComposerImages = [] + lastResolvedInstructions = nil + lastResolvedInstructionsThreadTitle = nil + isRunningSkillPolicyProbe = false + skillPolicyProbeResult = nil + activeThreadID = nil + activeThreadContextState = nil + } + + func refreshThreadContextState(for threadID: String? = nil) async { + guard let resolvedThreadID = threadID ?? activeThreadID else { + activeThreadContextState = nil + return + } + + do { + activeThreadContextState = try await runtime.fetchThreadContextState(id: resolvedThreadID) + } catch { + activeThreadContextState = nil + developerErrorLog("Failed to fetch thread context state. threadID=\(resolvedThreadID) error=\(error.localizedDescription)") + } + } + + func compactActiveThreadContext() async { + guard let activeThreadID else { + lastError = "Select a thread before compacting its prompt context." + return + } + guard !isCompactingThreadContext else { + return + } + + isCompactingThreadContext = true + defer { + isCompactingThreadContext = false + } + + do { + developerLog("Manual context compaction started. threadID=\(activeThreadID)") + activeThreadContextState = try await runtime.compactThreadContext(id: activeThreadID) + threads = await runtime.threads() + setMessages(await runtime.messages(for: activeThreadID)) + developerLog( + "Manual context compaction finished. threadID=\(activeThreadID) generation=\(activeThreadContextState?.generation ?? 0) effectiveMessages=\(activeThreadContextState?.effectiveMessages.count ?? 0)" + ) + } catch { + reportError(error) + } + } +} diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift index 3a8c4a8..016eec3 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift @@ -325,173 +325,6 @@ final class AgentDemoViewModel: @unchecked Sendable { } } - func restore() async { - developerLog( - "Restore started. store=\(resolvedStateURL.path) legacyJSONPresent=\(FileManager.default.fileExists(atPath: legacyStateURL.path))" - ) - do { - _ = try await runtime.restore() - await registerDemoTool() - await registerDemoSkills() - await refreshSnapshot() - developerLog( - "Restore finished. sessionPresent=\(session != nil) threadCount=\(threads.count)" - ) - } catch { - reportError(error) - } - } - - func signIn(using authenticationMethod: DemoAuthenticationMethod) async { - guard !isAuthenticating else { - return - } - - isAuthenticating = true - lastError = nil - currentAuthenticationMethod = authenticationMethod - developerLog("Sign-in started. method=\(authenticationMethod.rawValue)") - runtime = AgentDemoRuntimeFactory.makeRuntime( - authenticationMethod: authenticationMethod, - model: model, - enableWebSearch: enableWebSearch, - reasoningEffort: reasoningEffort, - stateURL: stateURL, - keychainAccount: keychainAccount, - approvalInbox: approvalInbox, - deviceCodePromptCoordinator: deviceCodePromptCoordinator - ) - - defer { - isAuthenticating = false - } - - do { - _ = try await runtime.restore() - await registerDemoTool() - await registerDemoSkills() - session = try await runtime.signIn() - await refreshSnapshot() - if healthCoachInitialized { - await refreshHealthCoachProgress() - } - developerLog( - "Sign-in finished. account=\(session?.account.email ?? "") threadCount=\(threads.count)" - ) - } catch { - await deviceCodePromptCoordinator.clear() - await refreshSnapshot() - reportError(error) - } - } - - func updateReasoningEffort(_ reasoningEffort: ReasoningEffort) async { - guard self.reasoningEffort != reasoningEffort else { - return - } - - guard canReconfigureRuntime else { - lastError = "Wait for the current turn to finish before switching thinking level." - return - } - - self.reasoningEffort = reasoningEffort - developerLog("Reconfiguring runtime. reasoningEffort=\(reasoningEffort.rawValue)") - let preservedActiveThreadID = activeThreadID - let preservedHealthCoachThreadID = healthCoachThreadID - - runtime = AgentDemoRuntimeFactory.makeRuntime( - authenticationMethod: currentAuthenticationMethod, - model: model, - enableWebSearch: enableWebSearch, - reasoningEffort: reasoningEffort, - stateURL: stateURL, - keychainAccount: keychainAccount, - approvalInbox: approvalInbox, - deviceCodePromptCoordinator: deviceCodePromptCoordinator - ) - - do { - _ = try await runtime.restore() - await registerDemoTool() - await refreshSnapshot() - - if let preservedActiveThreadID, - threads.contains(where: { $0.id == preservedActiveThreadID }) { - activeThreadID = preservedActiveThreadID - messages = await runtime.messages(for: preservedActiveThreadID) - } - - if let preservedHealthCoachThreadID, - threads.contains(where: { $0.id == preservedHealthCoachThreadID }) { - healthCoachThreadID = preservedHealthCoachThreadID - } - developerLog( - "Runtime reconfigured. reasoningEffort=\(reasoningEffort.rawValue) threadCount=\(threads.count)" - ) - } catch { - reportError(error) - } - } - - func createThread() async { - await createThreadInternal( - title: nil, - personaStack: nil - ) - } - - func createSupportPersonaThread() async { - await createThreadInternal( - title: "Support Persona Demo", - personaStack: catalog.supportPersona - ) - } - - func setPlannerPersonaOnActiveThread() async { - guard let activeThreadID else { - lastError = "Create or select a thread before swapping personas." - return - } - - do { - try await runtime.setPersonaStack( - catalog.plannerPersona, - for: activeThreadID - ) - threads = await runtime.threads() - } catch { - reportError(error) - } - } - - func sendReviewerOverrideExample() async { - if activeThreadID == nil { - await createSupportPersonaThread() - } - - await sendMessageInternal( - "Review this conversation setup and tell me the biggest risks first.", - personaOverride: catalog.reviewerOverridePersona - ) - } - - func createHealthCoachSkillThread() async { - await createThreadInternal( - title: "Skill Demo: Health Coach", - personaStack: nil, - skillIDs: [catalog.healthCoachSkill.id] - ) - } - - func createTravelPlannerSkillThread() async { - await createThreadInternal( - title: "Skill Demo: Travel Planner", - personaStack: nil, - skillIDs: [catalog.travelPlannerSkill.id] - ) - } - func personaSummary(for thread: AgentThread?) -> String? { guard let thread else { return nil @@ -506,237 +339,4 @@ final class AgentDemoViewModel: @unchecked Sendable { guard !sections.isEmpty else { return nil } return sections.joined(separator: " | ") } - - func activateThread(id: String) async { - activeThreadID = id - setMessages(await runtime.messages(for: id)) - streamingText = "" - await refreshThreadContextState(for: id) - } - - func sendComposerText() async { - let outgoingText = composerText.trimmingCharacters(in: .whitespacesAndNewlines) - let outgoingImages = pendingComposerImages - - guard !outgoingText.isEmpty || !outgoingImages.isEmpty else { - return - } - - composerText = "" - pendingComposerImages = [] - await sendMessageInternal( - outgoingText, - images: outgoingImages - ) - } - - func queueComposerImage( - data: Data, - mimeType: String - ) { - pendingComposerImages.append( - AgentImageAttachment( - mimeType: mimeType, - data: data - ) - ) - } - - func removePendingComposerImage(id: String) { - pendingComposerImages.removeAll { $0.id == id } - } - - func reportError(_ message: String) { - developerErrorLog(message) - lastError = message - } - - func reportError(_ error: Error) { - guard !diagnostics.isCancellationError(error) else { - developerLog("Ignoring CancellationError from async UI task.") - return - } - developerErrorLog(error.localizedDescription) - lastError = error.localizedDescription - } - - func approvePendingRequest() { - approvalInbox.approveCurrent() - } - - func denyPendingRequest() { - approvalInbox.denyCurrent() - } - - func dismissError() { - lastError = nil - } - - func developerLog(_ message: String) { - guard developerLoggingEnabled else { - return - } - diagnostics.log(message) - } - - func developerErrorLog(_ message: String) { - guard developerLoggingEnabled else { - return - } - diagnostics.error(message) - } - - func signOut() async { - do { - try await runtime.signOut() - await deviceCodePromptCoordinator.clear() - session = nil - threads = [] - messages = [] - streamingText = "" - composerText = "" - pendingComposerImages = [] - lastResolvedInstructions = nil - lastResolvedInstructionsThreadTitle = nil - isRunningSkillPolicyProbe = false - skillPolicyProbeResult = nil - isRunningStructuredOutputDemo = false - structuredShippingReplyResult = nil - structuredImportedSummaryResult = nil - isRunningMemoryDemo = false - automaticMemoryResult = nil - automaticPolicyMemoryResult = nil - guidedMemoryResult = nil - rawMemoryResult = nil - memoryPreviewResult = nil - activeThreadID = nil - healthCoachThreadID = nil - healthCoachFeedback = "Set a step goal, then start moving." - healthLastUpdatedAt = nil - healthKitAuthorized = false - notificationAuthorized = false - healthCoachInitialized = false - cachedAICoachFeedbackKey = nil - cachedAICoachFeedbackGeneratedAt = nil - cachedAIReminderBody = nil - cachedAIReminderKey = nil - cachedAIReminderGeneratedAt = nil - lastError = nil - } catch { - reportError(error) - } - } - - func refreshSnapshot() async { - session = await runtime.currentSession() - guard session != nil else { - clearConversationSnapshot() - developerLog("Snapshot refreshed with no active session.") - return - } - - threads = await runtime.threads() - developerLog( - "Snapshot refreshed. session=\(session?.account.email ?? "") threadCount=\(threads.count)" - ) - - let selectedThreadID = activeThreadID - if let selectedThreadID, - threads.contains(where: { $0.id == selectedThreadID }) { - setMessages(await runtime.messages(for: selectedThreadID)) - await refreshThreadContextState(for: selectedThreadID) - return - } - - if let firstThread = threads.first { - activeThreadID = firstThread.id - setMessages(await runtime.messages(for: firstThread.id)) - await refreshThreadContextState(for: firstThread.id) - } else { - activeThreadID = nil - messages = [] - activeThreadContextState = nil - } - } - - func clearConversationSnapshot() { - threads = [] - messages = [] - streamingText = "" - pendingComposerImages = [] - lastResolvedInstructions = nil - lastResolvedInstructionsThreadTitle = nil - isRunningSkillPolicyProbe = false - skillPolicyProbeResult = nil - activeThreadID = nil - activeThreadContextState = nil - } - - func refreshThreadContextState(for threadID: String? = nil) async { - guard let resolvedThreadID = threadID ?? activeThreadID else { - activeThreadContextState = nil - return - } - - do { - activeThreadContextState = try await runtime.fetchThreadContextState(id: resolvedThreadID) - } catch { - activeThreadContextState = nil - developerErrorLog("Failed to fetch thread context state. threadID=\(resolvedThreadID) error=\(error.localizedDescription)") - } - } - - func compactActiveThreadContext() async { - guard let activeThreadID else { - lastError = "Select a thread before compacting its prompt context." - return - } - guard !isCompactingThreadContext else { - return - } - - isCompactingThreadContext = true - defer { - isCompactingThreadContext = false - } - - do { - developerLog("Manual context compaction started. threadID=\(activeThreadID)") - activeThreadContextState = try await runtime.compactThreadContext(id: activeThreadID) - threads = await runtime.threads() - setMessages(await runtime.messages(for: activeThreadID)) - developerLog( - "Manual context compaction finished. threadID=\(activeThreadID) generation=\(activeThreadContextState?.generation ?? 0) effectiveMessages=\(activeThreadContextState?.effectiveMessages.count ?? 0)" - ) - } catch { - reportError(error) - } - } - - func setMessages(_ incoming: [AgentMessage]) { - messages = deduplicatedMessages(incoming) - } - - func upsertMessage(_ message: AgentMessage) { - if let existingIndex = messages.firstIndex(where: { $0.id == message.id }) { - messages[existingIndex] = message - return - } - messages.append(message) - } - - private func deduplicatedMessages(_ incoming: [AgentMessage]) -> [AgentMessage] { - var seen = Set() - var reversedUnique: [AgentMessage] = [] - reversedUnique.reserveCapacity(incoming.count) - - for message in incoming.reversed() { - guard seen.insert(message.id).inserted else { - continue - } - reversedUnique.append(message) - } - - return reversedUnique.reversed() - } } diff --git a/Sources/CodexKit/Auth/ChatGPTOAuthProvider.swift b/Sources/CodexKit/Auth/ChatGPTOAuthProvider.swift index a26f0ac..45817f7 100644 --- a/Sources/CodexKit/Auth/ChatGPTOAuthProvider.swift +++ b/Sources/CodexKit/Auth/ChatGPTOAuthProvider.swift @@ -83,122 +83,6 @@ public protocol ChatGPTWebAuthenticationProviding: Sendable { ) async throws -> URL } -#if canImport(AuthenticationServices) -@available(iOS 13.0, macOS 10.15, *) -public final class SystemChatGPTWebAuthenticationProvider: NSObject, ChatGPTWebAuthenticationProviding, @unchecked Sendable { - private var activeSession: ASWebAuthenticationSession? - private var activePresentationContextProvider: PresentationContextProvider? - private let presentationAnchorProvider: @MainActor @Sendable () -> ASPresentationAnchor? - - public override convenience init() { - self.init(presentationAnchorProvider: { - defaultPresentationAnchor() - }) - } - - public init( - presentationAnchorProvider: @escaping @MainActor @Sendable () -> ASPresentationAnchor? - ) { - self.presentationAnchorProvider = presentationAnchorProvider - super.init() - } - - public func authenticate( - authorizeURL: URL, - callbackScheme: String - ) async throws -> URL { - let anchor = try await MainActor.run { () throws -> ASPresentationAnchor in - guard let anchor = presentationAnchorProvider() else { - throw AgentRuntimeError( - code: "oauth_presentation_anchor_unavailable", - message: "The ChatGPT sign-in sheet could not be presented because no active window was available." - ) - } - return anchor - } - - return try await withCheckedThrowingContinuation { continuation in - Task { @MainActor [weak self] in - let session = ASWebAuthenticationSession( - url: authorizeURL, - callbackURLScheme: callbackScheme - ) { callbackURL, error in - self?.activeSession = nil - self?.activePresentationContextProvider = nil - - if let callbackURL { - continuation.resume(returning: callbackURL) - return - } - - continuation.resume( - throwing: error ?? AgentRuntimeError( - code: "oauth_authentication_cancelled", - message: "The ChatGPT sign-in flow did not complete." - ) - ) - } - let contextProvider = PresentationContextProvider(anchor: anchor) - session.presentationContextProvider = contextProvider - #if os(iOS) - session.prefersEphemeralWebBrowserSession = false - #endif - self?.activeSession = session - self?.activePresentationContextProvider = contextProvider - - guard session.start() else { - self?.activeSession = nil - self?.activePresentationContextProvider = nil - continuation.resume( - throwing: AgentRuntimeError( - code: "oauth_authentication_start_failed", - message: "The ChatGPT sign-in flow could not be started." - ) - ) - return - } - } - } - } -} - -@available(iOS 13.0, macOS 10.15, *) -private final class PresentationContextProvider: NSObject, ASWebAuthenticationPresentationContextProviding { - private let anchor: ASPresentationAnchor - - init(anchor: ASPresentationAnchor) { - self.anchor = anchor - } - - func presentationAnchor(for _: ASWebAuthenticationSession) -> ASPresentationAnchor { - anchor - } -} - -@MainActor -@available(iOS 13.0, macOS 10.15, *) -private func defaultPresentationAnchor() -> ASPresentationAnchor? { - #if canImport(UIKit) - let scenes = UIApplication.shared.connectedScenes - .compactMap { $0 as? UIWindowScene } - - if let keyWindow = scenes - .flatMap(\.windows) - .first(where: \.isKeyWindow) { - return keyWindow - } - - return scenes - .flatMap(\.windows) - .first(where: { !$0.isHidden }) - #elseif canImport(AppKit) - return NSApp.keyWindow ?? NSApp.mainWindow - #else - return nil - #endif -} -#endif - public final class ChatGPTOAuthProvider: ChatGPTAuthProviding, @unchecked Sendable { private let configuration: ChatGPTOAuthConfiguration private let urlSession: URLSession @@ -414,274 +298,3 @@ public final class ChatGPTOAuthProvider: ChatGPTAuthProviding, @unchecked Sendab ) } } - -private struct UnsupportedChatGPTWebAuthenticationProvider: ChatGPTWebAuthenticationProviding { - func authenticate( - authorizeURL _: URL, - callbackScheme _: String - ) async throws -> URL { - throw AgentRuntimeError( - code: "oauth_authentication_unsupported", - message: "Browser-based ChatGPT sign-in is not supported on this platform." - ) - } -} - -private extension URL { - var isLoopbackOAuthRedirect: Bool { - guard let scheme = scheme?.lowercased(), - scheme == "http" || scheme == "https", - let host = host?.lowercased(), - host == "localhost" || host == "127.0.0.1", - port != nil else { - return false - } - - return true - } -} - -struct AuthorizationCodeExchangeRequest: Encodable { - let clientID: String - let grantType: String - let code: String - let redirectURI: String - let codeVerifier: String - - enum CodingKeys: String, CodingKey { - case clientID = "client_id" - case grantType = "grant_type" - case code - case redirectURI = "redirect_uri" - case codeVerifier = "code_verifier" - } -} - -struct RefreshTokenRequest: Encodable { - let clientID: String - let grantType: String - let refreshToken: String - - enum CodingKeys: String, CodingKey { - case clientID = "client_id" - case grantType = "grant_type" - case refreshToken = "refresh_token" - } -} - -struct TokenResponse: Decodable { - let idToken: String - let accessToken: String - let refreshToken: String? - - enum CodingKeys: String, CodingKey { - case idToken = "id_token" - case accessToken = "access_token" - case refreshToken = "refresh_token" - } -} - -private struct OAuthCallback { - let code: String - let state: String - - init(url: URL) throws { - guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { - throw AgentRuntimeError( - code: "oauth_callback_invalid", - message: "The ChatGPT sign-in callback URL could not be parsed." - ) - } - - let queryItems = components.queryItems ?? [] - - if let errorDescription = queryItems.first(where: { $0.name == "error_description" || $0.name == "error" })?.value { - throw AgentRuntimeError( - code: "oauth_callback_failed", - message: "ChatGPT sign-in failed: \(errorDescription)" - ) - } - - guard let code = queryItems.first(where: { $0.name == "code" })?.value, !code.isEmpty else { - throw AgentRuntimeError( - code: "oauth_callback_missing_code", - message: "The ChatGPT sign-in callback did not include an authorization code." - ) - } - guard let state = queryItems.first(where: { $0.name == "state" })?.value, !state.isEmpty else { - throw AgentRuntimeError( - code: "oauth_callback_missing_state", - message: "The ChatGPT sign-in callback did not include a state parameter." - ) - } - - self.code = code - self.state = state - } -} - -private struct OAuthState { - let value: String - - static func generate() -> OAuthState { - OAuthState(value: randomData(count: 32).base64URLEncodedString()) - } -} - -private struct PKCECodes { - let codeVerifier: String - let codeChallenge: String - - init() { - let verifierData = randomData(count: 64) - codeVerifier = verifierData.base64URLEncodedString() - codeChallenge = Data(SHA256.hash(data: Data(codeVerifier.utf8))).base64URLEncodedString() - } -} - -struct JWTClaims: Decodable { - let email: String? - let chatGPTAccountID: String? - let planType: String? - let issuedAtSeconds: TimeInterval? - let expiresAtSeconds: TimeInterval? - - enum CodingKeys: String, CodingKey { - case email - case chatGPTAccountID = "chatgpt_account_id" - case planType = "chatgpt_plan_type" - case issuedAtSeconds = "iat" - case expiresAtSeconds = "exp" - } - - var issuedAt: Date? { - issuedAtSeconds.map(Date.init(timeIntervalSince1970:)) - } - - var expiresAt: Date? { - expiresAtSeconds.map(Date.init(timeIntervalSince1970:)) - } - - static func decode(from jwt: String) throws -> JWTClaims { - let parts = jwt.split(separator: ".") - guard parts.count >= 2 else { - throw AgentRuntimeError( - code: "jwt_invalid", - message: "A ChatGPT token could not be decoded." - ) - } - - let payload = try Data(base64URLString: String(parts[1])) - return try JSONDecoder().decode(JWTClaims.self, from: payload) - } -} - -func buildCodexLikeUserAgent( - originator: String, - product: String -) -> String { - let version = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String - ?? Bundle.main.infoDictionary?["CFBundleVersion"] as? String - ?? "0.1" - let platform = currentPlatformDescription() - return "\(originator)/\(version) (\(platform)) \(product)" -} - -private func currentPlatformDescription() -> String { - let osVersion = ProcessInfo.processInfo.operatingSystemVersion - let version = "\(osVersion.majorVersion).\(osVersion.minorVersion).\(osVersion.patchVersion)" - #if os(iOS) - let system = "iOS" - #elseif os(macOS) - let system = "macOS" - #elseif os(tvOS) - let system = "tvOS" - #elseif os(watchOS) - let system = "watchOS" - #elseif os(visionOS) - let system = "visionOS" - #else - let system = "Apple" - #endif - return "\(system) \(version); \(currentArchitecture())" -} - -private func currentArchitecture() -> String { - var systemInfo = utsname() - uname(&systemInfo) - let mirror = Mirror(reflecting: systemInfo.machine) - let identifier = mirror.children.reduce(into: "") { partial, element in - guard let value = element.value as? Int8, value != 0 else { - return - } - partial.append(Character(UnicodeScalar(UInt8(value)))) - } - return identifier.isEmpty ? "unknown" : identifier -} - -func urlEncodedFormBody(_ items: [(String, String)]) -> Data { - let body = items - .map { key, value in - "\(percentEncodeFormComponent(key))=\(percentEncodeFormComponent(value))" - } - .joined(separator: "&") - return Data(body.utf8) -} - -private func percentEncodeFormComponent(_ value: String) -> String { - let allowed = CharacterSet(charactersIn: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") - return value.addingPercentEncoding(withAllowedCharacters: allowed) ?? value -} - -func simplifyAuthErrorBody(_ data: Data) -> String { - let body = String(data: data, encoding: .utf8)?.trimmingCharacters(in: .whitespacesAndNewlines) - guard let body, !body.isEmpty else { - return "Unknown error" - } - - if body.localizedCaseInsensitiveContains(" String { - base64EncodedString() - .replacingOccurrences(of: "+", with: "-") - .replacingOccurrences(of: "/", with: "_") - .replacingOccurrences(of: "=", with: "") - } -} - -private func randomData(count: Int) -> Data { - var bytes = [UInt8](repeating: 0, count: count) - let status = SecRandomCopyBytes(kSecRandomDefault, bytes.count, &bytes) - if status == errSecSuccess { - return Data(bytes) - } - return Data((0 ..< count).map { _ in UInt8.random(in: .min ... .max) }) -} diff --git a/Sources/CodexKit/Auth/ChatGPTOAuthSupport.swift b/Sources/CodexKit/Auth/ChatGPTOAuthSupport.swift new file mode 100644 index 0000000..b6e3c75 --- /dev/null +++ b/Sources/CodexKit/Auth/ChatGPTOAuthSupport.swift @@ -0,0 +1,262 @@ +import CryptoKit +import Foundation +import Security + +extension URL { + var isLoopbackOAuthRedirect: Bool { + guard let scheme = scheme?.lowercased(), + scheme == "http" || scheme == "https", + let host = host?.lowercased(), + host == "localhost" || host == "127.0.0.1", + port != nil else { + return false + } + + return true + } +} + +struct AuthorizationCodeExchangeRequest: Encodable { + let clientID: String + let grantType: String + let code: String + let redirectURI: String + let codeVerifier: String + + enum CodingKeys: String, CodingKey { + case clientID = "client_id" + case grantType = "grant_type" + case code + case redirectURI = "redirect_uri" + case codeVerifier = "code_verifier" + } +} + +struct RefreshTokenRequest: Encodable { + let clientID: String + let grantType: String + let refreshToken: String + + enum CodingKeys: String, CodingKey { + case clientID = "client_id" + case grantType = "grant_type" + case refreshToken = "refresh_token" + } +} + +struct TokenResponse: Decodable { + let idToken: String + let accessToken: String + let refreshToken: String? + + enum CodingKeys: String, CodingKey { + case idToken = "id_token" + case accessToken = "access_token" + case refreshToken = "refresh_token" + } +} + +struct OAuthCallback { + let code: String + let state: String + + init(url: URL) throws { + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + throw AgentRuntimeError( + code: "oauth_callback_invalid", + message: "The ChatGPT sign-in callback URL could not be parsed." + ) + } + + let queryItems = components.queryItems ?? [] + + if let errorDescription = queryItems.first(where: { $0.name == "error_description" || $0.name == "error" })?.value { + throw AgentRuntimeError( + code: "oauth_callback_failed", + message: "ChatGPT sign-in failed: \(errorDescription)" + ) + } + + guard let code = queryItems.first(where: { $0.name == "code" })?.value, !code.isEmpty else { + throw AgentRuntimeError( + code: "oauth_callback_missing_code", + message: "The ChatGPT sign-in callback did not include an authorization code." + ) + } + guard let state = queryItems.first(where: { $0.name == "state" })?.value, !state.isEmpty else { + throw AgentRuntimeError( + code: "oauth_callback_missing_state", + message: "The ChatGPT sign-in callback did not include a state parameter." + ) + } + + self.code = code + self.state = state + } +} + +struct OAuthState { + let value: String + + static func generate() -> OAuthState { + OAuthState(value: randomData(count: 32).base64URLEncodedString()) + } +} + +struct PKCECodes { + let codeVerifier: String + let codeChallenge: String + + init() { + let verifierData = randomData(count: 64) + codeVerifier = verifierData.base64URLEncodedString() + codeChallenge = Data(SHA256.hash(data: Data(codeVerifier.utf8))).base64URLEncodedString() + } +} + +struct JWTClaims: Decodable { + let email: String? + let chatGPTAccountID: String? + let planType: String? + let issuedAtSeconds: TimeInterval? + let expiresAtSeconds: TimeInterval? + + enum CodingKeys: String, CodingKey { + case email + case chatGPTAccountID = "chatgpt_account_id" + case planType = "chatgpt_plan_type" + case issuedAtSeconds = "iat" + case expiresAtSeconds = "exp" + } + + var issuedAt: Date? { + issuedAtSeconds.map(Date.init(timeIntervalSince1970:)) + } + + var expiresAt: Date? { + expiresAtSeconds.map(Date.init(timeIntervalSince1970:)) + } + + static func decode(from jwt: String) throws -> JWTClaims { + let parts = jwt.split(separator: ".") + guard parts.count >= 2 else { + throw AgentRuntimeError( + code: "jwt_invalid", + message: "A ChatGPT token could not be decoded." + ) + } + + let payload = try Data(base64URLString: String(parts[1])) + return try JSONDecoder().decode(JWTClaims.self, from: payload) + } +} + +func buildCodexLikeUserAgent( + originator: String, + product: String +) -> String { + let version = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String + ?? Bundle.main.infoDictionary?["CFBundleVersion"] as? String + ?? "0.1" + let platform = currentPlatformDescription() + return "\(originator)/\(version) (\(platform)) \(product)" +} + +private func currentPlatformDescription() -> String { + let osVersion = ProcessInfo.processInfo.operatingSystemVersion + let version = "\(osVersion.majorVersion).\(osVersion.minorVersion).\(osVersion.patchVersion)" + #if os(iOS) + let system = "iOS" + #elseif os(macOS) + let system = "macOS" + #elseif os(tvOS) + let system = "tvOS" + #elseif os(watchOS) + let system = "watchOS" + #elseif os(visionOS) + let system = "visionOS" + #else + let system = "Apple" + #endif + return "\(system) \(version); \(currentArchitecture())" +} + +private func currentArchitecture() -> String { + var systemInfo = utsname() + uname(&systemInfo) + let mirror = Mirror(reflecting: systemInfo.machine) + let identifier = mirror.children.reduce(into: "") { partial, element in + guard let value = element.value as? Int8, value != 0 else { + return + } + partial.append(Character(UnicodeScalar(UInt8(value)))) + } + return identifier.isEmpty ? "unknown" : identifier +} + +func urlEncodedFormBody(_ items: [(String, String)]) -> Data { + let body = items + .map { key, value in + "\(percentEncodeFormComponent(key))=\(percentEncodeFormComponent(value))" + } + .joined(separator: "&") + return Data(body.utf8) +} + +private func percentEncodeFormComponent(_ value: String) -> String { + let allowed = CharacterSet(charactersIn: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") + return value.addingPercentEncoding(withAllowedCharacters: allowed) ?? value +} + +func simplifyAuthErrorBody(_ data: Data) -> String { + let body = String(data: data, encoding: .utf8)?.trimmingCharacters(in: .whitespacesAndNewlines) + guard let body, !body.isEmpty else { + return "Unknown error" + } + + if body.localizedCaseInsensitiveContains(" String { + base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + } +} + +private func randomData(count: Int) -> Data { + var bytes = [UInt8](repeating: 0, count: count) + let status = SecRandomCopyBytes(kSecRandomDefault, bytes.count, &bytes) + if status == errSecSuccess { + return Data(bytes) + } + return Data((0 ..< count).map { _ in UInt8.random(in: .min ... .max) }) +} diff --git a/Sources/CodexKit/Auth/ChatGPTOAuthWebAuthentication.swift b/Sources/CodexKit/Auth/ChatGPTOAuthWebAuthentication.swift new file mode 100644 index 0000000..a31d483 --- /dev/null +++ b/Sources/CodexKit/Auth/ChatGPTOAuthWebAuthentication.swift @@ -0,0 +1,138 @@ +import Foundation +#if canImport(AuthenticationServices) +import AuthenticationServices +#endif +#if canImport(UIKit) +import UIKit +#endif +#if canImport(AppKit) +import AppKit +#endif + +#if canImport(AuthenticationServices) +@available(iOS 13.0, macOS 10.15, *) +public final class SystemChatGPTWebAuthenticationProvider: NSObject, ChatGPTWebAuthenticationProviding, @unchecked Sendable { + private var activeSession: ASWebAuthenticationSession? + private var activePresentationContextProvider: PresentationContextProvider? + private let presentationAnchorProvider: @MainActor @Sendable () -> ASPresentationAnchor? + + public override convenience init() { + self.init(presentationAnchorProvider: { + defaultPresentationAnchor() + }) + } + + public init( + presentationAnchorProvider: @escaping @MainActor @Sendable () -> ASPresentationAnchor? + ) { + self.presentationAnchorProvider = presentationAnchorProvider + super.init() + } + + public func authenticate( + authorizeURL: URL, + callbackScheme: String + ) async throws -> URL { + let anchor = try await MainActor.run { () throws -> ASPresentationAnchor in + guard let anchor = presentationAnchorProvider() else { + throw AgentRuntimeError( + code: "oauth_presentation_anchor_unavailable", + message: "The ChatGPT sign-in sheet could not be presented because no active window was available." + ) + } + return anchor + } + + return try await withCheckedThrowingContinuation { continuation in + Task { @MainActor [weak self] in + let session = ASWebAuthenticationSession( + url: authorizeURL, + callbackURLScheme: callbackScheme + ) { callbackURL, error in + self?.activeSession = nil + self?.activePresentationContextProvider = nil + + if let callbackURL { + continuation.resume(returning: callbackURL) + return + } + + continuation.resume( + throwing: error ?? AgentRuntimeError( + code: "oauth_authentication_cancelled", + message: "The ChatGPT sign-in flow did not complete." + ) + ) + } + let contextProvider = PresentationContextProvider(anchor: anchor) + session.presentationContextProvider = contextProvider + #if os(iOS) + session.prefersEphemeralWebBrowserSession = false + #endif + self?.activeSession = session + self?.activePresentationContextProvider = contextProvider + + guard session.start() else { + self?.activeSession = nil + self?.activePresentationContextProvider = nil + continuation.resume( + throwing: AgentRuntimeError( + code: "oauth_authentication_start_failed", + message: "The ChatGPT sign-in flow could not be started." + ) + ) + return + } + } + } + } +} + +@available(iOS 13.0, macOS 10.15, *) +private final class PresentationContextProvider: NSObject, ASWebAuthenticationPresentationContextProviding { + private let anchor: ASPresentationAnchor + + init(anchor: ASPresentationAnchor) { + self.anchor = anchor + } + + func presentationAnchor(for _: ASWebAuthenticationSession) -> ASPresentationAnchor { + anchor + } +} + +@MainActor +@available(iOS 13.0, macOS 10.15, *) +private func defaultPresentationAnchor() -> ASPresentationAnchor? { + #if canImport(UIKit) + let scenes = UIApplication.shared.connectedScenes + .compactMap { $0 as? UIWindowScene } + + if let keyWindow = scenes + .flatMap(\.windows) + .first(where: \.isKeyWindow) { + return keyWindow + } + + return scenes + .flatMap(\.windows) + .first(where: { !$0.isHidden }) + #elseif canImport(AppKit) + return NSApp.keyWindow ?? NSApp.mainWindow + #else + return nil + #endif +} +#endif + +private struct UnsupportedChatGPTWebAuthenticationProvider: ChatGPTWebAuthenticationProviding { + func authenticate( + authorizeURL _: URL, + callbackScheme _: String + ) async throws -> URL { + throw AgentRuntimeError( + code: "oauth_authentication_unsupported", + message: "Browser-based ChatGPT sign-in is not supported on this platform." + ) + } +} diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift index de31094..0b6ee32 100644 --- a/Sources/CodexKit/Memory/SQLiteMemoryStore.swift +++ b/Sources/CodexKit/Memory/SQLiteMemoryStore.swift @@ -1,433 +1,6 @@ import Foundation import GRDB -private struct SQLiteMemoryStoreSchema: Sendable { - let currentVersion = 1 - - func existingVersion(in db: Database) throws -> Int { - try MemoryUserVersionQuery().execute(in: db) - } - - func makeMigrator() -> DatabaseMigrator { - var migrator = DatabaseMigrator() - - migrator.registerMigration("memory_store_v1") { db in - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_records ( - namespace TEXT NOT NULL, - id TEXT NOT NULL, - scope TEXT NOT NULL, - kind TEXT NOT NULL, - summary TEXT NOT NULL, - evidence_json TEXT NOT NULL, - importance REAL NOT NULL, - created_at REAL NOT NULL, - observed_at REAL, - expires_at REAL, - tags_json TEXT NOT NULL, - related_ids_json TEXT NOT NULL, - dedupe_key TEXT, - is_pinned INTEGER NOT NULL, - attributes_json TEXT, - status TEXT NOT NULL, - PRIMARY KEY(namespace, id) - ); - """) - - try db.execute(sql: """ - CREATE UNIQUE INDEX IF NOT EXISTS memory_records_namespace_dedupe - ON memory_records(namespace, dedupe_key) - WHERE dedupe_key IS NOT NULL; - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_scope - ON memory_records(namespace, scope); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_kind - ON memory_records(namespace, kind); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_status - ON memory_records(namespace, status); - """) - - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_tags ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - tag TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_tags_lookup - ON memory_tags(namespace, tag, record_id); - """) - - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_related_ids ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - related_id TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_related_lookup - ON memory_related_ids(namespace, related_id, record_id); - """) - - try db.execute(sql: """ - CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts - USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); - """) - - try db.execute(sql: "PRAGMA user_version = \(currentVersion)") - } - - return migrator - } -} - -private struct SQLiteMemoryStoreCodec: Sendable { - private let encoder = JSONEncoder() - private let decoder = JSONDecoder() - - func encode(_ value: T) throws -> String { - let data = try encoder.encode(value) - return String(decoding: data, as: UTF8.self) - } - - func encodeNullable(_ value: T?) throws -> String? { - guard let value else { - return nil - } - return try encode(value) - } - - func decode(_ type: T.Type, from string: String) throws -> T { - try decoder.decode(type, from: Data(string.utf8)) - } - - func decodeNullable(_ type: T.Type, from string: String?) throws -> T? { - guard let string else { - return nil - } - return try decode(type, from: string) - } - - func makeFTSQuery(from value: String?) -> String { - let tokens = MemoryQueryEngine.tokenize(value) - guard !tokens.isEmpty else { - return "" - } - return tokens.joined(separator: " OR ") - } -} - -private struct SQLiteMemoryStoreRepository: Sendable { - let codec: SQLiteMemoryStoreCodec - - func ensureRecordIDAvailable( - _ id: String, - namespace: String, - in db: Database - ) throws { - if try recordExists(id: id, namespace: namespace, in: db) { - throw MemoryStoreError.duplicateRecordID(id) - } - } - - func ensureDedupeKeyAvailable( - _ dedupeKey: String, - namespace: String, - in db: Database - ) throws { - if try recordExists(dedupeKey: dedupeKey, namespace: namespace, in: db) { - throw MemoryStoreError.duplicateDedupeKey(dedupeKey) - } - } - - func loadRecords( - namespace: String, - in db: Database - ) throws -> [MemoryRecord] { - let rows = try MemoryRecordRow - .filter(Column("namespace") == namespace) - .fetchAll(db) - - return try rows.map { row in - try makeRecord(from: row, namespace: namespace) - } - } - - func loadRawFTSScores( - namespace: String, - queryText: String?, - in db: Database - ) throws -> [String: Double] { - let matchQuery = codec.makeFTSQuery(from: queryText) - guard !matchQuery.isEmpty else { - return [:] - } - - let rows = try MemoryFTSScoreRowsRequest( - namespace: namespace, - matchQuery: matchQuery - ).execute(in: db) - - var scores: [String: Double] = [:] - for row in rows { - let recordID: String = row["record_id"] - let score: Double = row["score"] - scores[recordID] = score - } - return scores - } - - func loadRecord( - id: String, - namespace: String, - in db: Database - ) throws -> MemoryRecord? { - let row = try MemoryRecordRow - .filter(Column("namespace") == namespace) - .filter(Column("id") == id) - .fetchOne(db) - guard let row else { - return nil - } - return try makeRecord(from: row, namespace: namespace) - } - - func archiveRecord( - id: String, - namespace: String, - in db: Database - ) throws { - try db.execute( - sql: """ - UPDATE memory_records - SET status = ? - WHERE namespace = ? AND id = ?; - """, - arguments: [MemoryRecordStatus.archived.rawValue, namespace, id] - ) - } - - func deleteRecord( - id: String, - namespace: String, - in db: Database - ) throws { - try db.execute( - sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", - arguments: [namespace, id] - ) - try db.execute( - sql: "DELETE FROM memory_records WHERE namespace = ? AND id = ?;", - arguments: [namespace, id] - ) - } - - func deleteRecord( - withDedupeKey dedupeKey: String, - namespace: String, - in db: Database - ) throws { - if let row = try MemoryRecordRow - .filter(Column("namespace") == namespace) - .filter(Column("dedupe_key") == dedupeKey) - .fetchOne(db) { - try deleteRecord(id: row.id, namespace: namespace, in: db) - } - } - - func upsertRecord( - _ record: MemoryRecord, - in db: Database - ) throws { - try db.execute( - sql: """ - INSERT OR REPLACE INTO memory_records ( - namespace, id, scope, kind, summary, evidence_json, importance, - created_at, observed_at, expires_at, tags_json, related_ids_json, - dedupe_key, is_pinned, attributes_json, status - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); - """, - arguments: [ - record.namespace, - record.id, - record.scope.rawValue, - record.kind, - record.summary, - try codec.encode(record.evidence), - record.importance, - record.createdAt.timeIntervalSince1970, - record.observedAt?.timeIntervalSince1970, - record.expiresAt?.timeIntervalSince1970, - try codec.encode(record.tags), - try codec.encode(record.relatedIDs), - record.dedupeKey, - record.isPinned ? 1 : 0, - try codec.encodeNullable(record.attributes), - record.status.rawValue, - ] - ) - - try db.execute( - sql: "DELETE FROM memory_tags WHERE namespace = ? AND record_id = ?;", - arguments: [record.namespace, record.id] - ) - try db.execute( - sql: "DELETE FROM memory_related_ids WHERE namespace = ? AND record_id = ?;", - arguments: [record.namespace, record.id] - ) - try db.execute( - sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", - arguments: [record.namespace, record.id] - ) - - for tag in record.tags { - try db.execute( - sql: "INSERT INTO memory_tags(namespace, record_id, tag) VALUES (?, ?, ?);", - arguments: [record.namespace, record.id, tag] - ) - } - - for relatedID in record.relatedIDs { - try db.execute( - sql: """ - INSERT INTO memory_related_ids(namespace, record_id, related_id) - VALUES (?, ?, ?); - """, - arguments: [record.namespace, record.id, relatedID] - ) - } - - let ftsContent = ([record.summary] + record.evidence + record.tags + [record.kind]) - .joined(separator: " ") - try db.execute( - sql: "INSERT INTO memory_fts(namespace, record_id, content) VALUES (?, ?, ?);", - arguments: [record.namespace, record.id, ftsContent] - ) - } - - private func makeRecord( - from row: MemoryRecordRow, - namespace: String - ) throws -> MemoryRecord { - MemoryRecord( - id: row.id, - namespace: namespace, - scope: MemoryScope(rawValue: row.scope), - kind: row.kind, - summary: row.summary, - evidence: try codec.decode([String].self, from: row.evidenceJSON), - importance: row.importance, - createdAt: Date(timeIntervalSince1970: row.createdAt), - observedAt: row.observedAt.map(Date.init(timeIntervalSince1970:)), - expiresAt: row.expiresAt.map(Date.init(timeIntervalSince1970:)), - tags: try codec.decode([String].self, from: row.tagsJSON), - relatedIDs: try codec.decode([String].self, from: row.relatedIDsJSON), - dedupeKey: row.dedupeKey, - isPinned: row.isPinned, - attributes: try codec.decodeNullable(JSONValue.self, from: row.attributesJSON), - status: MemoryRecordStatus(rawValue: row.status) ?? .active - ) - } - - private func recordExists( - id: String, - namespace: String, - in db: Database - ) throws -> Bool { - try MemoryRecordRow - .filter(Column("namespace") == namespace) - .filter(Column("id") == id) - .fetchCount(db) > 0 - } - - private func recordExists( - dedupeKey: String, - namespace: String, - in db: Database - ) throws -> Bool { - try MemoryRecordRow - .filter(Column("namespace") == namespace) - .filter(Column("dedupe_key") == dedupeKey) - .fetchCount(db) > 0 - } -} - -private struct MemoryUserVersionQuery: Sendable { - func execute(in db: Database) throws -> Int { - // PRAGMA is SQLite-specific and doesn't map cleanly to GRDB's query interface. - let row = try SQLRequest(sql: "PRAGMA user_version;").fetchOne(db) - return row?[0] ?? 0 - } -} - -private struct MemoryFTSScoreRowsRequest: Sendable { - let namespace: String - let matchQuery: String - - func execute(in db: Database) throws -> [Row] { - // FTS5 MATCH and bm25() are much clearer and more direct in raw SQL than in GRDB's query interface. - try SQLRequest( - sql: """ - SELECT record_id, bm25(memory_fts) AS score - FROM memory_fts - WHERE namespace = ? AND memory_fts MATCH ?; - """, - arguments: [namespace, matchQuery] - ).fetchAll(db) - } -} - -private struct MemoryRecordRow: FetchableRecord, TableRecord { - static let databaseTableName = "memory_records" - - let id: String - let scope: String - let kind: String - let summary: String - let evidenceJSON: String - let importance: Double - let createdAt: Double - let observedAt: Double? - let expiresAt: Double? - let tagsJSON: String - let relatedIDsJSON: String - let dedupeKey: String? - let isPinned: Bool - let attributesJSON: String? - let status: String - - init(row: Row) { - id = row["id"] - scope = row["scope"] - kind = row["kind"] - summary = row["summary"] - evidenceJSON = row["evidence_json"] - importance = row["importance"] - createdAt = row["created_at"] - observedAt = row["observed_at"] - expiresAt = row["expires_at"] - tagsJSON = row["tags_json"] - relatedIDsJSON = row["related_ids_json"] - dedupeKey = row["dedupe_key"] - isPinned = row["is_pinned"] as Bool? ?? false - attributesJSON = row["attributes_json"] - status = row["status"] - } -} - public actor SQLiteMemoryStore: MemoryStoring { private let url: URL private let dbQueue: DatabaseQueue @@ -545,7 +118,7 @@ public actor SQLiteMemoryStore: MemoryStoring { return try await dbQueue.read { db in try repository.loadRecords(namespace: query.namespace, in: db) .filter { record in - if !query.includeArchived, record.status == .archived { + if !query.includeArchived, record.status == MemoryRecordStatus.archived { return false } if !query.scopes.isEmpty, !query.scopes.contains(record.scope) { @@ -580,10 +153,10 @@ public actor SQLiteMemoryStore: MemoryStoring { implementation: "sqlite", schemaVersion: schemaVersion, totalRecords: records.count, - activeRecords: records.filter { $0.status == .active }.count, - archivedRecords: records.filter { $0.status == .archived }.count, - countsByScope: Dictionary(grouping: records, by: \.scope).mapValues(\.count), - countsByKind: Dictionary(grouping: records, by: \.kind).mapValues(\.count) + activeRecords: records.filter { $0.status == MemoryRecordStatus.active }.count, + archivedRecords: records.filter { $0.status == MemoryRecordStatus.archived }.count, + countsByScope: Dictionary(grouping: records, by: { $0.scope }).mapValues { $0.count }, + countsByKind: Dictionary(grouping: records, by: { $0.kind }).mapValues { $0.count } ) } @@ -638,14 +211,19 @@ public actor SQLiteMemoryStore: MemoryStoring { try MemoryQueryEngine.validateNamespace(namespace) let repository = self.repository let expiredIDs = try await dbQueue.read { db in - try repository.loadRecords(namespace: namespace, in: db) - .filter { record in - !record.isPinned && - record.status == .active && - (record.expiresAt?.compare(now) == .orderedAscending || - record.expiresAt?.compare(now) == .orderedSame) + let records = try repository.loadRecords(namespace: namespace, in: db) + return records.compactMap { record -> String? in + guard !record.isPinned else { + return nil } - .map(\.id) + guard record.status == MemoryRecordStatus.active else { + return nil + } + guard let expiresAt = record.expiresAt, expiresAt <= now else { + return nil + } + return record.id + } } try await writeTransaction { db in diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift b/Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift new file mode 100644 index 0000000..db2d02f --- /dev/null +++ b/Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift @@ -0,0 +1,328 @@ +import Foundation +import GRDB + +struct SQLiteMemoryStoreCodec: Sendable { + private let encoder = JSONEncoder() + private let decoder = JSONDecoder() + + func encode(_ value: T) throws -> String { + let data = try encoder.encode(value) + return String(decoding: data, as: UTF8.self) + } + + func encodeNullable(_ value: T?) throws -> String? { + guard let value else { + return nil + } + return try encode(value) + } + + func decode(_ type: T.Type, from string: String) throws -> T { + try decoder.decode(type, from: Data(string.utf8)) + } + + func decodeNullable(_ type: T.Type, from string: String?) throws -> T? { + guard let string else { + return nil + } + return try decode(type, from: string) + } + + func makeFTSQuery(from value: String?) -> String { + let tokens = MemoryQueryEngine.tokenize(value) + guard !tokens.isEmpty else { + return "" + } + return tokens.joined(separator: " OR ") + } +} + +struct SQLiteMemoryStoreRepository: Sendable { + let codec: SQLiteMemoryStoreCodec + + func ensureRecordIDAvailable( + _ id: String, + namespace: String, + in db: Database + ) throws { + if try recordExists(id: id, namespace: namespace, in: db) { + throw MemoryStoreError.duplicateRecordID(id) + } + } + + func ensureDedupeKeyAvailable( + _ dedupeKey: String, + namespace: String, + in db: Database + ) throws { + if try recordExists(dedupeKey: dedupeKey, namespace: namespace, in: db) { + throw MemoryStoreError.duplicateDedupeKey(dedupeKey) + } + } + + func loadRecords( + namespace: String, + in db: Database + ) throws -> [MemoryRecord] { + let rows = try MemoryRecordRow + .filter(Column("namespace") == namespace) + .fetchAll(db) + + return try rows.map { row in + try makeRecord(from: row, namespace: namespace) + } + } + + func loadRawFTSScores( + namespace: String, + queryText: String?, + in db: Database + ) throws -> [String: Double] { + let matchQuery = codec.makeFTSQuery(from: queryText) + guard !matchQuery.isEmpty else { + return [:] + } + + let rows = try MemoryFTSScoreRowsRequest( + namespace: namespace, + matchQuery: matchQuery + ).execute(in: db) + + var scores: [String: Double] = [:] + for row in rows { + let recordID: String = row["record_id"] + let score: Double = row["score"] + scores[recordID] = score + } + return scores + } + + func loadRecord( + id: String, + namespace: String, + in db: Database + ) throws -> MemoryRecord? { + let row = try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("id") == id) + .fetchOne(db) + guard let row else { + return nil + } + return try makeRecord(from: row, namespace: namespace) + } + + func archiveRecord( + id: String, + namespace: String, + in db: Database + ) throws { + try db.execute( + sql: """ + UPDATE memory_records + SET status = ? + WHERE namespace = ? AND id = ?; + """, + arguments: [MemoryRecordStatus.archived.rawValue, namespace, id] + ) + } + + func deleteRecord( + id: String, + namespace: String, + in db: Database + ) throws { + try db.execute( + sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", + arguments: [namespace, id] + ) + try db.execute( + sql: "DELETE FROM memory_records WHERE namespace = ? AND id = ?;", + arguments: [namespace, id] + ) + } + + func deleteRecord( + withDedupeKey dedupeKey: String, + namespace: String, + in db: Database + ) throws { + if let row = try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("dedupe_key") == dedupeKey) + .fetchOne(db) { + try deleteRecord(id: row.id, namespace: namespace, in: db) + } + } + + func upsertRecord( + _ record: MemoryRecord, + in db: Database + ) throws { + try db.execute( + sql: """ + INSERT OR REPLACE INTO memory_records ( + namespace, id, scope, kind, summary, evidence_json, importance, + created_at, observed_at, expires_at, tags_json, related_ids_json, + dedupe_key, is_pinned, attributes_json, status + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + """, + arguments: [ + record.namespace, + record.id, + record.scope.rawValue, + record.kind, + record.summary, + try codec.encode(record.evidence), + record.importance, + record.createdAt.timeIntervalSince1970, + record.observedAt?.timeIntervalSince1970, + record.expiresAt?.timeIntervalSince1970, + try codec.encode(record.tags), + try codec.encode(record.relatedIDs), + record.dedupeKey, + record.isPinned ? 1 : 0, + try codec.encodeNullable(record.attributes), + record.status.rawValue, + ] + ) + + try db.execute( + sql: "DELETE FROM memory_tags WHERE namespace = ? AND record_id = ?;", + arguments: [record.namespace, record.id] + ) + try db.execute( + sql: "DELETE FROM memory_related_ids WHERE namespace = ? AND record_id = ?;", + arguments: [record.namespace, record.id] + ) + try db.execute( + sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", + arguments: [record.namespace, record.id] + ) + + for tag in record.tags { + try db.execute( + sql: "INSERT INTO memory_tags(namespace, record_id, tag) VALUES (?, ?, ?);", + arguments: [record.namespace, record.id, tag] + ) + } + + for relatedID in record.relatedIDs { + try db.execute( + sql: """ + INSERT INTO memory_related_ids(namespace, record_id, related_id) + VALUES (?, ?, ?); + """, + arguments: [record.namespace, record.id, relatedID] + ) + } + + let ftsContent = ([record.summary] + record.evidence + record.tags + [record.kind]) + .joined(separator: " ") + try db.execute( + sql: "INSERT INTO memory_fts(namespace, record_id, content) VALUES (?, ?, ?);", + arguments: [record.namespace, record.id, ftsContent] + ) + } + + private func makeRecord( + from row: MemoryRecordRow, + namespace: String + ) throws -> MemoryRecord { + MemoryRecord( + id: row.id, + namespace: namespace, + scope: MemoryScope(rawValue: row.scope), + kind: row.kind, + summary: row.summary, + evidence: try codec.decode([String].self, from: row.evidenceJSON), + importance: row.importance, + createdAt: Date(timeIntervalSince1970: row.createdAt), + observedAt: row.observedAt.map(Date.init(timeIntervalSince1970:)), + expiresAt: row.expiresAt.map(Date.init(timeIntervalSince1970:)), + tags: try codec.decode([String].self, from: row.tagsJSON), + relatedIDs: try codec.decode([String].self, from: row.relatedIDsJSON), + dedupeKey: row.dedupeKey, + isPinned: row.isPinned, + attributes: try codec.decodeNullable(JSONValue.self, from: row.attributesJSON), + status: MemoryRecordStatus(rawValue: row.status) ?? .active + ) + } + + private func recordExists( + id: String, + namespace: String, + in db: Database + ) throws -> Bool { + try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("id") == id) + .fetchCount(db) > 0 + } + + private func recordExists( + dedupeKey: String, + namespace: String, + in db: Database + ) throws -> Bool { + try MemoryRecordRow + .filter(Column("namespace") == namespace) + .filter(Column("dedupe_key") == dedupeKey) + .fetchCount(db) > 0 + } +} + +struct MemoryFTSScoreRowsRequest: Sendable { + let namespace: String + let matchQuery: String + + func execute(in db: Database) throws -> [Row] { + // FTS5 MATCH and bm25() are much clearer and more direct in raw SQL than in GRDB's query interface. + try SQLRequest( + sql: """ + SELECT record_id, bm25(memory_fts) AS score + FROM memory_fts + WHERE namespace = ? AND memory_fts MATCH ?; + """, + arguments: [namespace, matchQuery] + ).fetchAll(db) + } +} + +struct MemoryRecordRow: FetchableRecord, TableRecord { + static let databaseTableName = "memory_records" + + let id: String + let scope: String + let kind: String + let summary: String + let evidenceJSON: String + let importance: Double + let createdAt: Double + let observedAt: Double? + let expiresAt: Double? + let tagsJSON: String + let relatedIDsJSON: String + let dedupeKey: String? + let isPinned: Bool + let attributesJSON: String? + let status: String + + init(row: Row) { + id = row["id"] + scope = row["scope"] + kind = row["kind"] + summary = row["summary"] + evidenceJSON = row["evidence_json"] + importance = row["importance"] + createdAt = row["created_at"] + observedAt = row["observed_at"] + expiresAt = row["expires_at"] + tagsJSON = row["tags_json"] + relatedIDsJSON = row["related_ids_json"] + dedupeKey = row["dedupe_key"] + isPinned = row["is_pinned"] as Bool? ?? false + attributesJSON = row["attributes_json"] + status = row["status"] + } +} diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift b/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift new file mode 100644 index 0000000..208d9f1 --- /dev/null +++ b/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift @@ -0,0 +1,103 @@ +import Foundation +import GRDB + +struct SQLiteMemoryStoreSchema: Sendable { + let currentVersion = 1 + + func existingVersion(in db: Database) throws -> Int { + try MemoryUserVersionQuery().execute(in: db) + } + + func makeMigrator() -> DatabaseMigrator { + var migrator = DatabaseMigrator() + + migrator.registerMigration("memory_store_v1") { db in + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_records ( + namespace TEXT NOT NULL, + id TEXT NOT NULL, + scope TEXT NOT NULL, + kind TEXT NOT NULL, + summary TEXT NOT NULL, + evidence_json TEXT NOT NULL, + importance REAL NOT NULL, + created_at REAL NOT NULL, + observed_at REAL, + expires_at REAL, + tags_json TEXT NOT NULL, + related_ids_json TEXT NOT NULL, + dedupe_key TEXT, + is_pinned INTEGER NOT NULL, + attributes_json TEXT, + status TEXT NOT NULL, + PRIMARY KEY(namespace, id) + ); + """) + + try db.execute(sql: """ + CREATE UNIQUE INDEX IF NOT EXISTS memory_records_namespace_dedupe + ON memory_records(namespace, dedupe_key) + WHERE dedupe_key IS NOT NULL; + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_scope + ON memory_records(namespace, scope); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_kind + ON memory_records(namespace, kind); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_records_namespace_status + ON memory_records(namespace, status); + """) + + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_tags ( + namespace TEXT NOT NULL, + record_id TEXT NOT NULL, + tag TEXT NOT NULL, + FOREIGN KEY(namespace, record_id) + REFERENCES memory_records(namespace, id) + ON DELETE CASCADE + ); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_tags_lookup + ON memory_tags(namespace, tag, record_id); + """) + + try db.execute(sql: """ + CREATE TABLE IF NOT EXISTS memory_related_ids ( + namespace TEXT NOT NULL, + record_id TEXT NOT NULL, + related_id TEXT NOT NULL, + FOREIGN KEY(namespace, record_id) + REFERENCES memory_records(namespace, id) + ON DELETE CASCADE + ); + """) + try db.execute(sql: """ + CREATE INDEX IF NOT EXISTS memory_related_lookup + ON memory_related_ids(namespace, related_id, record_id); + """) + + try db.execute(sql: """ + CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts + USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); + """) + + try db.execute(sql: "PRAGMA user_version = \(currentVersion)") + } + + return migrator + } +} + +struct MemoryUserVersionQuery: Sendable { + func execute(in db: Database) throws -> Int { + // PRAGMA is SQLite-specific and doesn't map cleanly to GRDB's query interface. + let row = try SQLRequest(sql: "PRAGMA user_version;").fetchOne(db) + return row?[0] ?? 0 + } +} diff --git a/Sources/CodexKit/Runtime/AgentHistory+Serialization.swift b/Sources/CodexKit/Runtime/AgentHistory+Serialization.swift new file mode 100644 index 0000000..aff356e --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentHistory+Serialization.swift @@ -0,0 +1,314 @@ +import Foundation + +extension AgentThreadPendingState: Codable { + private enum CodingKeys: String, CodingKey { + case kind + case approval + case userInput + case toolWait + } + + private enum Kind: String, Codable { + case approval + case userInput + case toolWait + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + switch try container.decode(Kind.self, forKey: .kind) { + case .approval: + self = .approval(try container.decode(AgentPendingApprovalState.self, forKey: .approval)) + case .userInput: + self = .userInput(try container.decode(AgentPendingUserInputState.self, forKey: .userInput)) + case .toolWait: + self = .toolWait(try container.decode(AgentPendingToolWaitState.self, forKey: .toolWait)) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case let .approval(state): + try container.encode(Kind.approval, forKey: .kind) + try container.encode(state, forKey: .approval) + case let .userInput(state): + try container.encode(Kind.userInput, forKey: .kind) + try container.encode(state, forKey: .userInput) + case let .toolWait(state): + try container.encode(Kind.toolWait, forKey: .kind) + try container.encode(state, forKey: .toolWait) + } + } +} + +extension AgentHistoryItem: Codable { + private enum CodingKeys: String, CodingKey { + case kind + case message + case toolCall + case toolResult + case structuredOutput + case approval + case systemEvent + } + + private enum Kind: String, Codable { + case message + case toolCall + case toolResult + case structuredOutput + case approval + case systemEvent + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + switch try container.decode(Kind.self, forKey: .kind) { + case .message: + self = .message(try container.decode(AgentMessage.self, forKey: .message)) + case .toolCall: + self = .toolCall(try container.decode(AgentToolCallRecord.self, forKey: .toolCall)) + case .toolResult: + self = .toolResult(try container.decode(AgentToolResultRecord.self, forKey: .toolResult)) + case .structuredOutput: + self = .structuredOutput(try container.decode(AgentStructuredOutputRecord.self, forKey: .structuredOutput)) + case .approval: + self = .approval(try container.decode(AgentApprovalRecord.self, forKey: .approval)) + case .systemEvent: + self = .systemEvent(try container.decode(AgentSystemEventRecord.self, forKey: .systemEvent)) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + switch self { + case let .message(message): + try container.encode(Kind.message, forKey: .kind) + try container.encode(message, forKey: .message) + case let .toolCall(record): + try container.encode(Kind.toolCall, forKey: .kind) + try container.encode(record, forKey: .toolCall) + case let .toolResult(record): + try container.encode(Kind.toolResult, forKey: .kind) + try container.encode(record, forKey: .toolResult) + case let .structuredOutput(record): + try container.encode(Kind.structuredOutput, forKey: .kind) + try container.encode(record, forKey: .structuredOutput) + case let .approval(record): + try container.encode(Kind.approval, forKey: .kind) + try container.encode(record, forKey: .approval) + case let .systemEvent(record): + try container.encode(Kind.systemEvent, forKey: .kind) + try container.encode(record, forKey: .systemEvent) + } + } +} + +extension AgentHistoryFilter { + func matches(_ item: AgentHistoryItem) -> Bool { + switch item { + case .message: + return includeMessages + case .toolCall: + return includeToolCalls + case .toolResult: + return includeToolResults + case .structuredOutput: + return includeStructuredOutputs + case .approval: + return includeApprovals + case let .systemEvent(record): + if record.type == .contextCompacted { + return includeSystemEvents && includeCompactionEvents + } + return includeSystemEvents + } + } +} + +extension AgentHistoryItem { + var kind: AgentHistoryItemKind { + switch self { + case .message: + .message + case .toolCall: + .toolCall + case .toolResult: + .toolResult + case .structuredOutput: + .structuredOutput + case .approval: + .approval + case .systemEvent: + .systemEvent + } + } + + var turnID: String? { + switch self { + case let .message(message): + return message.structuredOutput == nil ? nil : nil + case let .toolCall(record): + return record.invocation.turnID + case let .toolResult(record): + return record.turnID + case let .structuredOutput(record): + return record.turnID + case let .approval(record): + return record.request?.turnID ?? record.resolution?.turnID + case let .systemEvent(record): + return record.turnID + } + } + + var defaultRecordID: String { + switch self { + case let .message(message): + return "message:\(message.id)" + case let .toolCall(record): + return "toolCall:\(record.invocation.id)" + case let .toolResult(record): + return "toolResult:\(record.result.invocationID)" + case let .structuredOutput(record): + return "structuredOutput:\(record.messageID ?? record.turnID)" + case let .approval(record): + return "approval:\(record.request?.id ?? record.resolution?.requestID ?? UUID().uuidString)" + case let .systemEvent(record): + if record.type == .contextCompacted, + let generation = record.compaction?.generation { + return "systemEvent:\(record.type.rawValue):\(record.threadID):\(generation)" + } + return "systemEvent:\(record.type.rawValue):\(record.turnID ?? record.threadID)" + } + } + + public var isCompactionMarker: Bool { + guard case let .systemEvent(record) = self else { + return false + } + return record.type == .contextCompacted + } +} + +extension AgentHistoryRecord { + func redacted(reason: AgentRedactionReason?) -> AgentHistoryRecord { + AgentHistoryRecord( + id: id, + sequenceNumber: sequenceNumber, + createdAt: createdAt, + item: item.redactedPayload(), + redaction: AgentHistoryRedaction(reason: reason) + ) + } +} + +private extension AgentHistoryItem { + func redactedPayload() -> AgentHistoryItem { + switch self { + case let .message(message): + return .message( + AgentMessage( + id: message.id, + threadID: message.threadID, + role: message.role, + text: "[Redacted]", + images: [], + structuredOutput: message.structuredOutput.map { + AgentStructuredOutputMetadata( + formatName: $0.formatName, + payload: .object(["redacted": .bool(true)]) + ) + }, + createdAt: message.createdAt + ) + ) + + case let .toolCall(record): + return .toolCall( + AgentToolCallRecord( + invocation: ToolInvocation( + id: record.invocation.id, + threadID: record.invocation.threadID, + turnID: record.invocation.turnID, + toolName: record.invocation.toolName, + arguments: .object(["redacted": .bool(true)]) + ), + requestedAt: record.requestedAt + ) + ) + + case let .toolResult(record): + return .toolResult( + AgentToolResultRecord( + threadID: record.threadID, + turnID: record.turnID, + result: ToolResultEnvelope( + invocationID: record.result.invocationID, + toolName: record.result.toolName, + success: record.result.success, + content: [.text("[Redacted]")], + errorMessage: record.result.errorMessage == nil ? nil : "[Redacted]", + session: record.result.session + ), + completedAt: record.completedAt + ) + ) + + case let .structuredOutput(record): + return .structuredOutput( + AgentStructuredOutputRecord( + threadID: record.threadID, + turnID: record.turnID, + messageID: record.messageID, + metadata: AgentStructuredOutputMetadata( + formatName: record.metadata.formatName, + payload: .object(["redacted": .bool(true)]) + ), + committedAt: record.committedAt + ) + ) + + case let .approval(record): + let request = record.request.map { + ApprovalRequest( + id: $0.id, + threadID: $0.threadID, + turnID: $0.turnID, + toolInvocation: ToolInvocation( + id: $0.toolInvocation.id, + threadID: $0.toolInvocation.threadID, + turnID: $0.toolInvocation.turnID, + toolName: $0.toolInvocation.toolName, + arguments: .object(["redacted": .bool(true)]) + ), + title: "[Redacted]", + message: "[Redacted]" + ) + } + return .approval( + AgentApprovalRecord( + kind: record.kind, + request: request, + resolution: record.resolution, + occurredAt: record.occurredAt + ) + ) + + case let .systemEvent(record): + return .systemEvent( + AgentSystemEventRecord( + type: record.type, + threadID: record.threadID, + turnID: record.turnID, + status: record.status, + turnSummary: record.turnSummary, + error: record.error.map { + AgentRuntimeError(code: $0.code, message: "[Redacted]") + }, + occurredAt: record.occurredAt + ) + ) + } + } +} diff --git a/Sources/CodexKit/Runtime/AgentHistory.swift b/Sources/CodexKit/Runtime/AgentHistory.swift index e2ce06f..e27d831 100644 --- a/Sources/CodexKit/Runtime/AgentHistory.swift +++ b/Sources/CodexKit/Runtime/AgentHistory.swift @@ -197,583 +197,3 @@ public extension AgentThreadSummary { ) } } - -public enum AgentHistoryItem: Hashable, Sendable { - case message(AgentMessage) - case toolCall(AgentToolCallRecord) - case toolResult(AgentToolResultRecord) - case structuredOutput(AgentStructuredOutputRecord) - case approval(AgentApprovalRecord) - case systemEvent(AgentSystemEventRecord) -} - -public struct AgentToolCallRecord: Codable, Hashable, Sendable { - public let invocation: ToolInvocation - public let requestedAt: Date - - public init( - invocation: ToolInvocation, - requestedAt: Date = Date() - ) { - self.invocation = invocation - self.requestedAt = requestedAt - } -} - -public struct AgentToolResultRecord: Codable, Hashable, Sendable { - public let threadID: String - public let turnID: String - public let result: ToolResultEnvelope - public let completedAt: Date - - public init( - threadID: String, - turnID: String, - result: ToolResultEnvelope, - completedAt: Date = Date() - ) { - self.threadID = threadID - self.turnID = turnID - self.result = result - self.completedAt = completedAt - } -} - -public struct AgentStructuredOutputRecord: Codable, Hashable, Sendable { - public let threadID: String - public let turnID: String - public let messageID: String? - public let metadata: AgentStructuredOutputMetadata - public let committedAt: Date - - public init( - threadID: String, - turnID: String, - messageID: String? = nil, - metadata: AgentStructuredOutputMetadata, - committedAt: Date = Date() - ) { - self.threadID = threadID - self.turnID = turnID - self.messageID = messageID - self.metadata = metadata - self.committedAt = committedAt - } -} - -public enum AgentApprovalEventKind: String, Codable, Hashable, Sendable { - case requested - case resolved -} - -public struct AgentApprovalRecord: Codable, Hashable, Sendable { - public let kind: AgentApprovalEventKind - public let request: ApprovalRequest? - public let resolution: ApprovalResolution? - public let occurredAt: Date - - public init( - kind: AgentApprovalEventKind, - request: ApprovalRequest? = nil, - resolution: ApprovalResolution? = nil, - occurredAt: Date = Date() - ) { - self.kind = kind - self.request = request - self.resolution = resolution - self.occurredAt = occurredAt - } -} - -public enum AgentSystemEventType: String, Codable, Hashable, Sendable { - case threadCreated - case threadResumed - case threadStatusChanged - case turnStarted - case turnCompleted - case turnFailed - case contextCompacted -} - -public struct AgentSystemEventRecord: Codable, Hashable, Sendable { - public let type: AgentSystemEventType - public let threadID: String - public let turnID: String? - public let status: AgentThreadStatus? - public let turnSummary: AgentTurnSummary? - public let error: AgentRuntimeError? - public let compaction: AgentContextCompactionMarker? - public let occurredAt: Date - - public init( - type: AgentSystemEventType, - threadID: String, - turnID: String? = nil, - status: AgentThreadStatus? = nil, - turnSummary: AgentTurnSummary? = nil, - error: AgentRuntimeError? = nil, - compaction: AgentContextCompactionMarker? = nil, - occurredAt: Date = Date() - ) { - self.type = type - self.threadID = threadID - self.turnID = turnID - self.status = status - self.turnSummary = turnSummary - self.error = error - self.compaction = compaction - self.occurredAt = occurredAt - } -} - -public enum AgentThreadPendingState: Hashable, Sendable { - case approval(AgentPendingApprovalState) - case userInput(AgentPendingUserInputState) - case toolWait(AgentPendingToolWaitState) -} - -public struct AgentPendingApprovalState: Codable, Hashable, Sendable { - public let request: ApprovalRequest - public let requestedAt: Date - - public init( - request: ApprovalRequest, - requestedAt: Date = Date() - ) { - self.request = request - self.requestedAt = requestedAt - } -} - -public struct AgentPendingUserInputState: Codable, Hashable, Sendable { - public let requestID: String - public let turnID: String - public let title: String - public let message: String - public let requestedAt: Date - - public init( - requestID: String, - turnID: String, - title: String, - message: String, - requestedAt: Date = Date() - ) { - self.requestID = requestID - self.turnID = turnID - self.title = title - self.message = message - self.requestedAt = requestedAt - } -} - -public struct AgentPendingToolWaitState: Codable, Hashable, Sendable { - public let invocationID: String - public let turnID: String - public let toolName: String - public let startedAt: Date - public let sessionID: String? - public let sessionStatus: String? - public let metadata: JSONValue? - public let resumable: Bool - - public init( - invocationID: String, - turnID: String, - toolName: String, - startedAt: Date = Date(), - sessionID: String? = nil, - sessionStatus: String? = nil, - metadata: JSONValue? = nil, - resumable: Bool = false - ) { - self.invocationID = invocationID - self.turnID = turnID - self.toolName = toolName - self.startedAt = startedAt - self.sessionID = sessionID - self.sessionStatus = sessionStatus - self.metadata = metadata - self.resumable = resumable - } -} - -public enum AgentToolSessionStatus: String, Codable, Hashable, Sendable { - case waiting - case running - case completed - case failed - case denied -} - -public struct AgentLatestToolState: Codable, Hashable, Sendable { - public let invocationID: String - public let turnID: String - public let toolName: String - public let status: AgentToolSessionStatus - public let success: Bool? - public let sessionID: String? - public let sessionStatus: String? - public let metadata: JSONValue? - public let resumable: Bool - public let updatedAt: Date - public let resultPreview: String? - - public init( - invocationID: String, - turnID: String, - toolName: String, - status: AgentToolSessionStatus, - success: Bool? = nil, - sessionID: String? = nil, - sessionStatus: String? = nil, - metadata: JSONValue? = nil, - resumable: Bool = false, - updatedAt: Date = Date(), - resultPreview: String? = nil - ) { - self.invocationID = invocationID - self.turnID = turnID - self.toolName = toolName - self.status = status - self.success = success - self.sessionID = sessionID - self.sessionStatus = sessionStatus - self.metadata = metadata - self.resumable = resumable - self.updatedAt = updatedAt - self.resultPreview = resultPreview - } -} - -public struct AgentPartialStructuredOutputSnapshot: Codable, Hashable, Sendable { - public let turnID: String - public let formatName: String - public let payload: JSONValue - public let updatedAt: Date - - public init( - turnID: String, - formatName: String, - payload: JSONValue, - updatedAt: Date = Date() - ) { - self.turnID = turnID - self.formatName = formatName - self.payload = payload - self.updatedAt = updatedAt - } -} - -extension AgentThreadPendingState: Codable { - private enum CodingKeys: String, CodingKey { - case kind - case approval - case userInput - case toolWait - } - - private enum Kind: String, Codable { - case approval - case userInput - case toolWait - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - switch try container.decode(Kind.self, forKey: .kind) { - case .approval: - self = .approval(try container.decode(AgentPendingApprovalState.self, forKey: .approval)) - case .userInput: - self = .userInput(try container.decode(AgentPendingUserInputState.self, forKey: .userInput)) - case .toolWait: - self = .toolWait(try container.decode(AgentPendingToolWaitState.self, forKey: .toolWait)) - } - } - - public func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - switch self { - case let .approval(state): - try container.encode(Kind.approval, forKey: .kind) - try container.encode(state, forKey: .approval) - case let .userInput(state): - try container.encode(Kind.userInput, forKey: .kind) - try container.encode(state, forKey: .userInput) - case let .toolWait(state): - try container.encode(Kind.toolWait, forKey: .kind) - try container.encode(state, forKey: .toolWait) - } - } -} - -extension AgentHistoryItem: Codable { - private enum CodingKeys: String, CodingKey { - case kind - case message - case toolCall - case toolResult - case structuredOutput - case approval - case systemEvent - } - - private enum Kind: String, Codable { - case message - case toolCall - case toolResult - case structuredOutput - case approval - case systemEvent - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - switch try container.decode(Kind.self, forKey: .kind) { - case .message: - self = .message(try container.decode(AgentMessage.self, forKey: .message)) - case .toolCall: - self = .toolCall(try container.decode(AgentToolCallRecord.self, forKey: .toolCall)) - case .toolResult: - self = .toolResult(try container.decode(AgentToolResultRecord.self, forKey: .toolResult)) - case .structuredOutput: - self = .structuredOutput(try container.decode(AgentStructuredOutputRecord.self, forKey: .structuredOutput)) - case .approval: - self = .approval(try container.decode(AgentApprovalRecord.self, forKey: .approval)) - case .systemEvent: - self = .systemEvent(try container.decode(AgentSystemEventRecord.self, forKey: .systemEvent)) - } - } - - public func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - switch self { - case let .message(message): - try container.encode(Kind.message, forKey: .kind) - try container.encode(message, forKey: .message) - case let .toolCall(record): - try container.encode(Kind.toolCall, forKey: .kind) - try container.encode(record, forKey: .toolCall) - case let .toolResult(record): - try container.encode(Kind.toolResult, forKey: .kind) - try container.encode(record, forKey: .toolResult) - case let .structuredOutput(record): - try container.encode(Kind.structuredOutput, forKey: .kind) - try container.encode(record, forKey: .structuredOutput) - case let .approval(record): - try container.encode(Kind.approval, forKey: .kind) - try container.encode(record, forKey: .approval) - case let .systemEvent(record): - try container.encode(Kind.systemEvent, forKey: .kind) - try container.encode(record, forKey: .systemEvent) - } - } -} - -extension AgentHistoryFilter { - func matches(_ item: AgentHistoryItem) -> Bool { - switch item { - case .message: - return includeMessages - case .toolCall: - return includeToolCalls - case .toolResult: - return includeToolResults - case .structuredOutput: - return includeStructuredOutputs - case .approval: - return includeApprovals - case let .systemEvent(record): - if record.type == .contextCompacted { - return includeSystemEvents && includeCompactionEvents - } - return includeSystemEvents - } - } -} - -extension AgentHistoryItem { - var kind: AgentHistoryItemKind { - switch self { - case .message: - .message - case .toolCall: - .toolCall - case .toolResult: - .toolResult - case .structuredOutput: - .structuredOutput - case .approval: - .approval - case .systemEvent: - .systemEvent - } - } - - var turnID: String? { - switch self { - case let .message(message): - return message.structuredOutput == nil ? nil : nil - case let .toolCall(record): - return record.invocation.turnID - case let .toolResult(record): - return record.turnID - case let .structuredOutput(record): - return record.turnID - case let .approval(record): - return record.request?.turnID ?? record.resolution?.turnID - case let .systemEvent(record): - return record.turnID - } - } - - var defaultRecordID: String { - switch self { - case let .message(message): - return "message:\(message.id)" - case let .toolCall(record): - return "toolCall:\(record.invocation.id)" - case let .toolResult(record): - return "toolResult:\(record.result.invocationID)" - case let .structuredOutput(record): - return "structuredOutput:\(record.messageID ?? record.turnID)" - case let .approval(record): - return "approval:\(record.request?.id ?? record.resolution?.requestID ?? UUID().uuidString)" - case let .systemEvent(record): - if record.type == .contextCompacted, - let generation = record.compaction?.generation { - return "systemEvent:\(record.type.rawValue):\(record.threadID):\(generation)" - } - return "systemEvent:\(record.type.rawValue):\(record.turnID ?? record.threadID)" - } - } - - public var isCompactionMarker: Bool { - guard case let .systemEvent(record) = self else { - return false - } - return record.type == .contextCompacted - } -} - -extension AgentHistoryRecord { - func redacted(reason: AgentRedactionReason?) -> AgentHistoryRecord { - AgentHistoryRecord( - id: id, - sequenceNumber: sequenceNumber, - createdAt: createdAt, - item: item.redactedPayload(), - redaction: AgentHistoryRedaction(reason: reason) - ) - } -} - -private extension AgentHistoryItem { - func redactedPayload() -> AgentHistoryItem { - switch self { - case let .message(message): - return .message( - AgentMessage( - id: message.id, - threadID: message.threadID, - role: message.role, - text: "[Redacted]", - images: [], - structuredOutput: message.structuredOutput.map { - AgentStructuredOutputMetadata( - formatName: $0.formatName, - payload: .object(["redacted": .bool(true)]) - ) - }, - createdAt: message.createdAt - ) - ) - - case let .toolCall(record): - return .toolCall( - AgentToolCallRecord( - invocation: ToolInvocation( - id: record.invocation.id, - threadID: record.invocation.threadID, - turnID: record.invocation.turnID, - toolName: record.invocation.toolName, - arguments: .object(["redacted": .bool(true)]) - ), - requestedAt: record.requestedAt - ) - ) - - case let .toolResult(record): - return .toolResult( - AgentToolResultRecord( - threadID: record.threadID, - turnID: record.turnID, - result: ToolResultEnvelope( - invocationID: record.result.invocationID, - toolName: record.result.toolName, - success: record.result.success, - content: [.text("[Redacted]")], - errorMessage: record.result.errorMessage == nil ? nil : "[Redacted]", - session: record.result.session - ), - completedAt: record.completedAt - ) - ) - - case let .structuredOutput(record): - return .structuredOutput( - AgentStructuredOutputRecord( - threadID: record.threadID, - turnID: record.turnID, - messageID: record.messageID, - metadata: AgentStructuredOutputMetadata( - formatName: record.metadata.formatName, - payload: .object(["redacted": .bool(true)]) - ), - committedAt: record.committedAt - ) - ) - - case let .approval(record): - let request = record.request.map { - ApprovalRequest( - id: $0.id, - threadID: $0.threadID, - turnID: $0.turnID, - toolInvocation: ToolInvocation( - id: $0.toolInvocation.id, - threadID: $0.toolInvocation.threadID, - turnID: $0.toolInvocation.turnID, - toolName: $0.toolInvocation.toolName, - arguments: .object(["redacted": .bool(true)]) - ), - title: "[Redacted]", - message: "[Redacted]" - ) - } - return .approval( - AgentApprovalRecord( - kind: record.kind, - request: request, - resolution: record.resolution, - occurredAt: record.occurredAt - ) - ) - - case let .systemEvent(record): - return .systemEvent( - AgentSystemEventRecord( - type: record.type, - threadID: record.threadID, - turnID: record.turnID, - status: record.status, - turnSummary: record.turnSummary, - error: record.error.map { - AgentRuntimeError(code: $0.code, message: "[Redacted]") - }, - occurredAt: record.occurredAt - ) - ) - } - } -} diff --git a/Sources/CodexKit/Runtime/AgentHistoryItems.swift b/Sources/CodexKit/Runtime/AgentHistoryItems.swift new file mode 100644 index 0000000..c81fc7f --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentHistoryItems.swift @@ -0,0 +1,129 @@ +import Foundation + +public enum AgentHistoryItem: Hashable, Sendable { + case message(AgentMessage) + case toolCall(AgentToolCallRecord) + case toolResult(AgentToolResultRecord) + case structuredOutput(AgentStructuredOutputRecord) + case approval(AgentApprovalRecord) + case systemEvent(AgentSystemEventRecord) +} + +public struct AgentToolCallRecord: Codable, Hashable, Sendable { + public let invocation: ToolInvocation + public let requestedAt: Date + + public init( + invocation: ToolInvocation, + requestedAt: Date = Date() + ) { + self.invocation = invocation + self.requestedAt = requestedAt + } +} + +public struct AgentToolResultRecord: Codable, Hashable, Sendable { + public let threadID: String + public let turnID: String + public let result: ToolResultEnvelope + public let completedAt: Date + + public init( + threadID: String, + turnID: String, + result: ToolResultEnvelope, + completedAt: Date = Date() + ) { + self.threadID = threadID + self.turnID = turnID + self.result = result + self.completedAt = completedAt + } +} + +public struct AgentStructuredOutputRecord: Codable, Hashable, Sendable { + public let threadID: String + public let turnID: String + public let messageID: String? + public let metadata: AgentStructuredOutputMetadata + public let committedAt: Date + + public init( + threadID: String, + turnID: String, + messageID: String? = nil, + metadata: AgentStructuredOutputMetadata, + committedAt: Date = Date() + ) { + self.threadID = threadID + self.turnID = turnID + self.messageID = messageID + self.metadata = metadata + self.committedAt = committedAt + } +} + +public enum AgentApprovalEventKind: String, Codable, Hashable, Sendable { + case requested + case resolved +} + +public struct AgentApprovalRecord: Codable, Hashable, Sendable { + public let kind: AgentApprovalEventKind + public let request: ApprovalRequest? + public let resolution: ApprovalResolution? + public let occurredAt: Date + + public init( + kind: AgentApprovalEventKind, + request: ApprovalRequest? = nil, + resolution: ApprovalResolution? = nil, + occurredAt: Date = Date() + ) { + self.kind = kind + self.request = request + self.resolution = resolution + self.occurredAt = occurredAt + } +} + +public enum AgentSystemEventType: String, Codable, Hashable, Sendable { + case threadCreated + case threadResumed + case threadStatusChanged + case turnStarted + case turnCompleted + case turnFailed + case contextCompacted +} + +public struct AgentSystemEventRecord: Codable, Hashable, Sendable { + public let type: AgentSystemEventType + public let threadID: String + public let turnID: String? + public let status: AgentThreadStatus? + public let turnSummary: AgentTurnSummary? + public let error: AgentRuntimeError? + public let compaction: AgentContextCompactionMarker? + public let occurredAt: Date + + public init( + type: AgentSystemEventType, + threadID: String, + turnID: String? = nil, + status: AgentThreadStatus? = nil, + turnSummary: AgentTurnSummary? = nil, + error: AgentRuntimeError? = nil, + compaction: AgentContextCompactionMarker? = nil, + occurredAt: Date = Date() + ) { + self.type = type + self.threadID = threadID + self.turnID = turnID + self.status = status + self.turnSummary = turnSummary + self.error = error + self.compaction = compaction + self.occurredAt = occurredAt + } +} diff --git a/Sources/CodexKit/Runtime/AgentHistoryPendingState.swift b/Sources/CodexKit/Runtime/AgentHistoryPendingState.swift new file mode 100644 index 0000000..eb97e6a --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentHistoryPendingState.swift @@ -0,0 +1,140 @@ +import Foundation + +public enum AgentThreadPendingState: Hashable, Sendable { + case approval(AgentPendingApprovalState) + case userInput(AgentPendingUserInputState) + case toolWait(AgentPendingToolWaitState) +} + +public struct AgentPendingApprovalState: Codable, Hashable, Sendable { + public let request: ApprovalRequest + public let requestedAt: Date + + public init( + request: ApprovalRequest, + requestedAt: Date = Date() + ) { + self.request = request + self.requestedAt = requestedAt + } +} + +public struct AgentPendingUserInputState: Codable, Hashable, Sendable { + public let requestID: String + public let turnID: String + public let title: String + public let message: String + public let requestedAt: Date + + public init( + requestID: String, + turnID: String, + title: String, + message: String, + requestedAt: Date = Date() + ) { + self.requestID = requestID + self.turnID = turnID + self.title = title + self.message = message + self.requestedAt = requestedAt + } +} + +public struct AgentPendingToolWaitState: Codable, Hashable, Sendable { + public let invocationID: String + public let turnID: String + public let toolName: String + public let startedAt: Date + public let sessionID: String? + public let sessionStatus: String? + public let metadata: JSONValue? + public let resumable: Bool + + public init( + invocationID: String, + turnID: String, + toolName: String, + startedAt: Date = Date(), + sessionID: String? = nil, + sessionStatus: String? = nil, + metadata: JSONValue? = nil, + resumable: Bool = false + ) { + self.invocationID = invocationID + self.turnID = turnID + self.toolName = toolName + self.startedAt = startedAt + self.sessionID = sessionID + self.sessionStatus = sessionStatus + self.metadata = metadata + self.resumable = resumable + } +} + +public enum AgentToolSessionStatus: String, Codable, Hashable, Sendable { + case waiting + case running + case completed + case failed + case denied +} + +public struct AgentLatestToolState: Codable, Hashable, Sendable { + public let invocationID: String + public let turnID: String + public let toolName: String + public let status: AgentToolSessionStatus + public let success: Bool? + public let sessionID: String? + public let sessionStatus: String? + public let metadata: JSONValue? + public let resumable: Bool + public let updatedAt: Date + public let resultPreview: String? + + public init( + invocationID: String, + turnID: String, + toolName: String, + status: AgentToolSessionStatus, + success: Bool? = nil, + sessionID: String? = nil, + sessionStatus: String? = nil, + metadata: JSONValue? = nil, + resumable: Bool = false, + updatedAt: Date = Date(), + resultPreview: String? = nil + ) { + self.invocationID = invocationID + self.turnID = turnID + self.toolName = toolName + self.status = status + self.success = success + self.sessionID = sessionID + self.sessionStatus = sessionStatus + self.metadata = metadata + self.resumable = resumable + self.updatedAt = updatedAt + self.resultPreview = resultPreview + } +} + +public struct AgentPartialStructuredOutputSnapshot: Codable, Hashable, Sendable { + public let turnID: String + public let formatName: String + public let payload: JSONValue + public let updatedAt: Date + + public init( + turnID: String, + formatName: String, + payload: JSONValue, + updatedAt: Date = Date() + ) { + self.turnID = turnID + self.formatName = formatName + self.payload = payload + self.updatedAt = updatedAt + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift b/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift index 8a41056..7dfa20b 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift @@ -48,67 +48,72 @@ extension AgentRuntime { throw AgentRuntimeError.threadNotFound(threadID) } - let session = try await sessionManager.requireSession() let userMessage = AgentMessage( threadID: threadID, role: .user, text: request.text, images: request.images ) - let resolvedTurnSkills = try resolveTurnSkills( - thread: thread, - message: request - ) - let resolvedInstructions = await resolveInstructions( - thread: thread, - message: request, - resolvedTurnSkills: resolvedTurnSkills - ) try await appendMessage(userMessage) try await setThreadStatus(.streaming, for: threadID) - let tools = await toolRegistry.allDefinitions() - try await maybeCompactThreadContextBeforeTurn( - thread: thread, - request: request, - instructions: resolvedInstructions, - tools: tools, - session: session - ) - let turnStart = try await beginTurnWithUnauthorizedRecovery( - thread: thread, - history: effectiveHistory(for: threadID), - message: request, - instructions: resolvedInstructions, - responseFormat: nil, - streamedStructuredOutput: AgentStreamedStructuredOutputRequest( - responseFormat: responseFormat, - options: options - ), - tools: tools, - session: session - ) - let turnStream = turnStart.turnStream - let turnSession = turnStart.session - return AsyncThrowingStream { continuation in continuation.yield(.messageCommitted(userMessage)) continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) Task { - await self.consumeStructuredTurnStream( - turnStream, - for: threadID, - userMessage: userMessage, - session: turnSession, - resolvedTurnSkills: resolvedTurnSkills, - responseFormat: responseFormat, - options: options, - decoder: decoder, - outputType: outputType, - continuation: continuation - ) + do { + let session = try await self.sessionManager.requireSession() + let resolvedTurnSkills = try self.resolveTurnSkills( + thread: thread, + message: request + ) + let resolvedInstructions = await self.resolveInstructions( + thread: thread, + message: request, + resolvedTurnSkills: resolvedTurnSkills + ) + let tools = await self.toolRegistry.allDefinitions() + try await self.maybeCompactThreadContextBeforeTurn( + thread: thread, + request: request, + instructions: resolvedInstructions, + tools: tools, + session: session + ) + let turnStart = try await self.beginTurnWithUnauthorizedRecovery( + thread: thread, + history: self.effectiveHistory(for: threadID), + message: request, + instructions: resolvedInstructions, + responseFormat: nil, + streamedStructuredOutput: AgentStreamedStructuredOutputRequest( + responseFormat: responseFormat, + options: options + ), + tools: tools, + session: session + ) + await self.consumeStructuredTurnStream( + turnStart.turnStream, + for: threadID, + userMessage: userMessage, + session: turnStart.session, + resolvedTurnSkills: resolvedTurnSkills, + responseFormat: responseFormat, + options: options, + decoder: decoder, + outputType: outputType, + continuation: continuation + ) + } catch { + await self.handleStructuredTurnStartupFailure( + error, + for: threadID, + continuation: continuation + ) + } } } } @@ -182,60 +187,65 @@ extension AgentRuntime { throw AgentRuntimeError.threadNotFound(threadID) } - let session = try await sessionManager.requireSession() let userMessage = AgentMessage( threadID: threadID, role: .user, text: request.text, images: request.images ) - let resolvedTurnSkills = try resolveTurnSkills( - thread: thread, - message: request - ) - let resolvedInstructions = await resolveInstructions( - thread: thread, - message: request, - resolvedTurnSkills: resolvedTurnSkills - ) try await appendMessage(userMessage) try await setThreadStatus(.streaming, for: threadID) - let tools = await toolRegistry.allDefinitions() - try await maybeCompactThreadContextBeforeTurn( - thread: thread, - request: request, - instructions: resolvedInstructions, - tools: tools, - session: session - ) - let turnStart = try await beginTurnWithUnauthorizedRecovery( - thread: thread, - history: effectiveHistory(for: threadID), - message: request, - instructions: resolvedInstructions, - responseFormat: responseFormat, - streamedStructuredOutput: streamedStructuredOutput, - tools: tools, - session: session - ) - let turnStream = turnStart.turnStream - let turnSession = turnStart.session - return AsyncThrowingStream { continuation in continuation.yield(.messageCommitted(userMessage)) continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) Task { - await self.consumeTurnStream( - turnStream, - for: threadID, - userMessage: userMessage, - session: turnSession, - resolvedTurnSkills: resolvedTurnSkills, - continuation: continuation - ) + do { + let session = try await self.sessionManager.requireSession() + let resolvedTurnSkills = try self.resolveTurnSkills( + thread: thread, + message: request + ) + let resolvedInstructions = await self.resolveInstructions( + thread: thread, + message: request, + resolvedTurnSkills: resolvedTurnSkills + ) + let tools = await self.toolRegistry.allDefinitions() + try await self.maybeCompactThreadContextBeforeTurn( + thread: thread, + request: request, + instructions: resolvedInstructions, + tools: tools, + session: session + ) + let turnStart = try await self.beginTurnWithUnauthorizedRecovery( + thread: thread, + history: self.effectiveHistory(for: threadID), + message: request, + instructions: resolvedInstructions, + responseFormat: responseFormat, + streamedStructuredOutput: streamedStructuredOutput, + tools: tools, + session: session + ) + await self.consumeTurnStream( + turnStart.turnStream, + for: threadID, + userMessage: userMessage, + session: turnStart.session, + resolvedTurnSkills: resolvedTurnSkills, + continuation: continuation + ) + } catch { + await self.handleTurnStartupFailure( + error, + for: threadID, + continuation: continuation + ) + } } } } @@ -321,4 +331,57 @@ extension AgentRuntime { resolvedTurnSkills: resolvedTurnSkills ) } + + private func runtimeError(for error: Error) -> AgentRuntimeError { + (error as? AgentRuntimeError) + ?? AgentRuntimeError( + code: "turn_failed", + message: error.localizedDescription + ) + } + + private func recordTurnStartupFailure( + _ error: Error, + for threadID: String + ) async -> AgentRuntimeError { + let runtimeError = runtimeError(for: error) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try? setLatestTurnStatus(.failed, for: threadID) + try? setLatestPartialStructuredOutput(nil, for: threadID) + try? await setThreadStatus(.failed, for: threadID) + return runtimeError + } + + private func handleTurnStartupFailure( + _ error: Error, + for threadID: String, + continuation: AsyncThrowingStream.Continuation + ) async { + let runtimeError = await recordTurnStartupFailure(error, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.turnFailed(runtimeError)) + continuation.finish(throwing: error) + } + + private func handleStructuredTurnStartupFailure( + _ error: Error, + for threadID: String, + continuation: AsyncThrowingStream, Error>.Continuation + ) async { + let runtimeError = await recordTurnStartupFailure(error, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.turnFailed(runtimeError)) + continuation.finish(throwing: error) + } } diff --git a/Sources/CodexKit/Runtime/AgentRuntime+StructuredTurnConsumption.swift b/Sources/CodexKit/Runtime/AgentRuntime+StructuredTurnConsumption.swift new file mode 100644 index 0000000..bafa5be --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+StructuredTurnConsumption.swift @@ -0,0 +1,310 @@ +import Foundation + +extension AgentRuntime { + func consumeStructuredTurnStream( + _ turnStream: any AgentTurnStreaming, + for threadID: String, + userMessage: AgentMessage, + session: ChatGPTSession, + resolvedTurnSkills: ResolvedTurnSkills, + responseFormat: AgentStructuredOutputFormat, + options: AgentStructuredStreamingOptions, + decoder: JSONDecoder, + outputType: Output.Type, + continuation: AsyncThrowingStream, Error>.Continuation + ) async { + let policyTracker: TurnSkillPolicyTracker? = if resolvedTurnSkills.compiledToolPolicy.hasConstraints { + TurnSkillPolicyTracker(policy: resolvedTurnSkills.compiledToolPolicy) + } else { + nil + } + var assistantMessages: [AgentMessage] = [] + var sawStructuredCommit = false + var currentTurnID: String? + + do { + for try await backendEvent in turnStream.events { + switch backendEvent { + case let .turnStarted(turn): + currentTurnID = turn.id + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnStarted, + threadID: threadID, + turnID: turn.id, + occurredAt: turn.startedAt + ) + ), + threadID: threadID, + createdAt: turn.startedAt + ) + try setLatestTurnStatus(.running, for: threadID) + updateThreadTimestamp(turn.startedAt, for: threadID) + try await persistState() + continuation.yield(.turnStarted(turn)) + + case let .assistantMessageDelta(threadID, turnID, delta): + continuation.yield( + .assistantMessageDelta( + threadID: threadID, + turnID: turnID, + delta: delta + ) + ) + + case let .assistantMessageCompleted(message): + try await appendMessage(message) + if message.role == .assistant { + assistantMessages.append(message) + } + continuation.yield(.messageCommitted(message)) + + case let .structuredOutputPartial(value): + do { + let decoded = try decodeStructuredValue( + value, + as: outputType, + decoder: decoder + ) + if let currentTurnID { + try setLatestPartialStructuredOutput( + AgentPartialStructuredOutputSnapshot( + turnID: currentTurnID, + formatName: responseFormat.name, + payload: value, + updatedAt: Date() + ), + for: threadID + ) + updateThreadTimestamp(Date(), for: threadID) + try await persistState() + } + if options.emitPartials { + continuation.yield(.structuredOutputPartial(decoded)) + } + } catch { + continuation.yield( + .structuredOutputValidationFailed( + AgentStructuredOutputValidationFailure( + stage: .partial, + message: error.localizedDescription, + rawPayload: value.prettyJSONString + ) + ) + ) + } + + case let .structuredOutputCommitted(value): + do { + let decoded = try decodeStructuredValue( + value, + as: outputType, + decoder: decoder + ) + sawStructuredCommit = true + let metadata = AgentStructuredOutputMetadata( + formatName: responseFormat.name, + payload: value + ) + try setLatestStructuredOutputMetadata(metadata, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) + appendHistoryItem( + .structuredOutput( + AgentStructuredOutputRecord( + threadID: threadID, + turnID: currentTurnID ?? "", + metadata: metadata, + committedAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + updateThreadTimestamp(Date(), for: threadID) + try await persistState() + continuation.yield(.structuredOutputCommitted(decoded)) + } catch { + let validationFailure = AgentStructuredOutputValidationFailure( + stage: .committed, + message: error.localizedDescription, + rawPayload: value.prettyJSONString + ) + let runtimeError = AgentRuntimeError.structuredOutputInvalid( + stage: validationFailure.stage, + underlyingMessage: validationFailure.message + ) + try? setLatestPartialStructuredOutput(nil, for: threadID) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try? setLatestTurnStatus(.failed, for: threadID) + try await setThreadStatus(.failed, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.structuredOutputValidationFailed(validationFailure)) + continuation.yield(.turnFailed(runtimeError)) + continuation.finish(throwing: runtimeError) + return + } + + case let .structuredOutputValidationFailed(validationFailure): + try? setLatestPartialStructuredOutput(nil, for: threadID) + try? await persistState() + continuation.yield(.structuredOutputValidationFailed(validationFailure)) + + case let .toolCallRequested(invocation): + appendHistoryItem( + .toolCall( + AgentToolCallRecord( + invocation: invocation, + requestedAt: Date() + ) + ), + threadID: invocation.threadID, + createdAt: Date() + ) + try setLatestToolState( + latestToolState(for: invocation, result: nil, updatedAt: Date()), + for: invocation.threadID + ) + updateThreadTimestamp(Date(), for: invocation.threadID) + try await persistState() + continuation.yield(.toolCallStarted(invocation)) + + let result: ToolResultEnvelope + if let policyTracker, + let validationError = policyTracker.validate(toolName: invocation.toolName) { + result = .failure( + invocation: invocation, + message: validationError.message + ) + } else { + let resolvedResult = try await resolveToolInvocation( + invocation, + session: session, + continuation: continuation + ) + result = resolvedResult + policyTracker?.recordAccepted(toolName: invocation.toolName) + } + + try await turnStream.submitToolResult(result, for: invocation.id) + continuation.yield(.toolCallFinished(result)) + try await setThreadStatus(.streaming, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) + + case let .turnCompleted(summary): + if let completionError = policyTracker?.completionError() { + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: completionError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try setLatestTurnStatus(.failed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) + try await setThreadStatus(.failed, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.turnFailed(completionError)) + continuation.finish(throwing: completionError) + return + } + + if options.required, !sawStructuredCommit { + let runtimeError = AgentRuntimeError.structuredOutputMissing( + formatName: responseFormat.name + ) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try setLatestTurnStatus(.failed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) + try await setThreadStatus(.failed, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.turnFailed(runtimeError)) + continuation.finish(throwing: runtimeError) + return + } + + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnCompleted, + threadID: threadID, + turnID: summary.turnID, + turnSummary: summary, + occurredAt: summary.completedAt + ) + ), + threadID: threadID, + createdAt: summary.completedAt + ) + try setLatestTurnStatus(.completed, for: threadID) + try setLatestPartialStructuredOutput(nil, for: threadID) + try await setThreadStatus(.idle, for: threadID) + await automaticallyCaptureMemoriesIfConfigured( + for: threadID, + userMessage: userMessage, + assistantMessages: assistantMessages + ) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .idle)) + continuation.yield(.turnCompleted(summary)) + } + } + + continuation.finish() + } catch { + let runtimeError = (error as? AgentRuntimeError) + ?? AgentRuntimeError( + code: "turn_failed", + message: error.localizedDescription + ) + appendHistoryItem( + .systemEvent( + AgentSystemEventRecord( + type: .turnFailed, + threadID: threadID, + turnID: currentTurnID, + error: runtimeError, + occurredAt: Date() + ) + ), + threadID: threadID, + createdAt: Date() + ) + try? setLatestTurnStatus(.failed, for: threadID) + try? setLatestPartialStructuredOutput(nil, for: threadID) + try? await setThreadStatus(.failed, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.turnFailed(runtimeError)) + continuation.finish(throwing: error) + } + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+ToolInvocationResolution.swift b/Sources/CodexKit/Runtime/AgentRuntime+ToolInvocationResolution.swift new file mode 100644 index 0000000..970e111 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+ToolInvocationResolution.swift @@ -0,0 +1,192 @@ +import Foundation + +extension AgentRuntime { + func resolveToolInvocation( + _ invocation: ToolInvocation, + session: ChatGPTSession, + continuation: AsyncThrowingStream.Continuation + ) async throws -> ToolResultEnvelope { + try await resolveToolInvocationImpl( + invocation, + session: session, + yieldThreadStatusChanged: { threadID, status in + continuation.yield(.threadStatusChanged(threadID: threadID, status: status)) + }, + yieldApprovalRequested: { approval in + continuation.yield(.approvalRequested(approval)) + }, + yieldApprovalResolved: { resolution in + continuation.yield(.approvalResolved(resolution)) + } + ) + } + + func resolveToolInvocation( + _ invocation: ToolInvocation, + session: ChatGPTSession, + continuation: AsyncThrowingStream, Error>.Continuation + ) async throws -> ToolResultEnvelope { + try await resolveToolInvocationImpl( + invocation, + session: session, + yieldThreadStatusChanged: { threadID, status in + continuation.yield(.threadStatusChanged(threadID: threadID, status: status)) + }, + yieldApprovalRequested: { approval in + continuation.yield(.approvalRequested(approval)) + }, + yieldApprovalResolved: { resolution in + continuation.yield(.approvalResolved(resolution)) + } + ) + } + + private func resolveToolInvocationImpl( + _ invocation: ToolInvocation, + session: ChatGPTSession, + yieldThreadStatusChanged: (String, AgentThreadStatus) -> Void, + yieldApprovalRequested: (ApprovalRequest) -> Void, + yieldApprovalResolved: (ApprovalResolution) -> Void + ) async throws -> ToolResultEnvelope { + if let definition = await toolRegistry.definition(named: invocation.toolName), + definition.approvalPolicy == .requiresApproval { + let approval = ApprovalRequest( + threadID: invocation.threadID, + turnID: invocation.turnID, + toolInvocation: invocation, + title: "Approve \(invocation.toolName)?", + message: definition.approvalMessage + ?? "This tool requires explicit approval before it can run." + ) + + appendHistoryItem( + .approval( + AgentApprovalRecord( + kind: .requested, + request: approval, + occurredAt: Date() + ) + ), + threadID: invocation.threadID, + createdAt: Date() + ) + try setPendingState( + .approval( + AgentPendingApprovalState( + request: approval, + requestedAt: Date() + ) + ), + for: invocation.threadID + ) + try await setThreadStatus(.waitingForApproval, for: invocation.threadID) + yieldThreadStatusChanged(invocation.threadID, .waitingForApproval) + yieldApprovalRequested(approval) + + let decision = try await approvalCoordinator.requestApproval(approval) + let resolution = ApprovalResolution( + requestID: approval.id, + threadID: approval.threadID, + turnID: approval.turnID, + decision: decision + ) + appendHistoryItem( + .approval( + AgentApprovalRecord( + kind: .resolved, + request: approval, + resolution: resolution, + occurredAt: resolution.decidedAt + ) + ), + threadID: invocation.threadID, + createdAt: resolution.decidedAt + ) + try setPendingState(nil, for: invocation.threadID) + yieldApprovalResolved(resolution) + + guard decision == .approved else { + let denied = ToolResultEnvelope.denied(invocation: invocation) + try setLatestToolState( + latestToolState(for: invocation, result: denied, updatedAt: resolution.decidedAt), + for: invocation.threadID + ) + appendHistoryItem( + .toolResult( + AgentToolResultRecord( + threadID: invocation.threadID, + turnID: invocation.turnID, + result: denied, + completedAt: resolution.decidedAt + ) + ), + threadID: invocation.threadID, + createdAt: resolution.decidedAt + ) + updateThreadTimestamp(resolution.decidedAt, for: invocation.threadID) + try await persistState() + return denied + } + } + + let toolWaitStartedAt = Date() + try setPendingState( + .toolWait( + AgentPendingToolWaitState( + invocationID: invocation.id, + turnID: invocation.turnID, + toolName: invocation.toolName, + startedAt: toolWaitStartedAt + ) + ), + for: invocation.threadID + ) + try setLatestToolState( + latestToolState(for: invocation, result: nil, updatedAt: toolWaitStartedAt), + for: invocation.threadID + ) + try await setThreadStatus(.waitingForToolResult, for: invocation.threadID) + yieldThreadStatusChanged(invocation.threadID, .waitingForToolResult) + + let result = await toolRegistry.execute(invocation, session: session) + let resultDate = Date() + try setLatestToolState( + latestToolState(for: invocation, result: result, updatedAt: resultDate), + for: invocation.threadID + ) + if let session = result.session, !session.isTerminal { + try setPendingState( + .toolWait( + AgentPendingToolWaitState( + invocationID: invocation.id, + turnID: invocation.turnID, + toolName: invocation.toolName, + startedAt: toolWaitStartedAt, + sessionID: session.sessionID, + sessionStatus: session.status, + metadata: session.metadata, + resumable: session.resumable + ) + ), + for: invocation.threadID + ) + } else { + try setPendingState(nil, for: invocation.threadID) + appendHistoryItem( + .toolResult( + AgentToolResultRecord( + threadID: invocation.threadID, + turnID: invocation.turnID, + result: result, + completedAt: resultDate + ) + ), + threadID: invocation.threadID, + createdAt: resultDate + ) + } + updateThreadTimestamp(resultDate, for: invocation.threadID) + try await persistState() + return result + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift b/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift index 7985428..89cdd5f 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+TurnConsumption.swift @@ -181,633 +181,4 @@ extension AgentRuntime { continuation.finish(throwing: error) } } - - func consumeStructuredTurnStream( - _ turnStream: any AgentTurnStreaming, - for threadID: String, - userMessage: AgentMessage, - session: ChatGPTSession, - resolvedTurnSkills: ResolvedTurnSkills, - responseFormat: AgentStructuredOutputFormat, - options: AgentStructuredStreamingOptions, - decoder: JSONDecoder, - outputType: Output.Type, - continuation: AsyncThrowingStream, Error>.Continuation - ) async { - let policyTracker: TurnSkillPolicyTracker? = if resolvedTurnSkills.compiledToolPolicy.hasConstraints { - TurnSkillPolicyTracker(policy: resolvedTurnSkills.compiledToolPolicy) - } else { - nil - } - var assistantMessages: [AgentMessage] = [] - var sawStructuredCommit = false - var currentTurnID: String? - - do { - for try await backendEvent in turnStream.events { - switch backendEvent { - case let .turnStarted(turn): - currentTurnID = turn.id - appendHistoryItem( - .systemEvent( - AgentSystemEventRecord( - type: .turnStarted, - threadID: threadID, - turnID: turn.id, - occurredAt: turn.startedAt - ) - ), - threadID: threadID, - createdAt: turn.startedAt - ) - try setLatestTurnStatus(.running, for: threadID) - updateThreadTimestamp(turn.startedAt, for: threadID) - try await persistState() - continuation.yield(.turnStarted(turn)) - - case let .assistantMessageDelta(threadID, turnID, delta): - continuation.yield( - .assistantMessageDelta( - threadID: threadID, - turnID: turnID, - delta: delta - ) - ) - - case let .assistantMessageCompleted(message): - try await appendMessage(message) - if message.role == .assistant { - assistantMessages.append(message) - } - continuation.yield(.messageCommitted(message)) - - case let .structuredOutputPartial(value): - do { - let decoded = try decodeStructuredValue( - value, - as: outputType, - decoder: decoder - ) - if let currentTurnID { - try setLatestPartialStructuredOutput( - AgentPartialStructuredOutputSnapshot( - turnID: currentTurnID, - formatName: responseFormat.name, - payload: value, - updatedAt: Date() - ), - for: threadID - ) - updateThreadTimestamp(Date(), for: threadID) - try await persistState() - } - if options.emitPartials { - continuation.yield(.structuredOutputPartial(decoded)) - } - } catch { - continuation.yield( - .structuredOutputValidationFailed( - AgentStructuredOutputValidationFailure( - stage: .partial, - message: error.localizedDescription, - rawPayload: value.prettyJSONString - ) - ) - ) - } - - case let .structuredOutputCommitted(value): - do { - let decoded = try decodeStructuredValue( - value, - as: outputType, - decoder: decoder - ) - sawStructuredCommit = true - let metadata = AgentStructuredOutputMetadata( - formatName: responseFormat.name, - payload: value - ) - try setLatestStructuredOutputMetadata(metadata, for: threadID) - try setLatestPartialStructuredOutput(nil, for: threadID) - appendHistoryItem( - .structuredOutput( - AgentStructuredOutputRecord( - threadID: threadID, - turnID: currentTurnID ?? "", - metadata: metadata, - committedAt: Date() - ) - ), - threadID: threadID, - createdAt: Date() - ) - updateThreadTimestamp(Date(), for: threadID) - try await persistState() - continuation.yield(.structuredOutputCommitted(decoded)) - } catch { - let validationFailure = AgentStructuredOutputValidationFailure( - stage: .committed, - message: error.localizedDescription, - rawPayload: value.prettyJSONString - ) - let runtimeError = AgentRuntimeError.structuredOutputInvalid( - stage: validationFailure.stage, - underlyingMessage: validationFailure.message - ) - try? setLatestPartialStructuredOutput(nil, for: threadID) - appendHistoryItem( - .systemEvent( - AgentSystemEventRecord( - type: .turnFailed, - threadID: threadID, - turnID: currentTurnID, - error: runtimeError, - occurredAt: Date() - ) - ), - threadID: threadID, - createdAt: Date() - ) - try? setLatestTurnStatus(.failed, for: threadID) - try await setThreadStatus(.failed, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) - continuation.yield(.structuredOutputValidationFailed(validationFailure)) - continuation.yield(.turnFailed(runtimeError)) - continuation.finish(throwing: runtimeError) - return - } - - case let .structuredOutputValidationFailed(validationFailure): - try? setLatestPartialStructuredOutput(nil, for: threadID) - try? await persistState() - continuation.yield(.structuredOutputValidationFailed(validationFailure)) - - case let .toolCallRequested(invocation): - appendHistoryItem( - .toolCall( - AgentToolCallRecord( - invocation: invocation, - requestedAt: Date() - ) - ), - threadID: invocation.threadID, - createdAt: Date() - ) - try setLatestToolState( - latestToolState(for: invocation, result: nil, updatedAt: Date()), - for: invocation.threadID - ) - updateThreadTimestamp(Date(), for: invocation.threadID) - try await persistState() - continuation.yield(.toolCallStarted(invocation)) - - let result: ToolResultEnvelope - if let policyTracker, - let validationError = policyTracker.validate(toolName: invocation.toolName) { - result = .failure( - invocation: invocation, - message: validationError.message - ) - } else { - let resolvedResult = try await resolveToolInvocation( - invocation, - session: session, - continuation: continuation - ) - result = resolvedResult - policyTracker?.recordAccepted(toolName: invocation.toolName) - } - - try await turnStream.submitToolResult(result, for: invocation.id) - continuation.yield(.toolCallFinished(result)) - try await setThreadStatus(.streaming, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) - - case let .turnCompleted(summary): - if let completionError = policyTracker?.completionError() { - appendHistoryItem( - .systemEvent( - AgentSystemEventRecord( - type: .turnFailed, - threadID: threadID, - turnID: currentTurnID, - error: completionError, - occurredAt: Date() - ) - ), - threadID: threadID, - createdAt: Date() - ) - try setLatestTurnStatus(.failed, for: threadID) - try setLatestPartialStructuredOutput(nil, for: threadID) - try await setThreadStatus(.failed, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) - continuation.yield(.turnFailed(completionError)) - continuation.finish(throwing: completionError) - return - } - - if options.required, !sawStructuredCommit { - let runtimeError = AgentRuntimeError.structuredOutputMissing( - formatName: responseFormat.name - ) - appendHistoryItem( - .systemEvent( - AgentSystemEventRecord( - type: .turnFailed, - threadID: threadID, - turnID: currentTurnID, - error: runtimeError, - occurredAt: Date() - ) - ), - threadID: threadID, - createdAt: Date() - ) - try setLatestTurnStatus(.failed, for: threadID) - try setLatestPartialStructuredOutput(nil, for: threadID) - try await setThreadStatus(.failed, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) - continuation.yield(.turnFailed(runtimeError)) - continuation.finish(throwing: runtimeError) - return - } - - appendHistoryItem( - .systemEvent( - AgentSystemEventRecord( - type: .turnCompleted, - threadID: threadID, - turnID: summary.turnID, - turnSummary: summary, - occurredAt: summary.completedAt - ) - ), - threadID: threadID, - createdAt: summary.completedAt - ) - try setLatestTurnStatus(.completed, for: threadID) - try setLatestPartialStructuredOutput(nil, for: threadID) - try await setThreadStatus(.idle, for: threadID) - await automaticallyCaptureMemoriesIfConfigured( - for: threadID, - userMessage: userMessage, - assistantMessages: assistantMessages - ) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .idle)) - continuation.yield(.turnCompleted(summary)) - } - } - - continuation.finish() - } catch { - let runtimeError = (error as? AgentRuntimeError) - ?? AgentRuntimeError( - code: "turn_failed", - message: error.localizedDescription - ) - appendHistoryItem( - .systemEvent( - AgentSystemEventRecord( - type: .turnFailed, - threadID: threadID, - turnID: currentTurnID, - error: runtimeError, - occurredAt: Date() - ) - ), - threadID: threadID, - createdAt: Date() - ) - try? setLatestTurnStatus(.failed, for: threadID) - try? setLatestPartialStructuredOutput(nil, for: threadID) - try? await setThreadStatus(.failed, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) - continuation.yield(.turnFailed(runtimeError)) - continuation.finish(throwing: error) - } - } - - func resolveToolInvocation( - _ invocation: ToolInvocation, - session: ChatGPTSession, - continuation: AsyncThrowingStream.Continuation - ) async throws -> ToolResultEnvelope { - if let definition = await toolRegistry.definition(named: invocation.toolName), - definition.approvalPolicy == .requiresApproval { - let approval = ApprovalRequest( - threadID: invocation.threadID, - turnID: invocation.turnID, - toolInvocation: invocation, - title: "Approve \(invocation.toolName)?", - message: definition.approvalMessage - ?? "This tool requires explicit approval before it can run." - ) - - appendHistoryItem( - .approval( - AgentApprovalRecord( - kind: .requested, - request: approval, - occurredAt: Date() - ) - ), - threadID: invocation.threadID, - createdAt: Date() - ) - try setPendingState( - .approval( - AgentPendingApprovalState( - request: approval, - requestedAt: Date() - ) - ), - for: invocation.threadID - ) - try await setThreadStatus(.waitingForApproval, for: invocation.threadID) - continuation.yield( - .threadStatusChanged( - threadID: invocation.threadID, - status: .waitingForApproval - ) - ) - continuation.yield(.approvalRequested(approval)) - - let decision = try await approvalCoordinator.requestApproval(approval) - let resolution = ApprovalResolution( - requestID: approval.id, - threadID: approval.threadID, - turnID: approval.turnID, - decision: decision - ) - appendHistoryItem( - .approval( - AgentApprovalRecord( - kind: .resolved, - request: approval, - resolution: resolution, - occurredAt: resolution.decidedAt - ) - ), - threadID: invocation.threadID, - createdAt: resolution.decidedAt - ) - try setPendingState(nil, for: invocation.threadID) - continuation.yield( - .approvalResolved( - resolution - ) - ) - - guard decision == .approved else { - let denied = ToolResultEnvelope.denied(invocation: invocation) - try setLatestToolState( - latestToolState(for: invocation, result: denied, updatedAt: resolution.decidedAt), - for: invocation.threadID - ) - appendHistoryItem( - .toolResult( - AgentToolResultRecord( - threadID: invocation.threadID, - turnID: invocation.turnID, - result: denied, - completedAt: resolution.decidedAt - ) - ), - threadID: invocation.threadID, - createdAt: resolution.decidedAt - ) - updateThreadTimestamp(resolution.decidedAt, for: invocation.threadID) - try await persistState() - return denied - } - } - - let toolWaitStartedAt = Date() - try setPendingState( - .toolWait( - AgentPendingToolWaitState( - invocationID: invocation.id, - turnID: invocation.turnID, - toolName: invocation.toolName, - startedAt: toolWaitStartedAt - ) - ), - for: invocation.threadID - ) - try setLatestToolState( - latestToolState(for: invocation, result: nil, updatedAt: toolWaitStartedAt), - for: invocation.threadID - ) - try await setThreadStatus(.waitingForToolResult, for: invocation.threadID) - continuation.yield( - .threadStatusChanged( - threadID: invocation.threadID, - status: .waitingForToolResult - ) - ) - - let result = await toolRegistry.execute(invocation, session: session) - let resultDate = Date() - try setLatestToolState( - latestToolState(for: invocation, result: result, updatedAt: resultDate), - for: invocation.threadID - ) - if let session = result.session, !session.isTerminal { - try setPendingState( - .toolWait( - AgentPendingToolWaitState( - invocationID: invocation.id, - turnID: invocation.turnID, - toolName: invocation.toolName, - startedAt: toolWaitStartedAt, - sessionID: session.sessionID, - sessionStatus: session.status, - metadata: session.metadata, - resumable: session.resumable - ) - ), - for: invocation.threadID - ) - } else { - try setPendingState(nil, for: invocation.threadID) - appendHistoryItem( - .toolResult( - AgentToolResultRecord( - threadID: invocation.threadID, - turnID: invocation.turnID, - result: result, - completedAt: resultDate - ) - ), - threadID: invocation.threadID, - createdAt: resultDate - ) - } - updateThreadTimestamp(resultDate, for: invocation.threadID) - try await persistState() - return result - } - - func resolveToolInvocation( - _ invocation: ToolInvocation, - session: ChatGPTSession, - continuation: AsyncThrowingStream, Error>.Continuation - ) async throws -> ToolResultEnvelope { - if let definition = await toolRegistry.definition(named: invocation.toolName), - definition.approvalPolicy == .requiresApproval { - let approval = ApprovalRequest( - threadID: invocation.threadID, - turnID: invocation.turnID, - toolInvocation: invocation, - title: "Approve \(invocation.toolName)?", - message: definition.approvalMessage - ?? "This tool requires explicit approval before it can run." - ) - - appendHistoryItem( - .approval( - AgentApprovalRecord( - kind: .requested, - request: approval, - occurredAt: Date() - ) - ), - threadID: invocation.threadID, - createdAt: Date() - ) - try setPendingState( - .approval( - AgentPendingApprovalState( - request: approval, - requestedAt: Date() - ) - ), - for: invocation.threadID - ) - try await setThreadStatus(.waitingForApproval, for: invocation.threadID) - continuation.yield( - .threadStatusChanged( - threadID: invocation.threadID, - status: .waitingForApproval - ) - ) - continuation.yield(.approvalRequested(approval)) - - let decision = try await approvalCoordinator.requestApproval(approval) - let resolution = ApprovalResolution( - requestID: approval.id, - threadID: approval.threadID, - turnID: approval.turnID, - decision: decision - ) - appendHistoryItem( - .approval( - AgentApprovalRecord( - kind: .resolved, - request: approval, - resolution: resolution, - occurredAt: resolution.decidedAt - ) - ), - threadID: invocation.threadID, - createdAt: resolution.decidedAt - ) - try setPendingState(nil, for: invocation.threadID) - continuation.yield( - .approvalResolved( - resolution - ) - ) - - guard decision == .approved else { - let denied = ToolResultEnvelope.denied(invocation: invocation) - try setLatestToolState( - latestToolState(for: invocation, result: denied, updatedAt: resolution.decidedAt), - for: invocation.threadID - ) - appendHistoryItem( - .toolResult( - AgentToolResultRecord( - threadID: invocation.threadID, - turnID: invocation.turnID, - result: denied, - completedAt: resolution.decidedAt - ) - ), - threadID: invocation.threadID, - createdAt: resolution.decidedAt - ) - updateThreadTimestamp(resolution.decidedAt, for: invocation.threadID) - try await persistState() - return denied - } - } - - let toolWaitStartedAt = Date() - try setPendingState( - .toolWait( - AgentPendingToolWaitState( - invocationID: invocation.id, - turnID: invocation.turnID, - toolName: invocation.toolName, - startedAt: toolWaitStartedAt - ) - ), - for: invocation.threadID - ) - try setLatestToolState( - latestToolState(for: invocation, result: nil, updatedAt: toolWaitStartedAt), - for: invocation.threadID - ) - try await setThreadStatus(.waitingForToolResult, for: invocation.threadID) - continuation.yield( - .threadStatusChanged( - threadID: invocation.threadID, - status: .waitingForToolResult - ) - ) - - let result = await toolRegistry.execute(invocation, session: session) - let resultDate = Date() - try setLatestToolState( - latestToolState(for: invocation, result: result, updatedAt: resultDate), - for: invocation.threadID - ) - if let session = result.session, !session.isTerminal { - try setPendingState( - .toolWait( - AgentPendingToolWaitState( - invocationID: invocation.id, - turnID: invocation.turnID, - toolName: invocation.toolName, - startedAt: toolWaitStartedAt, - sessionID: session.sessionID, - sessionStatus: session.status, - metadata: session.metadata, - resumable: session.resumable - ) - ), - for: invocation.threadID - ) - } else { - try setPendingState(nil, for: invocation.threadID) - appendHistoryItem( - .toolResult( - AgentToolResultRecord( - threadID: invocation.threadID, - turnID: invocation.turnID, - result: result, - completedAt: resultDate - ) - ), - threadID: invocation.threadID, - createdAt: resultDate - ) - } - updateThreadTimestamp(resultDate, for: invocation.threadID) - try await persistState() - return result - } } diff --git a/Sources/CodexKit/Runtime/AgentRuntime.swift b/Sources/CodexKit/Runtime/AgentRuntime.swift index 9408427..aeda489 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime.swift @@ -1,3 +1,4 @@ +import Combine import Foundation public actor AgentRuntime { @@ -65,6 +66,7 @@ public actor AgentRuntime { let baseInstructions: String? let definitionSourceLoader: AgentDefinitionSourceLoader let contextCompactionConfiguration: AgentContextCompactionConfiguration + nonisolated let observationCenter: AgentRuntimeObservationCenter var skillsByID: [String: AgentSkill] var state: StoredRuntimeState = .empty @@ -170,15 +172,21 @@ public actor AgentRuntime { self.baseInstructions = configuration.baseInstructions ?? configuration.backend.baseInstructions self.definitionSourceLoader = configuration.definitionSourceLoader self.contextCompactionConfiguration = configuration.contextCompaction + self.observationCenter = AgentRuntimeObservationCenter() self.skillsByID = try Self.validatedSkills(from: configuration.skills) } + public nonisolated var observations: AnyPublisher { + observationCenter.publisher + } + @discardableResult public func restore() async throws -> StoredRuntimeState { _ = try await sessionManager.restore() _ = try await stateStore.prepare() state = try await stateStore.loadState() pendingStoreOperations.removeAll() + publishAllObservations() return state } @@ -227,18 +235,70 @@ public actor AgentRuntime { state = state.normalized() guard !pendingStoreOperations.isEmpty else { try await stateStore.saveState(state) + publishAllObservations() return } let operations = pendingStoreOperations try await stateStore.apply(operations) pendingStoreOperations.removeAll() + publishObservations(for: operations) } func enqueueStoreOperation(_ operation: AgentStoreWriteOperation) { pendingStoreOperations.append(operation) } + func publishAllObservations() { + observationCenter.send(.threadsChanged(threads())) + for thread in state.threads { + publishThreadObservations(for: thread.id) + } + } + + func publishObservations(for operations: [AgentStoreWriteOperation]) { + let deletedThreadIDs = Set(operations.compactMap { operation -> String? in + guard case let .deleteThread(threadID) = operation else { + return nil + } + return threadID + }) + let affectedThreadIDs = Set(operations.map(\.affectedThreadID)) + + observationCenter.send(.threadsChanged(threads())) + for threadID in deletedThreadIDs { + observationCenter.send(.threadDeleted(threadID: threadID)) + } + for threadID in affectedThreadIDs.subtracting(deletedThreadIDs) { + publishThreadObservations(for: threadID) + } + } + + func publishThreadObservations(for threadID: String) { + guard let thread = thread(for: threadID) else { + return + } + + observationCenter.send(.threadChanged(thread)) + observationCenter.send( + .messagesChanged( + threadID: threadID, + messages: state.messagesByThread[threadID] ?? [] + ) + ) + observationCenter.send( + .threadSummaryChanged( + state.summariesByThread[threadID] ?? state.threadSummaryFallback(for: thread) + ) + ) + observationCenter.send( + .threadContextStateChanged( + threadID: threadID, + state: state.contextStateByThread[threadID] + ) + ) + } + func resolveInstructions( thread: AgentThread, message: UserMessageRequest, diff --git a/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift b/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift new file mode 100644 index 0000000..78565a5 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift @@ -0,0 +1,25 @@ +import Combine +import Foundation + +public enum AgentRuntimeObservation: Sendable { + case threadsChanged([AgentThread]) + case threadChanged(AgentThread) + case messagesChanged(threadID: String, messages: [AgentMessage]) + case threadSummaryChanged(AgentThreadSummary) + case threadContextStateChanged(threadID: String, state: AgentThreadContextState?) + case threadDeleted(threadID: String) +} + +public final class AgentRuntimeObservationCenter: @unchecked Sendable { + private let subject = PassthroughSubject() + + public init() {} + + public var publisher: AnyPublisher { + subject.eraseToAnyPublisher() + } + + func send(_ observation: AgentRuntimeObservation) { + subject.send(observation) + } +} diff --git a/Sources/CodexKit/Runtime/FileRuntimeStateStore.swift b/Sources/CodexKit/Runtime/FileRuntimeStateStore.swift new file mode 100644 index 0000000..001d5b9 --- /dev/null +++ b/Sources/CodexKit/Runtime/FileRuntimeStateStore.swift @@ -0,0 +1,225 @@ +import Foundation + +public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { + private let url: URL + private let encoder = JSONEncoder() + private let decoder = JSONDecoder() + private let fileManager = FileManager.default + private let attachmentStore: RuntimeAttachmentStore + + public init(url: URL) { + self.url = url + let basename = url.deletingPathExtension().lastPathComponent + self.attachmentStore = RuntimeAttachmentStore( + rootURL: url.deletingLastPathComponent() + .appendingPathComponent("\(basename).codexkit-state", isDirectory: true) + .appendingPathComponent("attachments", isDirectory: true) + ) + } + + public func loadState() async throws -> StoredRuntimeState { + try loadNormalizedStateMigratingIfNeeded() + } + + public func saveState(_ state: StoredRuntimeState) async throws { + try persistLayout(for: state.normalized()) + } + + public func prepare() async throws -> AgentStoreMetadata { + _ = try loadNormalizedStateMigratingIfNeeded() + return try await readMetadata() + } + + public func readMetadata() async throws -> AgentStoreMetadata { + AgentStoreMetadata( + logicalSchemaVersion: .v1, + storeSchemaVersion: 1, + capabilities: AgentStoreCapabilities( + supportsPushdownQueries: false, + supportsCrossThreadQueries: false, + supportsSorting: true, + supportsFiltering: true, + supportsMigrations: true + ), + storeKind: "FileRuntimeStateStore" + ) + } + + public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { + if let manifest = try loadManifest() { + guard let thread = manifest.threads.first(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + return manifest.summariesByThread[id] + ?? StoredRuntimeState(threads: [thread]).threadSummaryFallback(for: thread) + } + + return try loadNormalizedStateMigratingIfNeeded().threadSummary(id: id) + } + + public func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage { + if let manifest = try loadManifest() { + guard manifest.threads.contains(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + + let history = try loadHistory(for: id) + let state = StoredRuntimeState( + threads: manifest.threads, + historyByThread: [id: history], + summariesByThread: manifest.summariesByThread, + contextStateByThread: manifest.contextStateByThread, + nextHistorySequenceByThread: manifest.nextHistorySequenceByThread + ) + return try state.threadHistoryPage(id: id, query: query) + } + + return try loadNormalizedStateMigratingIfNeeded().threadHistoryPage(id: id, query: query) + } + + public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { + let summary = try await fetchThreadSummary(id: id) + return summary.latestStructuredOutputMetadata + } + + public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { + if let manifest = try loadManifest() { + guard manifest.threads.contains(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + return manifest.contextStateByThread[id] + } + + return try loadNormalizedStateMigratingIfNeeded().contextStateByThread[id] + } + + private func loadNormalizedStateMigratingIfNeeded() throws -> StoredRuntimeState { + guard fileManager.fileExists(atPath: url.path) else { + return .empty + } + + if let manifest = try loadManifest() { + return try state(from: manifest) + } + + let data = try Data(contentsOf: url) + let legacy = try decoder.decode(StoredRuntimeState.self, from: data).normalized() + try persistLayout(for: legacy) + return legacy + } + + private func loadManifest() throws -> FileRuntimeStateManifest? { + guard fileManager.fileExists(atPath: url.path) else { + return nil + } + + let data = try Data(contentsOf: url) + return try? decoder.decode(FileRuntimeStateManifest.self, from: data) + } + + private func state(from manifest: FileRuntimeStateManifest) throws -> StoredRuntimeState { + var historyByThread: [String: [AgentHistoryRecord]] = [:] + for thread in manifest.threads { + historyByThread[thread.id] = try loadHistory(for: thread.id) + } + + return StoredRuntimeState( + threads: manifest.threads, + historyByThread: historyByThread, + summariesByThread: manifest.summariesByThread, + contextStateByThread: manifest.contextStateByThread, + nextHistorySequenceByThread: manifest.nextHistorySequenceByThread + ) + } + + private func loadHistory(for threadID: String) throws -> [AgentHistoryRecord] { + let historyURL = historyFileURL(for: threadID) + guard fileManager.fileExists(atPath: historyURL.path) else { + return [] + } + + let data = try Data(contentsOf: historyURL) + if let persisted = try? decoder.decode([PersistedAgentHistoryRecord].self, from: data) { + return try persisted.map { try $0.decode(using: attachmentStore) } + } + return try decoder.decode([AgentHistoryRecord].self, from: data) + } + + private func persistLayout(for state: StoredRuntimeState) throws { + let normalized = state.normalized() + let directory = url.deletingLastPathComponent() + if !directory.path.isEmpty { + try fileManager.createDirectory( + at: directory, + withIntermediateDirectories: true + ) + } + + try fileManager.createDirectory( + at: historyDirectoryURL, + withIntermediateDirectories: true + ) + try attachmentStore.reset() + + for thread in normalized.threads { + let historyURL = historyFileURL(for: thread.id) + let history = normalized.historyByThread[thread.id] ?? [] + let persisted = try history.map { + try PersistedAgentHistoryRecord( + record: $0, + attachmentStore: attachmentStore + ) + } + let data = try encoder.encode(persisted) + try data.write(to: historyURL, options: .atomic) + } + + let manifest = FileRuntimeStateManifest( + threads: normalized.threads, + summariesByThread: normalized.summariesByThread, + contextStateByThread: normalized.contextStateByThread, + nextHistorySequenceByThread: normalized.nextHistorySequenceByThread + ) + let manifestData = try encoder.encode(manifest) + try manifestData.write(to: url, options: .atomic) + } + + private var historyDirectoryURL: URL { + let basename = url.deletingPathExtension().lastPathComponent + return url.deletingLastPathComponent() + .appendingPathComponent("\(basename).codexkit-state", isDirectory: true) + .appendingPathComponent("threads", isDirectory: true) + } + + private func historyFileURL(for threadID: String) -> URL { + historyDirectoryURL.appendingPathComponent(safeThreadFilename(threadID)).appendingPathExtension("json") + } + + private func safeThreadFilename(_ threadID: String) -> String { + threadID.addingPercentEncoding(withAllowedCharacters: .alphanumerics) ?? threadID + } +} + +private struct FileRuntimeStateManifest: Codable { + let storageVersion: Int + let threads: [AgentThread] + let summariesByThread: [String: AgentThreadSummary] + let contextStateByThread: [String: AgentThreadContextState] + let nextHistorySequenceByThread: [String: Int] + + init( + threads: [AgentThread], + summariesByThread: [String: AgentThreadSummary], + contextStateByThread: [String: AgentThreadContextState], + nextHistorySequenceByThread: [String: Int] + ) { + self.storageVersion = 1 + self.threads = threads + self.summariesByThread = summariesByThread + self.contextStateByThread = contextStateByThread + self.nextHistorySequenceByThread = nextHistorySequenceByThread + } +} diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift new file mode 100644 index 0000000..7caceb6 --- /dev/null +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift @@ -0,0 +1,403 @@ +import Foundation +import GRDB + +extension GRDBRuntimeStateStore { + func shouldImportLegacyState() async throws -> Bool { + guard let legacyStateURL else { + return false + } + guard legacyStateURL != url else { + return false + } + guard FileManager.default.fileExists(atPath: legacyStateURL.path) else { + return false + } + guard !databaseExistedAtInitialization else { + return false + } + + let threadCount = try await dbQueue.read { db in + try RuntimeThreadCountQuery().execute(in: db) + } + return threadCount == 0 + } + + func importLegacyState() async throws { + guard let legacyStateURL else { + return + } + + let legacyStore = FileRuntimeStateStore(url: legacyStateURL) + let state = try await legacyStore.loadState().normalized() + guard !state.threads.isEmpty || !state.historyByThread.isEmpty else { + return + } + + try await dbQueue.write { db in + try attachmentStore.reset() + try Self.replaceDatabaseContents( + with: state, + in: db, + attachmentStore: attachmentStore + ) + } + } + + func readUserVersion() async throws -> Int { + try await dbQueue.read { db in + try RuntimeUserVersionQuery().execute(in: db) + } + } + + static func replaceDatabaseContents( + with normalized: StoredRuntimeState, + in db: Database, + attachmentStore: RuntimeAttachmentStore + ) throws { + let threadRows = try normalized.threads.map(Self.makeThreadRow) + let summaryRows = try normalized.threads.compactMap { thread -> RuntimeSummaryRow? in + guard let summary = normalized.summariesByThread[thread.id] else { + return nil + } + return try Self.makeSummaryRow(from: summary) + } + let historyRows = try normalized.historyByThread.values + .flatMap { $0 } + .map { try Self.makeHistoryRow(from: $0, attachmentStore: attachmentStore) } + let structuredOutputRows = try Self.structuredOutputRows(from: normalized.historyByThread) + let contextRows = try normalized.contextStateByThread.values.map(Self.makeContextStateRow) + + try RuntimeContextStateRow.deleteAll(db) + try RuntimeStructuredOutputRow.deleteAll(db) + try RuntimeHistoryRow.deleteAll(db) + try RuntimeSummaryRow.deleteAll(db) + try RuntimeThreadRow.deleteAll(db) + + for row in threadRows { + try row.insert(db) + } + for row in summaryRows { + try row.insert(db) + } + for row in historyRows { + try row.insert(db) + } + for row in structuredOutputRows { + try row.insert(db) + } + for row in contextRows { + try row.insert(db) + } + } + + static func loadPartialState( + for threadIDs: Set, + from db: Database, + attachmentStore: RuntimeAttachmentStore + ) throws -> StoredRuntimeState { + guard !threadIDs.isEmpty else { + return .empty + } + + let ids = Array(threadIDs) + let threadRows = try RuntimeThreadRow + .filter(ids.contains(Column("threadID"))) + .fetchAll(db) + let summaryRows = try RuntimeSummaryRow + .filter(ids.contains(Column("threadID"))) + .fetchAll(db) + // History loading keeps raw SQL here so we can preserve a deterministic + // thread + sequence ordering across multiple thread IDs in one fetch. + let historyRows = try RuntimeHistoryRowsRequest( + sql: """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE threadID IN \(Self.sqlPlaceholders(count: ids.count)) + ORDER BY threadID ASC, sequenceNumber ASC + """, + arguments: StatementArguments(ids) + ).execute(in: db) + let contextRows = try RuntimeContextStateRow + .filter(ids.contains(Column("threadID"))) + .fetchAll(db) + + let threads = try threadRows.map { try Self.decodeThread(from: $0) } + let summaries = try Dictionary( + uniqueKeysWithValues: summaryRows.map { ($0.threadID, try Self.decodeSummary(from: $0)) } + ) + let decodedHistoryRows = try historyRows.map { + try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) + } + let history = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) + let contextState = try Dictionary( + uniqueKeysWithValues: contextRows.map { ($0.threadID, try Self.decodeContextState(from: $0)) } + ) + let nextSequence = history.mapValues { ($0.last?.sequenceNumber ?? 0) + 1 } + + return StoredRuntimeState( + threads: threads, + historyByThread: history, + summariesByThread: summaries, + contextStateByThread: contextState, + nextHistorySequenceByThread: nextSequence + ) + } + + static func persistThreads( + ids threadIDs: Set, + from state: StoredRuntimeState, + in db: Database, + attachmentStore: RuntimeAttachmentStore + ) throws { + let normalized = state.normalized() + let threads = normalized.threads.filter { threadIDs.contains($0.id) } + guard !threads.isEmpty else { + return + } + + for thread in threads { + try Self.makeThreadRow(from: thread).insert(db) + if let summary = normalized.summariesByThread[thread.id] { + try Self.makeSummaryRow(from: summary).insert(db) + } + if let contextState = normalized.contextStateByThread[thread.id] { + try Self.makeContextStateRow(from: contextState).insert(db) + } + for record in normalized.historyByThread[thread.id] ?? [] { + try Self.makeHistoryRow(from: record, attachmentStore: attachmentStore).insert(db) + } + } + + for row in try Self.structuredOutputRows( + from: normalized.historyByThread.filter { threadIDs.contains($0.key) } + ) { + try row.insert(db) + } + } + + static func deletePersistedThread( + _ threadID: String, + in db: Database + ) throws { + _ = try RuntimeThreadRow.deleteOne(db, key: threadID) + } + + static func makeThreadRow(from thread: AgentThread) throws -> RuntimeThreadRow { + RuntimeThreadRow( + threadID: thread.id, + createdAt: thread.createdAt.timeIntervalSince1970, + updatedAt: thread.updatedAt.timeIntervalSince1970, + status: thread.status.rawValue, + encodedThread: try JSONEncoder().encode(thread) + ) + } + + static func makeSummaryRow(from summary: AgentThreadSummary) throws -> RuntimeSummaryRow { + RuntimeSummaryRow( + threadID: summary.threadID, + createdAt: summary.createdAt.timeIntervalSince1970, + updatedAt: summary.updatedAt.timeIntervalSince1970, + latestItemAt: summary.latestItemAt?.timeIntervalSince1970, + itemCount: summary.itemCount, + pendingStateKind: summary.pendingState?.kind.rawValue, + latestStructuredOutputFormatName: summary.latestStructuredOutputMetadata?.formatName, + encodedSummary: try JSONEncoder().encode(summary) + ) + } + + static func makeHistoryRow( + from record: AgentHistoryRecord, + attachmentStore: RuntimeAttachmentStore + ) throws -> RuntimeHistoryRow { + let persisted = try PersistedAgentHistoryRecord( + record: record, + attachmentStore: attachmentStore + ) + return RuntimeHistoryRow( + storageID: "\(record.item.threadID):\(record.sequenceNumber)", + recordID: record.id, + threadID: record.item.threadID, + sequenceNumber: record.sequenceNumber, + createdAt: record.createdAt.timeIntervalSince1970, + kind: record.item.kind.rawValue, + turnID: record.item.turnID, + isCompactionMarker: record.item.isCompactionMarker, + isRedacted: record.redaction != nil, + encodedRecord: try JSONEncoder().encode(persisted) + ) + } + + static func makeContextStateRow(from state: AgentThreadContextState) throws -> RuntimeContextStateRow { + RuntimeContextStateRow( + threadID: state.threadID, + generation: state.generation, + encodedState: try JSONEncoder().encode(state) + ) + } + + static func structuredOutputRows( + from historyByThread: [String: [AgentHistoryRecord]] + ) throws -> [RuntimeStructuredOutputRow] { + try historyByThread.values + .flatMap { $0 } + .compactMap { record -> RuntimeStructuredOutputRow? in + switch record.item { + case let .structuredOutput(output): + return try Self.makeStructuredOutputRow( + id: "structured:\(record.id)", + record: output + ) + + case let .message(message): + guard let metadata = message.structuredOutput else { + return nil + } + return try Self.makeStructuredOutputRow( + id: "message:\(message.id)", + record: AgentStructuredOutputRecord( + threadID: message.threadID, + turnID: "", + messageID: message.id, + metadata: metadata, + committedAt: message.createdAt + ) + ) + + default: + return nil + } + } + } + + static func makeStructuredOutputRow( + id: String, + record: AgentStructuredOutputRecord + ) throws -> RuntimeStructuredOutputRow { + RuntimeStructuredOutputRow( + outputID: id, + threadID: record.threadID, + formatName: record.metadata.formatName, + committedAt: record.committedAt.timeIntervalSince1970, + encodedRecord: try JSONEncoder().encode(record) + ) + } + + static func decodeThread(from row: RuntimeThreadRow) throws -> AgentThread { + try JSONDecoder().decode(AgentThread.self, from: row.encodedThread) + } + + static func decodeSummary(from row: RuntimeSummaryRow) throws -> AgentThreadSummary { + try JSONDecoder().decode(AgentThreadSummary.self, from: row.encodedSummary) + } + + static func decodeContextState(from row: RuntimeContextStateRow) throws -> AgentThreadContextState { + try JSONDecoder().decode(AgentThreadContextState.self, from: row.encodedState) + } + + static func decodeHistoryRecord( + from row: RuntimeHistoryRow, + attachmentStore: RuntimeAttachmentStore + ) throws -> AgentHistoryRecord { + let decoder = JSONDecoder() + if let persisted = try? decoder.decode(PersistedAgentHistoryRecord.self, from: row.encodedRecord) { + return try persisted.decode(using: attachmentStore) + } + return try decoder.decode(AgentHistoryRecord.self, from: row.encodedRecord) + } + + static func decodeStructuredOutputRecord(from row: RuntimeStructuredOutputRow) throws -> AgentStructuredOutputRecord { + try JSONDecoder().decode(AgentStructuredOutputRecord.self, from: row.encodedRecord) + } + + static func makeMigrator() -> DatabaseMigrator { + var migrator = DatabaseMigrator() + + migrator.registerMigration("runtime_store_v1") { db in + try db.create(table: RuntimeThreadRow.databaseTableName) { table in + table.column("threadID", .text).primaryKey() + table.column("createdAt", .double).notNull() + table.column("updatedAt", .double).notNull() + table.column("status", .text).notNull() + table.column("encodedThread", .blob).notNull() + } + + try db.create(table: RuntimeSummaryRow.databaseTableName) { table in + table.column("threadID", .text) + .primaryKey() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("createdAt", .double).notNull() + table.column("updatedAt", .double).notNull() + table.column("latestItemAt", .double) + table.column("itemCount", .integer) + table.column("pendingStateKind", .text) + table.column("latestStructuredOutputFormatName", .text) + table.column("encodedSummary", .blob).notNull() + } + + try db.create(table: RuntimeHistoryRow.databaseTableName) { table in + table.column("storageID", .text).primaryKey() + table.column("recordID", .text).notNull() + table.column("threadID", .text) + .notNull() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("sequenceNumber", .integer).notNull() + table.column("createdAt", .double).notNull() + table.column("kind", .text).notNull() + table.column("turnID", .text) + table.column("isCompactionMarker", .boolean).notNull().defaults(to: false) + table.column("isRedacted", .boolean).notNull().defaults(to: false) + table.column("encodedRecord", .blob).notNull() + } + + try db.create(index: "runtime_history_thread_sequence", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "sequenceNumber"], unique: true) + try db.create(index: "runtime_history_thread_created_at", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "createdAt"]) + try db.create(index: "runtime_history_thread_kind", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "kind"]) + try db.create(index: "runtime_history_thread_record_id", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "recordID"]) + + try db.create(table: RuntimeStructuredOutputRow.databaseTableName) { table in + table.column("outputID", .text).primaryKey() + table.column("threadID", .text) + .notNull() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("formatName", .text).notNull() + table.column("committedAt", .double).notNull() + table.column("encodedRecord", .blob).notNull() + } + + try db.create(index: "runtime_structured_outputs_thread_committed_at", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["threadID", "committedAt"]) + try db.create(index: "runtime_structured_outputs_format_name", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["formatName"]) + + try db.create(table: RuntimeContextStateRow.databaseTableName) { table in + table.column("threadID", .text) + .primaryKey() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("generation", .integer).notNull() + table.column("encodedState", .blob).notNull() + } + + try db.execute(sql: "PRAGMA user_version = \(currentStoreSchemaVersion)") + } + + migrator.registerMigration("runtime_store_v2_compaction_state") { db in + let historyColumns = try db.columns(in: RuntimeHistoryRow.databaseTableName).map(\.name) + if !historyColumns.contains("isCompactionMarker") { + try db.alter(table: RuntimeHistoryRow.databaseTableName) { table in + table.add(column: "isCompactionMarker", .boolean).notNull().defaults(to: false) + } + } + + if try !db.tableExists(RuntimeContextStateRow.databaseTableName) { + try db.create(table: RuntimeContextStateRow.databaseTableName) { table in + table.column("threadID", .text) + .primaryKey() + .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) + table.column("generation", .integer).notNull() + table.column("encodedState", .blob).notNull() + } + } + + try db.execute(sql: "PRAGMA user_version = \(currentStoreSchemaVersion)") + } + + return migrator + } +} diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift new file mode 100644 index 0000000..a47511f --- /dev/null +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift @@ -0,0 +1,459 @@ +import Foundation +import GRDB + +extension GRDBRuntimeStateStore { + func executeHistoryQuery(_ query: HistoryItemsQuery) async throws -> AgentHistoryQueryResult { + try await dbQueue.read { db in + guard let threadRow = try RuntimeThreadRow.fetchOne(db, key: query.threadID) else { + return AgentHistoryQueryResult( + threadID: query.threadID, + records: [], + nextCursor: nil, + previousCursor: nil, + hasMoreBefore: false, + hasMoreAfter: false + ) + } + + let thread = try Self.decodeThread(from: threadRow) + let history = try Self.fetchHistoryRows( + threadID: query.threadID, + kinds: query.kinds, + createdAtRange: query.createdAtRange, + turnID: query.turnID, + includeRedacted: query.includeRedacted, + includeCompactionEvents: query.includeCompactionEvents, + in: db, + attachmentStore: attachmentStore + ) + + let state = StoredRuntimeState( + threads: [thread], + historyByThread: [query.threadID: history] + ) + return try state.execute(query) + } + } + + func executeThreadQuery(_ query: ThreadMetadataQuery) async throws -> [AgentThread] { + try await dbQueue.read { db in + var request = RuntimeThreadRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + if let statuses = query.statuses, !statuses.isEmpty { + request = request.filter(statuses.map(\.rawValue).contains(Column("status"))) + } + if let range = query.updatedAtRange { + request = request.filter(Column("updatedAt") >= range.lowerBound.timeIntervalSince1970) + request = request.filter(Column("updatedAt") <= range.upperBound.timeIntervalSince1970) + } + + switch query.sort { + case let .updatedAt(order): + request = order == .ascending + ? request.order(Column("updatedAt").asc, Column("threadID").asc) + : request.order(Column("updatedAt").desc, Column("threadID").asc) + case let .createdAt(order): + request = order == .ascending + ? request.order(Column("createdAt").asc, Column("threadID").asc) + : request.order(Column("createdAt").desc, Column("threadID").asc) + } + + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + let rows = try request.fetchAll(db) + return try rows.map { try Self.decodeThread(from: $0) } + } + } + + func executeThreadContextStateQuery(_ query: ThreadContextStateQuery) async throws -> [AgentThreadContextState] { + try await dbQueue.read { db in + var request = RuntimeContextStateRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + request = request.order(Column("generation").desc, Column("threadID").asc) + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + return try request.fetchAll(db).map { try Self.decodeContextState(from: $0) } + } + } + + func executePendingStateQuery(_ query: PendingStateQuery) async throws -> [AgentPendingStateRecord] { + try await dbQueue.read { db in + var request = RuntimeSummaryRow + .filter(Column("pendingStateKind") != nil) + + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + if let kinds = query.kinds, !kinds.isEmpty { + request = request.filter(kinds.map(\.rawValue).contains(Column("pendingStateKind"))) + } + + switch query.sort { + case let .updatedAt(order): + request = order == .ascending + ? request.order(Column("updatedAt").asc) + : request.order(Column("updatedAt").desc) + } + + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + let summaries = try request.fetchAll(db) + return try summaries.compactMap { row -> AgentPendingStateRecord? in + let summary = try Self.decodeSummary(from: row) + guard let pendingState = summary.pendingState else { + return nil + } + return AgentPendingStateRecord( + threadID: summary.threadID, + pendingState: pendingState, + updatedAt: summary.updatedAt + ) + } + } + } + + func executeStructuredOutputQuery(_ query: StructuredOutputQuery) async throws -> [AgentStructuredOutputRecord] { + try await dbQueue.read { db in + var request = RuntimeStructuredOutputRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + if let formatNames = query.formatNames, !formatNames.isEmpty { + request = request.filter(formatNames.contains(Column("formatName"))) + } + + switch query.sort { + case let .committedAt(order): + request = order == .ascending + ? request.order(Column("committedAt").asc) + : request.order(Column("committedAt").desc) + } + + if let limit = query.limit, !query.latestOnly { + request = request.limit(max(0, limit)) + } + + var records = try request.fetchAll(db) + .map { try Self.decodeStructuredOutputRecord(from: $0) } + + if query.latestOnly { + var seen = Set() + records = records.filter { seen.insert($0.threadID).inserted } + } + + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } + } + + func executeThreadSnapshotQuery(_ query: ThreadSnapshotQuery) async throws -> [AgentThreadSnapshot] { + try await dbQueue.read { db in + var request = RuntimeSummaryRow.all() + if let threadIDs = query.threadIDs, !threadIDs.isEmpty { + request = request.filter(threadIDs.contains(Column("threadID"))) + } + + switch query.sort { + case let .updatedAt(order): + request = order == .ascending + ? request.order(Column("updatedAt").asc, Column("threadID").asc) + : request.order(Column("updatedAt").desc, Column("threadID").asc) + case let .createdAt(order): + request = order == .ascending + ? request.order(Column("createdAt").asc, Column("threadID").asc) + : request.order(Column("createdAt").desc, Column("threadID").asc) + } + + if let limit = query.limit { + request = request.limit(max(0, limit)) + } + + return try request.fetchAll(db) + .map { try Self.decodeSummary(from: $0) } + .map(\.snapshot) + } + } + + static func fetchHistoryRows( + threadID: String, + kinds: Set?, + createdAtRange: ClosedRange?, + turnID: String?, + includeRedacted: Bool, + includeCompactionEvents: Bool, + in db: Database, + attachmentStore: RuntimeAttachmentStore + ) throws -> [AgentHistoryRecord] { + var clauses = ["threadID = ?"] + var arguments: [any DatabaseValueConvertible] = [threadID] + + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + arguments.append(contentsOf: kinds.map(\.rawValue)) + } + if let createdAtRange { + clauses.append("createdAt >= ?") + clauses.append("createdAt <= ?") + arguments.append(createdAtRange.lowerBound.timeIntervalSince1970) + arguments.append(createdAtRange.upperBound.timeIntervalSince1970) + } + if let turnID { + clauses.append("turnID = ?") + arguments.append(turnID) + } + if !includeRedacted { + clauses.append("isRedacted = 0") + } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + + // This stays in SQL because the history query shape is highly dynamic and + // we always want sequence-ordered reads for restore/query replay semantics. + let sql = """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ORDER BY sequenceNumber ASC + """ + return try RuntimeHistoryRowsRequest( + sql: sql, + arguments: StatementArguments(arguments) + ).execute(in: db).map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } + } + + static func fetchHistoryPage( + threadID: String, + query: AgentHistoryQuery, + in db: Database, + attachmentStore: RuntimeAttachmentStore + ) throws -> AgentThreadHistoryPage { + let limit = max(1, query.limit) + let kinds = historyKinds(from: query.filter) + let includeCompactionEvents = query.filter?.includeCompactionEvents ?? false + let anchor = try decodeCursorSequence(query.cursor, expectedThreadID: threadID) + + switch query.direction { + case .backward: + var clauses = ["threadID = ?"] + var arguments: [any DatabaseValueConvertible] = [threadID] + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + for kind in kinds { arguments.append(kind.rawValue) } + } + if let anchor { + clauses.append("sequenceNumber < ?") + arguments.append(anchor) + } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + + // Cursor paging is kept as raw SQL because the descending window + overfetch + // pattern is much clearer here than trying to express it through chained requests. + let sql = """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ORDER BY sequenceNumber DESC + LIMIT \(limit + 1) + """ + let fetched = try RuntimeHistoryRowsRequest( + sql: sql, + arguments: StatementArguments(arguments) + ).execute(in: db) + let hasMoreBefore = fetched.count > limit + let pageRowsDescending = Array(fetched.prefix(limit)) + let pageRecords = try pageRowsDescending + .map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } + .reversed() + + let hasMoreAfter: Bool + if let anchor { + hasMoreAfter = try historyRecordExists( + threadID: threadID, + kinds: kinds, + includeCompactionEvents: includeCompactionEvents, + comparator: "sequenceNumber >= ?", + value: anchor, + in: db + ) + } else { + hasMoreAfter = false + } + + return AgentThreadHistoryPage( + threadID: threadID, + items: pageRecords.map(\.item), + nextCursor: hasMoreBefore ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + previousCursor: hasMoreAfter ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + + case .forward: + var clauses = ["threadID = ?"] + var arguments: [any DatabaseValueConvertible] = [threadID] + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + for kind in kinds { arguments.append(kind.rawValue) } + } + if let anchor { + clauses.append("sequenceNumber > ?") + arguments.append(anchor) + } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + + // Forward paging mirrors the backward cursor window and stays in SQL for the + // same reason: explicit sequence bounds and overfetch are easier to verify here. + let sql = """ + SELECT * FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ORDER BY sequenceNumber ASC + LIMIT \(limit + 1) + """ + let fetched = try RuntimeHistoryRowsRequest( + sql: sql, + arguments: StatementArguments(arguments) + ).execute(in: db) + let hasMoreAfter = fetched.count > limit + let pageRows = Array(fetched.prefix(limit)) + let pageRecords = try pageRows.map { + try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) + } + + let hasMoreBefore: Bool + if let anchor { + hasMoreBefore = try historyRecordExists( + threadID: threadID, + kinds: kinds, + includeCompactionEvents: includeCompactionEvents, + comparator: "sequenceNumber <= ?", + value: anchor, + in: db + ) + } else { + hasMoreBefore = false + } + + return AgentThreadHistoryPage( + threadID: threadID, + items: pageRecords.map(\.item), + nextCursor: hasMoreAfter ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + previousCursor: hasMoreBefore ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + } + } + + static func historyRecordExists( + threadID: String, + kinds: Set?, + includeCompactionEvents: Bool, + comparator: String, + value: Int, + in db: Database + ) throws -> Bool { + var clauses = ["threadID = ?", comparator] + var arguments: [any DatabaseValueConvertible] = [threadID, value] + if let kinds, !kinds.isEmpty { + clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") + for kind in kinds { arguments.append(kind.rawValue) } + } + if !includeCompactionEvents { + clauses.append("isCompactionMarker = 0") + } + + // EXISTS is one of the few cases where the raw SQL is both shorter and more obvious + // than the equivalent GRDB request composition for cursor-bound history checks. + let sql = """ + SELECT EXISTS( + SELECT 1 FROM \(RuntimeHistoryRow.databaseTableName) + WHERE \(clauses.joined(separator: " AND ")) + ) + """ + return try RuntimeHistoryExistenceQuery( + sql: sql, + arguments: StatementArguments(arguments) + ).execute(in: db) + } + + static func historyKinds(from filter: AgentHistoryFilter?) -> Set? { + guard let filter else { + return nil + } + + var kinds: Set = [] + if filter.includeMessages { kinds.insert(.message) } + if filter.includeToolCalls { kinds.insert(.toolCall) } + if filter.includeToolResults { kinds.insert(.toolResult) } + if filter.includeStructuredOutputs { kinds.insert(.structuredOutput) } + if filter.includeApprovals { kinds.insert(.approval) } + if filter.includeSystemEvents { kinds.insert(.systemEvent) } + return kinds + } + + static func makeCursor(threadID: String, sequenceNumber: Int?) -> AgentHistoryCursor? { + guard let sequenceNumber else { + return nil + } + + let payload = GRDBHistoryCursorPayload( + version: 1, + threadID: threadID, + sequenceNumber: sequenceNumber + ) + let data = (try? JSONEncoder().encode(payload)) ?? Data() + let base64 = data.base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + return AgentHistoryCursor(rawValue: base64) + } + + static func decodeCursorSequence( + _ cursor: AgentHistoryCursor?, + expectedThreadID: String + ) throws -> Int? { + guard let cursor else { + return nil + } + + let padded = cursor.rawValue + .replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let remainder = padded.count % 4 + let adjusted = padded + String(repeating: "=", count: remainder == 0 ? 0 : 4 - remainder) + + guard let data = Data(base64Encoded: adjusted) else { + throw AgentRuntimeError.invalidHistoryCursor() + } + + let payload = try JSONDecoder().decode(GRDBHistoryCursorPayload.self, from: data) + guard payload.threadID == expectedThreadID else { + throw AgentRuntimeError.invalidHistoryCursor() + } + return payload.sequenceNumber + } + + static func defaultLegacyImportURL(for url: URL) -> URL { + url.deletingPathExtension().appendingPathExtension("json") + } + + static func sqlPlaceholders(count: Int) -> String { + "(" + Array(repeating: "?", count: count).joined(separator: ", ") + ")" + } +} diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift index d82b409..08941cf 100644 --- a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift @@ -2,15 +2,15 @@ import Foundation import GRDB public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { - private static let currentStoreSchemaVersion = 2 + static let currentStoreSchemaVersion = 2 - private let url: URL - private let legacyStateURL: URL? - private let attachmentStore: RuntimeAttachmentStore - private let databaseExistedAtInitialization: Bool - private let dbQueue: DatabaseQueue - private let migrator: DatabaseMigrator - private var isPrepared = false + let url: URL + let legacyStateURL: URL? + let attachmentStore: RuntimeAttachmentStore + let databaseExistedAtInitialization: Bool + let dbQueue: DatabaseQueue + let migrator: DatabaseMigrator + var isPrepared = false public init( url: URL, @@ -225,7 +225,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, return try query.execute(in: state) } - private func ensurePrepared() async throws { + func ensurePrepared() async throws { if isPrepared { return } @@ -243,1028 +243,4 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, } isPrepared = true } - - private func executeHistoryQuery(_ query: HistoryItemsQuery) async throws -> AgentHistoryQueryResult { - try await dbQueue.read { db in - guard let threadRow = try RuntimeThreadRow.fetchOne(db, key: query.threadID) else { - return AgentHistoryQueryResult( - threadID: query.threadID, - records: [], - nextCursor: nil, - previousCursor: nil, - hasMoreBefore: false, - hasMoreAfter: false - ) - } - - let thread = try Self.decodeThread(from: threadRow) - let history = try Self.fetchHistoryRows( - threadID: query.threadID, - kinds: query.kinds, - createdAtRange: query.createdAtRange, - turnID: query.turnID, - includeRedacted: query.includeRedacted, - includeCompactionEvents: query.includeCompactionEvents, - in: db, - attachmentStore: attachmentStore - ) - - let state = StoredRuntimeState( - threads: [thread], - historyByThread: [query.threadID: history] - ) - return try state.execute(query) - } - } - - private func executeThreadQuery(_ query: ThreadMetadataQuery) async throws -> [AgentThread] { - try await dbQueue.read { db in - var request = RuntimeThreadRow.all() - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - request = request.filter(threadIDs.contains(Column("threadID"))) - } - if let statuses = query.statuses, !statuses.isEmpty { - request = request.filter(statuses.map(\.rawValue).contains(Column("status"))) - } - if let range = query.updatedAtRange { - request = request.filter(Column("updatedAt") >= range.lowerBound.timeIntervalSince1970) - request = request.filter(Column("updatedAt") <= range.upperBound.timeIntervalSince1970) - } - - switch query.sort { - case let .updatedAt(order): - request = order == .ascending - ? request.order(Column("updatedAt").asc, Column("threadID").asc) - : request.order(Column("updatedAt").desc, Column("threadID").asc) - case let .createdAt(order): - request = order == .ascending - ? request.order(Column("createdAt").asc, Column("threadID").asc) - : request.order(Column("createdAt").desc, Column("threadID").asc) - } - - if let limit = query.limit { - request = request.limit(max(0, limit)) - } - - let rows = try request.fetchAll(db) - return try rows.map { try Self.decodeThread(from: $0) } - } - } - - private func executeThreadContextStateQuery(_ query: ThreadContextStateQuery) async throws -> [AgentThreadContextState] { - try await dbQueue.read { db in - var request = RuntimeContextStateRow.all() - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - request = request.filter(threadIDs.contains(Column("threadID"))) - } - request = request.order(Column("generation").desc, Column("threadID").asc) - if let limit = query.limit { - request = request.limit(max(0, limit)) - } - - return try request.fetchAll(db).map { try Self.decodeContextState(from: $0) } - } - } - - private func executePendingStateQuery(_ query: PendingStateQuery) async throws -> [AgentPendingStateRecord] { - try await dbQueue.read { db in - var request = RuntimeSummaryRow - .filter(Column("pendingStateKind") != nil) - - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - request = request.filter(threadIDs.contains(Column("threadID"))) - } - if let kinds = query.kinds, !kinds.isEmpty { - request = request.filter(kinds.map(\.rawValue).contains(Column("pendingStateKind"))) - } - - switch query.sort { - case let .updatedAt(order): - request = order == .ascending - ? request.order(Column("updatedAt").asc) - : request.order(Column("updatedAt").desc) - } - - if let limit = query.limit { - request = request.limit(max(0, limit)) - } - - let summaries = try request.fetchAll(db) - let records = try summaries.compactMap { row -> AgentPendingStateRecord? in - let summary = try Self.decodeSummary(from: row) - guard let pendingState = summary.pendingState else { - return nil - } - return AgentPendingStateRecord( - threadID: summary.threadID, - pendingState: pendingState, - updatedAt: summary.updatedAt - ) - } - return records - } - } - - private func executeStructuredOutputQuery(_ query: StructuredOutputQuery) async throws -> [AgentStructuredOutputRecord] { - try await dbQueue.read { db in - var request = RuntimeStructuredOutputRow.all() - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - request = request.filter(threadIDs.contains(Column("threadID"))) - } - if let formatNames = query.formatNames, !formatNames.isEmpty { - request = request.filter(formatNames.contains(Column("formatName"))) - } - - switch query.sort { - case let .committedAt(order): - request = order == .ascending - ? request.order(Column("committedAt").asc) - : request.order(Column("committedAt").desc) - } - - if let limit = query.limit, !query.latestOnly { - request = request.limit(max(0, limit)) - } - - var records = try request.fetchAll(db) - .map { try Self.decodeStructuredOutputRecord(from: $0) } - - if query.latestOnly { - var seen = Set() - records = records.filter { seen.insert($0.threadID).inserted } - } - - if let limit = query.limit { - records = Array(records.prefix(max(0, limit))) - } - return records - } - } - - private func executeThreadSnapshotQuery(_ query: ThreadSnapshotQuery) async throws -> [AgentThreadSnapshot] { - try await dbQueue.read { db in - var request = RuntimeSummaryRow.all() - if let threadIDs = query.threadIDs, !threadIDs.isEmpty { - request = request.filter(threadIDs.contains(Column("threadID"))) - } - - switch query.sort { - case let .updatedAt(order): - request = order == .ascending - ? request.order(Column("updatedAt").asc, Column("threadID").asc) - : request.order(Column("updatedAt").desc, Column("threadID").asc) - case let .createdAt(order): - request = order == .ascending - ? request.order(Column("createdAt").asc, Column("threadID").asc) - : request.order(Column("createdAt").desc, Column("threadID").asc) - } - - if let limit = query.limit { - request = request.limit(max(0, limit)) - } - - let snapshots = try request.fetchAll(db) - .map { try Self.decodeSummary(from: $0) } - .map(\.snapshot) - return snapshots - } - } - - private static func replaceDatabaseContents( - with normalized: StoredRuntimeState, - in db: Database, - attachmentStore: RuntimeAttachmentStore - ) throws { - let threadRows = try normalized.threads.map(Self.makeThreadRow) - let summaryRows = try normalized.threads.compactMap { thread -> RuntimeSummaryRow? in - guard let summary = normalized.summariesByThread[thread.id] else { - return nil - } - return try Self.makeSummaryRow(from: summary) - } - let historyRows = try normalized.historyByThread.values - .flatMap { $0 } - .map { try Self.makeHistoryRow(from: $0, attachmentStore: attachmentStore) } - let structuredOutputRows = try Self.structuredOutputRows(from: normalized.historyByThread) - let contextRows = try normalized.contextStateByThread.values.map(Self.makeContextStateRow) - - try RuntimeContextStateRow.deleteAll(db) - try RuntimeStructuredOutputRow.deleteAll(db) - try RuntimeHistoryRow.deleteAll(db) - try RuntimeSummaryRow.deleteAll(db) - try RuntimeThreadRow.deleteAll(db) - - for row in threadRows { - try row.insert(db) - } - for row in summaryRows { - try row.insert(db) - } - for row in historyRows { - try row.insert(db) - } - for row in structuredOutputRows { - try row.insert(db) - } - for row in contextRows { - try row.insert(db) - } - } - - private func shouldImportLegacyState() async throws -> Bool { - guard let legacyStateURL else { - return false - } - guard legacyStateURL != url else { - return false - } - guard FileManager.default.fileExists(atPath: legacyStateURL.path) else { - return false - } - guard !databaseExistedAtInitialization else { - return false - } - - let threadCount = try await dbQueue.read { db in - try RuntimeThreadCountQuery().execute(in: db) - } - return threadCount == 0 - } - - private func importLegacyState() async throws { - guard let legacyStateURL else { - return - } - - let legacyStore = FileRuntimeStateStore(url: legacyStateURL) - let state = try await legacyStore.loadState().normalized() - guard !state.threads.isEmpty || !state.historyByThread.isEmpty else { - return - } - - try await dbQueue.write { db in - try attachmentStore.reset() - try Self.replaceDatabaseContents( - with: state, - in: db, - attachmentStore: attachmentStore - ) - } - } - - private static func loadPartialState( - for threadIDs: Set, - from db: Database, - attachmentStore: RuntimeAttachmentStore - ) throws -> StoredRuntimeState { - guard !threadIDs.isEmpty else { - return .empty - } - - let ids = Array(threadIDs) - let threadRows = try RuntimeThreadRow - .filter(ids.contains(Column("threadID"))) - .fetchAll(db) - let summaryRows = try RuntimeSummaryRow - .filter(ids.contains(Column("threadID"))) - .fetchAll(db) - // History loading keeps raw SQL here so we can preserve a deterministic - // thread + sequence ordering across multiple thread IDs in one fetch. - let historyRows = try RuntimeHistoryRowsRequest( - sql: """ - SELECT * FROM \(RuntimeHistoryRow.databaseTableName) - WHERE threadID IN \(Self.sqlPlaceholders(count: ids.count)) - ORDER BY threadID ASC, sequenceNumber ASC - """, - arguments: StatementArguments(ids) - ).execute(in: db) - let contextRows = try RuntimeContextStateRow - .filter(ids.contains(Column("threadID"))) - .fetchAll(db) - - let threads = try threadRows.map { try Self.decodeThread(from: $0) } - let summaries = try Dictionary( - uniqueKeysWithValues: summaryRows.map { ($0.threadID, try Self.decodeSummary(from: $0)) } - ) - let decodedHistoryRows = try historyRows.map { - try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) - } - let history = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) - let contextState = try Dictionary( - uniqueKeysWithValues: contextRows.map { ($0.threadID, try Self.decodeContextState(from: $0)) } - ) - let nextSequence = history.mapValues { ($0.last?.sequenceNumber ?? 0) + 1 } - - return StoredRuntimeState( - threads: threads, - historyByThread: history, - summariesByThread: summaries, - contextStateByThread: contextState, - nextHistorySequenceByThread: nextSequence - ) - } - - private static func persistThreads( - ids threadIDs: Set, - from state: StoredRuntimeState, - in db: Database, - attachmentStore: RuntimeAttachmentStore - ) throws { - let normalized = state.normalized() - let threads = normalized.threads.filter { threadIDs.contains($0.id) } - guard !threads.isEmpty else { - return - } - - for thread in threads { - try Self.makeThreadRow(from: thread).insert(db) - if let summary = normalized.summariesByThread[thread.id] { - try Self.makeSummaryRow(from: summary).insert(db) - } - if let contextState = normalized.contextStateByThread[thread.id] { - try Self.makeContextStateRow(from: contextState).insert(db) - } - for record in normalized.historyByThread[thread.id] ?? [] { - try Self.makeHistoryRow(from: record, attachmentStore: attachmentStore).insert(db) - } - } - - for row in try Self.structuredOutputRows( - from: normalized.historyByThread.filter { threadIDs.contains($0.key) } - ) { - try row.insert(db) - } - } - - private static func deletePersistedThread( - _ threadID: String, - in db: Database - ) throws { - _ = try RuntimeThreadRow.deleteOne(db, key: threadID) - } - - private static func fetchHistoryRows( - threadID: String, - kinds: Set?, - createdAtRange: ClosedRange?, - turnID: String?, - includeRedacted: Bool, - includeCompactionEvents: Bool, - in db: Database, - attachmentStore: RuntimeAttachmentStore - ) throws -> [AgentHistoryRecord] { - var clauses = ["threadID = ?"] - var arguments: [any DatabaseValueConvertible] = [threadID] - - if let kinds, !kinds.isEmpty { - clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") - arguments.append(contentsOf: kinds.map(\.rawValue)) - } - if let createdAtRange { - clauses.append("createdAt >= ?") - clauses.append("createdAt <= ?") - arguments.append(createdAtRange.lowerBound.timeIntervalSince1970) - arguments.append(createdAtRange.upperBound.timeIntervalSince1970) - } - if let turnID { - clauses.append("turnID = ?") - arguments.append(turnID) - } - if !includeRedacted { - clauses.append("isRedacted = 0") - } - if !includeCompactionEvents { - clauses.append("isCompactionMarker = 0") - } - - // This stays in SQL because the history query shape is highly dynamic and - // we always want sequence-ordered reads for restore/query replay semantics. - let sql = """ - SELECT * FROM \(RuntimeHistoryRow.databaseTableName) - WHERE \(clauses.joined(separator: " AND ")) - ORDER BY sequenceNumber ASC - """ - return try RuntimeHistoryRowsRequest( - sql: sql, - arguments: StatementArguments(arguments) - ).execute(in: db).map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } - } - - private static func fetchHistoryPage( - threadID: String, - query: AgentHistoryQuery, - in db: Database, - attachmentStore: RuntimeAttachmentStore - ) throws -> AgentThreadHistoryPage { - let limit = max(1, query.limit) - let kinds = historyKinds(from: query.filter) - let includeCompactionEvents = query.filter?.includeCompactionEvents ?? false - let anchor = try decodeCursorSequence(query.cursor, expectedThreadID: threadID) - - switch query.direction { - case .backward: - var clauses = ["threadID = ?"] - var arguments: [any DatabaseValueConvertible] = [threadID] - if let kinds, !kinds.isEmpty { - clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") - for kind in kinds { arguments.append(kind.rawValue) } - } - if let anchor { - clauses.append("sequenceNumber < ?") - arguments.append(anchor) - } - if !includeCompactionEvents { - clauses.append("isCompactionMarker = 0") - } - - // Cursor paging is kept as raw SQL because the descending window + overfetch - // pattern is much clearer here than trying to express it through chained requests. - let sql = """ - SELECT * FROM \(RuntimeHistoryRow.databaseTableName) - WHERE \(clauses.joined(separator: " AND ")) - ORDER BY sequenceNumber DESC - LIMIT \(limit + 1) - """ - let fetched = try RuntimeHistoryRowsRequest( - sql: sql, - arguments: StatementArguments(arguments) - ).execute(in: db) - let hasMoreBefore = fetched.count > limit - let pageRowsDescending = Array(fetched.prefix(limit)) - let pageRecords = try pageRowsDescending - .map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } - .reversed() - - let hasMoreAfter: Bool - if let anchor { - hasMoreAfter = try historyRecordExists( - threadID: threadID, - kinds: kinds, - includeCompactionEvents: includeCompactionEvents, - comparator: "sequenceNumber >= ?", - value: anchor, - in: db - ) - } else { - hasMoreAfter = false - } - - return AgentThreadHistoryPage( - threadID: threadID, - items: pageRecords.map(\.item), - nextCursor: hasMoreBefore ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, - previousCursor: hasMoreAfter ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, - hasMoreBefore: hasMoreBefore, - hasMoreAfter: hasMoreAfter - ) - - case .forward: - var clauses = ["threadID = ?"] - var arguments: [any DatabaseValueConvertible] = [threadID] - if let kinds, !kinds.isEmpty { - clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") - for kind in kinds { arguments.append(kind.rawValue) } - } - if let anchor { - clauses.append("sequenceNumber > ?") - arguments.append(anchor) - } - if !includeCompactionEvents { - clauses.append("isCompactionMarker = 0") - } - - // Forward paging mirrors the backward cursor window and stays in SQL for the - // same reason: explicit sequence bounds and overfetch are easier to verify here. - let sql = """ - SELECT * FROM \(RuntimeHistoryRow.databaseTableName) - WHERE \(clauses.joined(separator: " AND ")) - ORDER BY sequenceNumber ASC - LIMIT \(limit + 1) - """ - let fetched = try RuntimeHistoryRowsRequest( - sql: sql, - arguments: StatementArguments(arguments) - ).execute(in: db) - let hasMoreAfter = fetched.count > limit - let pageRows = Array(fetched.prefix(limit)) - let pageRecords = try pageRows.map { - try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) - } - - let hasMoreBefore: Bool - if let anchor { - hasMoreBefore = try historyRecordExists( - threadID: threadID, - kinds: kinds, - includeCompactionEvents: includeCompactionEvents, - comparator: "sequenceNumber <= ?", - value: anchor, - in: db - ) - } else { - hasMoreBefore = false - } - - return AgentThreadHistoryPage( - threadID: threadID, - items: pageRecords.map(\.item), - nextCursor: hasMoreAfter ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, - previousCursor: hasMoreBefore ? makeCursor(threadID: threadID, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, - hasMoreBefore: hasMoreBefore, - hasMoreAfter: hasMoreAfter - ) - } - } - - private static func historyRecordExists( - threadID: String, - kinds: Set?, - includeCompactionEvents: Bool, - comparator: String, - value: Int, - in db: Database - ) throws -> Bool { - var clauses = ["threadID = ?", comparator] - var arguments: [any DatabaseValueConvertible] = [threadID, value] - if let kinds, !kinds.isEmpty { - clauses.append("kind IN \(sqlPlaceholders(count: kinds.count))") - for kind in kinds { arguments.append(kind.rawValue) } - } - if !includeCompactionEvents { - clauses.append("isCompactionMarker = 0") - } - - // EXISTS is one of the few cases where the raw SQL is both shorter and more obvious - // than the equivalent GRDB request composition for cursor-bound history checks. - let sql = """ - SELECT EXISTS( - SELECT 1 FROM \(RuntimeHistoryRow.databaseTableName) - WHERE \(clauses.joined(separator: " AND ")) - ) - """ - return try RuntimeHistoryExistenceQuery( - sql: sql, - arguments: StatementArguments(arguments) - ).execute(in: db) - } - - private static func defaultLegacyImportURL(for url: URL) -> URL { - url.deletingPathExtension().appendingPathExtension("json") - } - - private static func sqlPlaceholders(count: Int) -> String { - "(" + Array(repeating: "?", count: count).joined(separator: ", ") + ")" - } - - private static func historyKinds(from filter: AgentHistoryFilter?) -> Set? { - guard let filter else { - return nil - } - - var kinds: Set = [] - if filter.includeMessages { kinds.insert(.message) } - if filter.includeToolCalls { kinds.insert(.toolCall) } - if filter.includeToolResults { kinds.insert(.toolResult) } - if filter.includeStructuredOutputs { kinds.insert(.structuredOutput) } - if filter.includeApprovals { kinds.insert(.approval) } - if filter.includeSystemEvents { kinds.insert(.systemEvent) } - return kinds - } - - private static func makeCursor(threadID: String, sequenceNumber: Int?) -> AgentHistoryCursor? { - guard let sequenceNumber else { - return nil - } - - let payload = GRDBHistoryCursorPayload( - version: 1, - threadID: threadID, - sequenceNumber: sequenceNumber - ) - let data = (try? JSONEncoder().encode(payload)) ?? Data() - let base64 = data.base64EncodedString() - .replacingOccurrences(of: "+", with: "-") - .replacingOccurrences(of: "/", with: "_") - .replacingOccurrences(of: "=", with: "") - return AgentHistoryCursor(rawValue: base64) - } - - private static func decodeCursorSequence( - _ cursor: AgentHistoryCursor?, - expectedThreadID: String - ) throws -> Int? { - guard let cursor else { - return nil - } - - let padded = cursor.rawValue - .replacingOccurrences(of: "-", with: "+") - .replacingOccurrences(of: "_", with: "/") - let remainder = padded.count % 4 - let adjusted = padded + String(repeating: "=", count: remainder == 0 ? 0 : 4 - remainder) - - guard let data = Data(base64Encoded: adjusted) else { - throw AgentRuntimeError.invalidHistoryCursor() - } - - let payload = try JSONDecoder().decode(GRDBHistoryCursorPayload.self, from: data) - guard payload.threadID == expectedThreadID else { - throw AgentRuntimeError.invalidHistoryCursor() - } - return payload.sequenceNumber - } - - private static func makeThreadRow(from thread: AgentThread) throws -> RuntimeThreadRow { - RuntimeThreadRow( - threadID: thread.id, - createdAt: thread.createdAt.timeIntervalSince1970, - updatedAt: thread.updatedAt.timeIntervalSince1970, - status: thread.status.rawValue, - encodedThread: try JSONEncoder().encode(thread) - ) - } - - private static func makeSummaryRow(from summary: AgentThreadSummary) throws -> RuntimeSummaryRow { - RuntimeSummaryRow( - threadID: summary.threadID, - createdAt: summary.createdAt.timeIntervalSince1970, - updatedAt: summary.updatedAt.timeIntervalSince1970, - latestItemAt: summary.latestItemAt?.timeIntervalSince1970, - itemCount: summary.itemCount, - pendingStateKind: summary.pendingState?.kind.rawValue, - latestStructuredOutputFormatName: summary.latestStructuredOutputMetadata?.formatName, - encodedSummary: try JSONEncoder().encode(summary) - ) - } - - private static func makeHistoryRow( - from record: AgentHistoryRecord, - attachmentStore: RuntimeAttachmentStore - ) throws -> RuntimeHistoryRow { - let persisted = try PersistedAgentHistoryRecord( - record: record, - attachmentStore: attachmentStore - ) - return RuntimeHistoryRow( - storageID: "\(record.item.threadID):\(record.sequenceNumber)", - recordID: record.id, - threadID: record.item.threadID, - sequenceNumber: record.sequenceNumber, - createdAt: record.createdAt.timeIntervalSince1970, - kind: record.item.kind.rawValue, - turnID: record.item.turnID, - isCompactionMarker: record.item.isCompactionMarker, - isRedacted: record.redaction != nil, - encodedRecord: try JSONEncoder().encode(persisted) - ) - } - - private static func makeContextStateRow(from state: AgentThreadContextState) throws -> RuntimeContextStateRow { - RuntimeContextStateRow( - threadID: state.threadID, - generation: state.generation, - encodedState: try JSONEncoder().encode(state) - ) - } - - private static func structuredOutputRows( - from historyByThread: [String: [AgentHistoryRecord]] - ) throws -> [RuntimeStructuredOutputRow] { - try historyByThread.values - .flatMap { $0 } - .compactMap { record -> RuntimeStructuredOutputRow? in - switch record.item { - case let .structuredOutput(output): - return try Self.makeStructuredOutputRow( - id: "structured:\(record.id)", - record: output - ) - - case let .message(message): - guard let metadata = message.structuredOutput else { - return nil - } - return try Self.makeStructuredOutputRow( - id: "message:\(message.id)", - record: AgentStructuredOutputRecord( - threadID: message.threadID, - turnID: "", - messageID: message.id, - metadata: metadata, - committedAt: message.createdAt - ) - ) - - default: - return nil - } - } - } - - private static func makeStructuredOutputRow( - id: String, - record: AgentStructuredOutputRecord - ) throws -> RuntimeStructuredOutputRow { - RuntimeStructuredOutputRow( - outputID: id, - threadID: record.threadID, - formatName: record.metadata.formatName, - committedAt: record.committedAt.timeIntervalSince1970, - encodedRecord: try JSONEncoder().encode(record) - ) - } - - private static func decodeThread(from row: RuntimeThreadRow) throws -> AgentThread { - try JSONDecoder().decode(AgentThread.self, from: row.encodedThread) - } - - private static func decodeSummary(from row: RuntimeSummaryRow) throws -> AgentThreadSummary { - try JSONDecoder().decode(AgentThreadSummary.self, from: row.encodedSummary) - } - - private static func decodeContextState(from row: RuntimeContextStateRow) throws -> AgentThreadContextState { - try JSONDecoder().decode(AgentThreadContextState.self, from: row.encodedState) - } - - private static func decodeHistoryRecord( - from row: RuntimeHistoryRow, - attachmentStore: RuntimeAttachmentStore - ) throws -> AgentHistoryRecord { - let decoder = JSONDecoder() - if let persisted = try? decoder.decode(PersistedAgentHistoryRecord.self, from: row.encodedRecord) { - return try persisted.decode(using: attachmentStore) - } - return try decoder.decode(AgentHistoryRecord.self, from: row.encodedRecord) - } - - private static func decodeStructuredOutputRecord(from row: RuntimeStructuredOutputRow) throws -> AgentStructuredOutputRecord { - try JSONDecoder().decode(AgentStructuredOutputRecord.self, from: row.encodedRecord) - } - - private static func makeMigrator() -> DatabaseMigrator { - var migrator = DatabaseMigrator() - - migrator.registerMigration("runtime_store_v1") { db in - try db.create(table: RuntimeThreadRow.databaseTableName) { table in - table.column("threadID", .text).primaryKey() - table.column("createdAt", .double).notNull() - table.column("updatedAt", .double).notNull() - table.column("status", .text).notNull() - table.column("encodedThread", .blob).notNull() - } - - try db.create(table: RuntimeSummaryRow.databaseTableName) { table in - table.column("threadID", .text) - .primaryKey() - .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) - table.column("createdAt", .double).notNull() - table.column("updatedAt", .double).notNull() - table.column("latestItemAt", .double) - table.column("itemCount", .integer) - table.column("pendingStateKind", .text) - table.column("latestStructuredOutputFormatName", .text) - table.column("encodedSummary", .blob).notNull() - } - - try db.create(table: RuntimeHistoryRow.databaseTableName) { table in - table.column("storageID", .text).primaryKey() - table.column("recordID", .text).notNull() - table.column("threadID", .text) - .notNull() - .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) - table.column("sequenceNumber", .integer).notNull() - table.column("createdAt", .double).notNull() - table.column("kind", .text).notNull() - table.column("turnID", .text) - table.column("isCompactionMarker", .boolean).notNull().defaults(to: false) - table.column("isRedacted", .boolean).notNull().defaults(to: false) - table.column("encodedRecord", .blob).notNull() - } - - try db.create(index: "runtime_history_thread_sequence", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "sequenceNumber"], unique: true) - try db.create(index: "runtime_history_thread_created_at", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "createdAt"]) - try db.create(index: "runtime_history_thread_kind", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "kind"]) - try db.create(index: "runtime_history_thread_record_id", on: RuntimeHistoryRow.databaseTableName, columns: ["threadID", "recordID"]) - - try db.create(table: RuntimeStructuredOutputRow.databaseTableName) { table in - table.column("outputID", .text).primaryKey() - table.column("threadID", .text) - .notNull() - .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) - table.column("formatName", .text).notNull() - table.column("committedAt", .double).notNull() - table.column("encodedRecord", .blob).notNull() - } - - try db.create(index: "runtime_structured_outputs_thread_committed_at", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["threadID", "committedAt"]) - try db.create(index: "runtime_structured_outputs_format_name", on: RuntimeStructuredOutputRow.databaseTableName, columns: ["formatName"]) - - try db.create(table: RuntimeContextStateRow.databaseTableName) { table in - table.column("threadID", .text) - .primaryKey() - .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) - table.column("generation", .integer).notNull() - table.column("encodedState", .blob).notNull() - } - - try db.execute(sql: "PRAGMA user_version = \(currentStoreSchemaVersion)") - } - - migrator.registerMigration("runtime_store_v2_compaction_state") { db in - let historyColumns = try db.columns(in: RuntimeHistoryRow.databaseTableName).map(\.name) - if !historyColumns.contains("isCompactionMarker") { - try db.alter(table: RuntimeHistoryRow.databaseTableName) { table in - table.add(column: "isCompactionMarker", .boolean).notNull().defaults(to: false) - } - } - - if try !db.tableExists(RuntimeContextStateRow.databaseTableName) { - try db.create(table: RuntimeContextStateRow.databaseTableName) { table in - table.column("threadID", .text) - .primaryKey() - .references(RuntimeThreadRow.databaseTableName, onDelete: .cascade) - table.column("generation", .integer).notNull() - table.column("encodedState", .blob).notNull() - } - } - - try db.execute(sql: "PRAGMA user_version = \(currentStoreSchemaVersion)") - } - - return migrator - } - - private static func sortPendingStateRecords( - _ records: [AgentPendingStateRecord], - using sort: AgentPendingStateSort - ) -> [AgentPendingStateRecord] { - records.sorted { lhs, rhs in - switch sort { - case let .updatedAt(order): - if lhs.updatedAt == rhs.updatedAt { - return lhs.threadID < rhs.threadID - } - return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt - } - } - } - - private static func sortStructuredOutputRecords( - _ records: [AgentStructuredOutputRecord], - using sort: AgentStructuredOutputSort - ) -> [AgentStructuredOutputRecord] { - records.sorted { lhs, rhs in - switch sort { - case let .committedAt(order): - if lhs.committedAt == rhs.committedAt { - return (lhs.messageID ?? lhs.turnID) < (rhs.messageID ?? rhs.turnID) - } - return order == .ascending ? lhs.committedAt < rhs.committedAt : lhs.committedAt > rhs.committedAt - } - } - } - - private static func sortThreadSnapshots( - _ snapshots: [AgentThreadSnapshot], - using sort: AgentThreadSnapshotSort - ) -> [AgentThreadSnapshot] { - snapshots.sorted { lhs, rhs in - switch sort { - case let .updatedAt(order): - if lhs.updatedAt == rhs.updatedAt { - return lhs.threadID < rhs.threadID - } - return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt - case let .createdAt(order): - if lhs.createdAt == rhs.createdAt { - return lhs.threadID < rhs.threadID - } - return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt - } - } - } - - private func readUserVersion() async throws -> Int { - try await dbQueue.read { db in - try RuntimeUserVersionQuery().execute(in: db) - } - } -} - -private struct RuntimeThreadRow: Codable, FetchableRecord, PersistableRecord, TableRecord { - static let databaseTableName = "runtime_threads" - - let threadID: String - let createdAt: Double - let updatedAt: Double - let status: String - let encodedThread: Data -} - -private struct RuntimeThreadCountQuery { - func execute(in db: Database) throws -> Int { - let row = try SQLRequest( - sql: "SELECT COUNT(*) AS thread_count FROM \(RuntimeThreadRow.databaseTableName)" - ).fetchOne(db) - let count: Int? = row?["thread_count"] - return count ?? 0 - } -} - -private struct RuntimeSummaryRow: Codable, FetchableRecord, PersistableRecord, TableRecord { - static let databaseTableName = "runtime_summaries" - - let threadID: String - let createdAt: Double - let updatedAt: Double - let latestItemAt: Double? - let itemCount: Int? - let pendingStateKind: String? - let latestStructuredOutputFormatName: String? - let encodedSummary: Data -} - -private struct RuntimeHistoryRow: Codable, FetchableRecord, PersistableRecord, TableRecord { - static let databaseTableName = "runtime_history_items" - - let storageID: String - let recordID: String - let threadID: String - let sequenceNumber: Int - let createdAt: Double - let kind: String - let turnID: String? - let isCompactionMarker: Bool - let isRedacted: Bool - let encodedRecord: Data -} - -private struct RuntimeHistoryRowsRequest { - let sql: String - let arguments: StatementArguments - - func execute(in db: Database) throws -> [RuntimeHistoryRow] { - try SQLRequest(sql: sql, arguments: arguments).fetchAll(db) - } -} - -private struct RuntimeHistoryExistenceQuery { - let sql: String - let arguments: StatementArguments - - func execute(in db: Database) throws -> Bool { - let row = try SQLRequest(sql: sql, arguments: arguments).fetchOne(db) - let exists: Bool? = row?[0] - return exists ?? false - } -} - -private struct RuntimeStructuredOutputRow: Codable, FetchableRecord, PersistableRecord, TableRecord { - static let databaseTableName = "runtime_structured_outputs" - - let outputID: String - let threadID: String - let formatName: String - let committedAt: Double - let encodedRecord: Data -} - -private struct RuntimeContextStateRow: Codable, FetchableRecord, PersistableRecord, TableRecord { - static let databaseTableName = "runtime_context_states" - - let threadID: String - let generation: Int - let encodedState: Data -} - -private struct RuntimeUserVersionQuery { - func execute(in db: Database) throws -> Int { - let row = try SQLRequest(sql: "PRAGMA user_version;").fetchOne(db) - return row?[0] ?? 0 - } -} - -private struct GRDBHistoryCursorPayload: Codable { - let version: Int - let threadID: String - let sequenceNumber: Int -} - -private extension AgentHistoryItem { - var threadID: String { - switch self { - case let .message(message): - message.threadID - case let .toolCall(record): - record.invocation.threadID - case let .toolResult(record): - record.threadID - case let .structuredOutput(record): - record.threadID - case let .approval(record): - record.request?.threadID ?? record.resolution?.threadID ?? "" - case let .systemEvent(record): - record.threadID - } - } } diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStoreRows.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStoreRows.swift new file mode 100644 index 0000000..59e522e --- /dev/null +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStoreRows.swift @@ -0,0 +1,120 @@ +import Foundation +import GRDB + +struct RuntimeThreadRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_threads" + + let threadID: String + let createdAt: Double + let updatedAt: Double + let status: String + let encodedThread: Data +} + +struct RuntimeThreadCountQuery { + func execute(in db: Database) throws -> Int { + let row = try SQLRequest( + sql: "SELECT COUNT(*) AS thread_count FROM \(RuntimeThreadRow.databaseTableName)" + ).fetchOne(db) + let count: Int? = row?["thread_count"] + return count ?? 0 + } +} + +struct RuntimeSummaryRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_summaries" + + let threadID: String + let createdAt: Double + let updatedAt: Double + let latestItemAt: Double? + let itemCount: Int? + let pendingStateKind: String? + let latestStructuredOutputFormatName: String? + let encodedSummary: Data +} + +struct RuntimeHistoryRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_history_items" + + let storageID: String + let recordID: String + let threadID: String + let sequenceNumber: Int + let createdAt: Double + let kind: String + let turnID: String? + let isCompactionMarker: Bool + let isRedacted: Bool + let encodedRecord: Data +} + +struct RuntimeHistoryRowsRequest { + let sql: String + let arguments: StatementArguments + + func execute(in db: Database) throws -> [RuntimeHistoryRow] { + try SQLRequest(sql: sql, arguments: arguments).fetchAll(db) + } +} + +struct RuntimeHistoryExistenceQuery { + let sql: String + let arguments: StatementArguments + + func execute(in db: Database) throws -> Bool { + let row = try SQLRequest(sql: sql, arguments: arguments).fetchOne(db) + let exists: Bool? = row?[0] + return exists ?? false + } +} + +struct RuntimeStructuredOutputRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_structured_outputs" + + let outputID: String + let threadID: String + let formatName: String + let committedAt: Double + let encodedRecord: Data +} + +struct RuntimeContextStateRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "runtime_context_states" + + let threadID: String + let generation: Int + let encodedState: Data +} + +struct RuntimeUserVersionQuery { + func execute(in db: Database) throws -> Int { + let row = try SQLRequest(sql: "PRAGMA user_version;").fetchOne(db) + return row?[0] ?? 0 + } +} + +struct GRDBHistoryCursorPayload: Codable { + let version: Int + let threadID: String + let sequenceNumber: Int +} + +extension AgentHistoryItem { + var threadID: String { + switch self { + case let .message(message): + message.threadID + case let .toolCall(record): + record.invocation.threadID + case let .toolResult(record): + record.threadID + case let .structuredOutput(record): + record.threadID + case let .approval(record): + record.request?.threadID ?? record.resolution?.threadID ?? "" + case let .systemEvent(record): + record.threadID + } + } +} diff --git a/Sources/CodexKit/Runtime/InMemoryRuntimeStateStore.swift b/Sources/CodexKit/Runtime/InMemoryRuntimeStateStore.swift new file mode 100644 index 0000000..84b1f9a --- /dev/null +++ b/Sources/CodexKit/Runtime/InMemoryRuntimeStateStore.swift @@ -0,0 +1,56 @@ +import Foundation + +public actor InMemoryRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { + private var state: StoredRuntimeState + + public init(initialState: StoredRuntimeState = .empty) { + state = initialState.normalized() + } + + public func loadState() async throws -> StoredRuntimeState { + state + } + + public func saveState(_ state: StoredRuntimeState) async throws { + self.state = state.normalized() + } + + public func prepare() async throws -> AgentStoreMetadata { + state = state.normalized() + return try await readMetadata() + } + + public func readMetadata() async throws -> AgentStoreMetadata { + AgentStoreMetadata( + logicalSchemaVersion: .v1, + storeSchemaVersion: 1, + capabilities: AgentStoreCapabilities( + supportsPushdownQueries: true, + supportsCrossThreadQueries: true, + supportsSorting: true, + supportsFiltering: true, + supportsMigrations: false + ), + storeKind: "InMemoryRuntimeStateStore" + ) + } + + public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { + try state.threadSummary(id: id) + } + + public func fetchThreadHistory( + id: String, + query: AgentHistoryQuery + ) async throws -> AgentThreadHistoryPage { + try state.threadHistoryPage(id: id, query: query) + } + + public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { + try state.threadSummary(id: id).latestStructuredOutputMetadata + } + + public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { + state.contextStateByThread[id] + } +} diff --git a/Sources/CodexKit/Runtime/RuntimeStateStore.swift b/Sources/CodexKit/Runtime/RuntimeStateStore.swift index 9184157..5237fc0 100644 --- a/Sources/CodexKit/Runtime/RuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/RuntimeStateStore.swift @@ -117,1032 +117,3 @@ public extension RuntimeStateStoring { try await saveState(updated) } } - -public actor InMemoryRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { - private var state: StoredRuntimeState - - public init(initialState: StoredRuntimeState = .empty) { - state = initialState.normalized() - } - - public func loadState() async throws -> StoredRuntimeState { - state - } - - public func saveState(_ state: StoredRuntimeState) async throws { - self.state = state.normalized() - } - - public func prepare() async throws -> AgentStoreMetadata { - state = state.normalized() - return try await readMetadata() - } - - public func readMetadata() async throws -> AgentStoreMetadata { - AgentStoreMetadata( - logicalSchemaVersion: .v1, - storeSchemaVersion: 1, - capabilities: AgentStoreCapabilities( - supportsPushdownQueries: true, - supportsCrossThreadQueries: true, - supportsSorting: true, - supportsFiltering: true, - supportsMigrations: false - ), - storeKind: "InMemoryRuntimeStateStore" - ) - } - - public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { - try state.threadSummary(id: id) - } - - public func fetchThreadHistory( - id: String, - query: AgentHistoryQuery - ) async throws -> AgentThreadHistoryPage { - try state.threadHistoryPage(id: id, query: query) - } - - public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { - try state.threadSummary(id: id).latestStructuredOutputMetadata - } - - public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { - state.contextStateByThread[id] - } -} - -public actor FileRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, AgentRuntimeQueryableStore { - private let url: URL - private let encoder = JSONEncoder() - private let decoder = JSONDecoder() - private let fileManager = FileManager.default - private let attachmentStore: RuntimeAttachmentStore - - public init(url: URL) { - self.url = url - let basename = url.deletingPathExtension().lastPathComponent - self.attachmentStore = RuntimeAttachmentStore( - rootURL: url.deletingLastPathComponent() - .appendingPathComponent("\(basename).codexkit-state", isDirectory: true) - .appendingPathComponent("attachments", isDirectory: true) - ) - } - - public func loadState() async throws -> StoredRuntimeState { - try loadNormalizedStateMigratingIfNeeded() - } - - public func saveState(_ state: StoredRuntimeState) async throws { - try persistLayout(for: state.normalized()) - } - - public func prepare() async throws -> AgentStoreMetadata { - _ = try loadNormalizedStateMigratingIfNeeded() - return try await readMetadata() - } - - public func readMetadata() async throws -> AgentStoreMetadata { - AgentStoreMetadata( - logicalSchemaVersion: .v1, - storeSchemaVersion: 1, - capabilities: AgentStoreCapabilities( - supportsPushdownQueries: false, - supportsCrossThreadQueries: false, - supportsSorting: true, - supportsFiltering: true, - supportsMigrations: true - ), - storeKind: "FileRuntimeStateStore" - ) - } - - public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { - if let manifest = try loadManifest() { - guard let thread = manifest.threads.first(where: { $0.id == id }) else { - throw AgentRuntimeError.threadNotFound(id) - } - return manifest.summariesByThread[id] - ?? StoredRuntimeState(threads: [thread]).threadSummaryFallback(for: thread) - } - - return try loadNormalizedStateMigratingIfNeeded().threadSummary(id: id) - } - - public func fetchThreadHistory( - id: String, - query: AgentHistoryQuery - ) async throws -> AgentThreadHistoryPage { - if let manifest = try loadManifest() { - guard manifest.threads.contains(where: { $0.id == id }) else { - throw AgentRuntimeError.threadNotFound(id) - } - - let history = try loadHistory(for: id) - let state = StoredRuntimeState( - threads: manifest.threads, - historyByThread: [id: history], - summariesByThread: manifest.summariesByThread, - contextStateByThread: manifest.contextStateByThread, - nextHistorySequenceByThread: manifest.nextHistorySequenceByThread - ) - return try state.threadHistoryPage(id: id, query: query) - } - - return try loadNormalizedStateMigratingIfNeeded().threadHistoryPage(id: id, query: query) - } - - public func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? { - let summary = try await fetchThreadSummary(id: id) - return summary.latestStructuredOutputMetadata - } - - public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { - if let manifest = try loadManifest() { - guard manifest.threads.contains(where: { $0.id == id }) else { - throw AgentRuntimeError.threadNotFound(id) - } - return manifest.contextStateByThread[id] - } - - return try loadNormalizedStateMigratingIfNeeded().contextStateByThread[id] - } - - private func loadNormalizedStateMigratingIfNeeded() throws -> StoredRuntimeState { - guard fileManager.fileExists(atPath: url.path) else { - return .empty - } - - if let manifest = try loadManifest() { - return try state(from: manifest) - } - - let data = try Data(contentsOf: url) - let legacy = try decoder.decode(StoredRuntimeState.self, from: data).normalized() - try persistLayout(for: legacy) - return legacy - } - - private func loadManifest() throws -> FileRuntimeStateManifest? { - guard fileManager.fileExists(atPath: url.path) else { - return nil - } - - let data = try Data(contentsOf: url) - return try? decoder.decode(FileRuntimeStateManifest.self, from: data) - } - - private func state(from manifest: FileRuntimeStateManifest) throws -> StoredRuntimeState { - var historyByThread: [String: [AgentHistoryRecord]] = [:] - for thread in manifest.threads { - historyByThread[thread.id] = try loadHistory(for: thread.id) - } - - return StoredRuntimeState( - threads: manifest.threads, - historyByThread: historyByThread, - summariesByThread: manifest.summariesByThread, - contextStateByThread: manifest.contextStateByThread, - nextHistorySequenceByThread: manifest.nextHistorySequenceByThread - ) - } - - private func loadHistory(for threadID: String) throws -> [AgentHistoryRecord] { - let historyURL = historyFileURL(for: threadID) - guard fileManager.fileExists(atPath: historyURL.path) else { - return [] - } - - let data = try Data(contentsOf: historyURL) - if let persisted = try? decoder.decode([PersistedAgentHistoryRecord].self, from: data) { - return try persisted.map { try $0.decode(using: attachmentStore) } - } - return try decoder.decode([AgentHistoryRecord].self, from: data) - } - - private func persistLayout(for state: StoredRuntimeState) throws { - let normalized = state.normalized() - let directory = url.deletingLastPathComponent() - if !directory.path.isEmpty { - try fileManager.createDirectory( - at: directory, - withIntermediateDirectories: true - ) - } - - try fileManager.createDirectory( - at: historyDirectoryURL, - withIntermediateDirectories: true - ) - try attachmentStore.reset() - - for thread in normalized.threads { - let historyURL = historyFileURL(for: thread.id) - let history = normalized.historyByThread[thread.id] ?? [] - let persisted = try history.map { - try PersistedAgentHistoryRecord( - record: $0, - attachmentStore: attachmentStore - ) - } - let data = try encoder.encode(persisted) - try data.write(to: historyURL, options: .atomic) - } - - let manifest = FileRuntimeStateManifest( - threads: normalized.threads, - summariesByThread: normalized.summariesByThread, - contextStateByThread: normalized.contextStateByThread, - nextHistorySequenceByThread: normalized.nextHistorySequenceByThread - ) - let manifestData = try encoder.encode(manifest) - try manifestData.write(to: url, options: .atomic) - } - - private var historyDirectoryURL: URL { - let basename = url.deletingPathExtension().lastPathComponent - return url.deletingLastPathComponent() - .appendingPathComponent("\(basename).codexkit-state", isDirectory: true) - .appendingPathComponent("threads", isDirectory: true) - } - - private func historyFileURL(for threadID: String) -> URL { - historyDirectoryURL.appendingPathComponent(safeThreadFilename(threadID)).appendingPathExtension("json") - } - - private func safeThreadFilename(_ threadID: String) -> String { - threadID.addingPercentEncoding(withAllowedCharacters: .alphanumerics) ?? threadID - } -} - -private struct FileRuntimeStateManifest: Codable { - let storageVersion: Int - let threads: [AgentThread] - let summariesByThread: [String: AgentThreadSummary] - let contextStateByThread: [String: AgentThreadContextState] - let nextHistorySequenceByThread: [String: Int] - - init( - threads: [AgentThread], - summariesByThread: [String: AgentThreadSummary], - contextStateByThread: [String: AgentThreadContextState], - nextHistorySequenceByThread: [String: Int] - ) { - self.storageVersion = 1 - self.threads = threads - self.summariesByThread = summariesByThread - self.contextStateByThread = contextStateByThread - self.nextHistorySequenceByThread = nextHistorySequenceByThread - } -} - -extension StoredRuntimeState { - func normalized() -> StoredRuntimeState { - let sortedThreads = threads.sorted { lhs, rhs in - if lhs.updatedAt == rhs.updatedAt { - return lhs.id < rhs.id - } - return lhs.updatedAt > rhs.updatedAt - } - - var normalizedHistory = historyByThread - .mapValues { records in - records.sorted { lhs, rhs in - if lhs.sequenceNumber == rhs.sequenceNumber { - return lhs.createdAt < rhs.createdAt - } - return lhs.sequenceNumber < rhs.sequenceNumber - } - } - - for (threadID, messages) in messagesByThread where normalizedHistory[threadID]?.isEmpty != false { - normalizedHistory[threadID] = Self.syntheticHistory(from: messages) - } - - let normalizedMessages: [String: [AgentMessage]] = normalizedHistory.mapValues { records in - records.compactMap { record -> AgentMessage? in - guard case let .message(message) = record.item else { - return nil - } - return message - } - } - - var normalizedNextSequence = nextHistorySequenceByThread - for thread in sortedThreads { - let history = normalizedHistory[thread.id] ?? [] - let nextSequence = (history.last?.sequenceNumber ?? 0) + 1 - normalizedNextSequence[thread.id] = max(normalizedNextSequence[thread.id] ?? 0, nextSequence) - } - - var normalizedSummaries: [String: AgentThreadSummary] = [:] - var normalizedContextState = contextStateByThread - for thread in sortedThreads { - let history = normalizedHistory[thread.id] ?? [] - normalizedSummaries[thread.id] = Self.rebuildSummary( - for: thread, - history: history, - existing: summariesByThread[thread.id] - ) - if let existing = normalizedContextState[thread.id] { - normalizedContextState[thread.id] = AgentThreadContextState( - threadID: thread.id, - effectiveMessages: existing.effectiveMessages, - generation: existing.generation, - lastCompactedAt: existing.lastCompactedAt, - lastCompactionReason: existing.lastCompactionReason, - latestMarkerID: existing.latestMarkerID - ) - } - } - - return StoredRuntimeState( - threads: sortedThreads, - messagesByThread: normalizedMessages, - historyByThread: normalizedHistory, - summariesByThread: normalizedSummaries, - contextStateByThread: normalizedContextState, - nextHistorySequenceByThread: normalizedNextSequence, - normalizeState: false - ) - } - - func threadSummary(id: String) throws -> AgentThreadSummary { - guard let thread = threads.first(where: { $0.id == id }) else { - throw AgentRuntimeError.threadNotFound(id) - } - - return summariesByThread[id] ?? threadSummaryFallback(for: thread) - } - - func threadSummaryFallback(for thread: AgentThread) -> AgentThreadSummary { - Self.rebuildSummary( - for: thread, - history: historyByThread[thread.id] ?? [], - existing: summariesByThread[thread.id] - ) - } - - func threadHistoryPage( - id: String, - query: AgentHistoryQuery - ) throws -> AgentThreadHistoryPage { - guard threads.contains(where: { $0.id == id }) else { - throw AgentRuntimeError.threadNotFound(id) - } - - let limit = max(1, query.limit) - let filter = query.filter ?? AgentHistoryFilter() - let records = (historyByThread[id] ?? []).filter { filter.matches($0.item) } - let anchor = try query.cursor?.decodedSequenceNumber(expectedThreadID: id) - - switch query.direction { - case .backward: - let endIndex = records.endIndexForBackward(anchor: anchor) - let startIndex = max(0, endIndex - limit) - let pageRecords = Array(records[startIndex ..< endIndex]) - let hasMoreBefore = startIndex > 0 - let hasMoreAfter = endIndex < records.count - return AgentThreadHistoryPage( - threadID: id, - items: pageRecords.map(\.item), - nextCursor: hasMoreBefore ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, - previousCursor: hasMoreAfter ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, - hasMoreBefore: hasMoreBefore, - hasMoreAfter: hasMoreAfter - ) - - case .forward: - let startIndex = records.startIndexForForward(anchor: anchor) - let endIndex = min(records.count, startIndex + limit) - let pageRecords = Array(records[startIndex ..< endIndex]) - let hasMoreBefore = startIndex > 0 - let hasMoreAfter = endIndex < records.count - return AgentThreadHistoryPage( - threadID: id, - items: pageRecords.map(\.item), - nextCursor: hasMoreAfter ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, - previousCursor: hasMoreBefore ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, - hasMoreBefore: hasMoreBefore, - hasMoreAfter: hasMoreAfter - ) - } - } - - func applying(_ operations: [AgentStoreWriteOperation]) throws -> StoredRuntimeState { - var updated = self - - for operation in operations { - switch operation { - case let .upsertThread(thread): - if let index = updated.threads.firstIndex(where: { $0.id == thread.id }) { - updated.threads[index] = thread - } else { - updated.threads.append(thread) - } - - case let .upsertSummary(threadID, summary): - updated.summariesByThread[threadID] = summary - - case let .appendHistoryItems(threadID, items): - updated.historyByThread[threadID, default: []].append(contentsOf: items) - let nextSequence = (updated.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 - updated.nextHistorySequenceByThread[threadID] = nextSequence - - case let .appendCompactionMarker(threadID, marker): - updated.historyByThread[threadID, default: []].append(marker) - let nextSequence = (updated.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 - updated.nextHistorySequenceByThread[threadID] = nextSequence - - case let .upsertThreadContextState(threadID, state): - updated.contextStateByThread[threadID] = state - - case let .deleteThreadContextState(threadID): - updated.contextStateByThread.removeValue(forKey: threadID) - - case let .setPendingState(threadID, state): - if let thread = updated.threads.first(where: { $0.id == threadID }) { - let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) - updated.summariesByThread[threadID] = AgentThreadSummary( - threadID: current.threadID, - createdAt: current.createdAt, - updatedAt: current.updatedAt, - latestItemAt: current.latestItemAt, - itemCount: current.itemCount, - latestAssistantMessagePreview: current.latestAssistantMessagePreview, - latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, - latestPartialStructuredOutput: current.latestPartialStructuredOutput, - latestToolState: current.latestToolState, - latestTurnStatus: current.latestTurnStatus, - pendingState: state - ) - } - - case let .setPartialStructuredSnapshot(threadID, snapshot): - if let thread = updated.threads.first(where: { $0.id == threadID }) { - let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) - updated.summariesByThread[threadID] = AgentThreadSummary( - threadID: current.threadID, - createdAt: current.createdAt, - updatedAt: current.updatedAt, - latestItemAt: current.latestItemAt, - itemCount: current.itemCount, - latestAssistantMessagePreview: current.latestAssistantMessagePreview, - latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, - latestPartialStructuredOutput: snapshot, - latestToolState: current.latestToolState, - latestTurnStatus: current.latestTurnStatus, - pendingState: current.pendingState - ) - } - - case let .upsertToolSession(threadID, session): - if let thread = updated.threads.first(where: { $0.id == threadID }) { - let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) - let latestToolState = AgentLatestToolState( - invocationID: session.invocationID, - turnID: session.turnID, - toolName: session.toolName, - status: .running, - success: nil, - sessionID: session.sessionID, - sessionStatus: session.sessionStatus, - metadata: session.metadata, - resumable: session.resumable, - updatedAt: session.updatedAt, - resultPreview: nil - ) - updated.summariesByThread[threadID] = AgentThreadSummary( - threadID: current.threadID, - createdAt: current.createdAt, - updatedAt: current.updatedAt, - latestItemAt: current.latestItemAt, - itemCount: current.itemCount, - latestAssistantMessagePreview: current.latestAssistantMessagePreview, - latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, - latestPartialStructuredOutput: current.latestPartialStructuredOutput, - latestToolState: latestToolState, - latestTurnStatus: current.latestTurnStatus, - pendingState: .toolWait( - AgentPendingToolWaitState( - invocationID: session.invocationID, - turnID: session.turnID, - toolName: session.toolName, - startedAt: session.updatedAt, - sessionID: session.sessionID, - sessionStatus: session.sessionStatus, - metadata: session.metadata, - resumable: session.resumable - ) - ) - ) - } - - case let .redactHistoryItems(threadID, itemIDs, reason): - guard !itemIDs.isEmpty else { - continue - } - updated.historyByThread[threadID] = updated.historyByThread[threadID]?.map { record in - guard itemIDs.contains(record.id) else { - return record - } - return record.redacted(reason: reason) - } - - case let .deleteThread(threadID): - updated.threads.removeAll { $0.id == threadID } - updated.messagesByThread.removeValue(forKey: threadID) - updated.historyByThread.removeValue(forKey: threadID) - updated.summariesByThread.removeValue(forKey: threadID) - updated.contextStateByThread.removeValue(forKey: threadID) - updated.nextHistorySequenceByThread.removeValue(forKey: threadID) - } - } - - return updated.normalized() - } - - func execute(_ query: HistoryItemsQuery) throws -> AgentHistoryQueryResult { - guard threads.contains(where: { $0.id == query.threadID }) else { - return AgentHistoryQueryResult( - threadID: query.threadID, - records: [], - nextCursor: nil, - previousCursor: nil, - hasMoreBefore: false, - hasMoreAfter: false - ) - } - - var records = historyByThread[query.threadID] ?? [] - if let kinds = query.kinds { - records = records.filter { kinds.contains($0.item.kind) } - } - if let createdAtRange = query.createdAtRange { - records = records.filter { createdAtRange.contains($0.createdAt) } - } - if let turnID = query.turnID { - records = records.filter { $0.item.turnID == turnID } - } - if !query.includeRedacted { - records = records.filter { $0.redaction == nil } - } - if !query.includeCompactionEvents { - records = records.filter { !$0.item.isCompactionMarker } - } - - records = sort(records, using: query.sort) - let page = try page(records, threadID: query.threadID, with: query.page, sort: query.sort) - return page - } - - func execute(_ query: ThreadMetadataQuery) -> [AgentThread] { - var filtered = threads - if let threadIDs = query.threadIDs { - filtered = filtered.filter { threadIDs.contains($0.id) } - } - if let statuses = query.statuses { - filtered = filtered.filter { statuses.contains($0.status) } - } - if let updatedAtRange = query.updatedAtRange { - filtered = filtered.filter { updatedAtRange.contains($0.updatedAt) } - } - filtered = sort(filtered, using: query.sort) - if let limit = query.limit { - filtered = Array(filtered.prefix(max(0, limit))) - } - return filtered - } - - func execute(_ query: PendingStateQuery) -> [AgentPendingStateRecord] { - var records = summariesByThread.compactMap { threadID, summary -> AgentPendingStateRecord? in - guard let pendingState = summary.pendingState else { - return nil - } - return AgentPendingStateRecord( - threadID: threadID, - pendingState: pendingState, - updatedAt: summary.updatedAt - ) - } - - if let threadIDs = query.threadIDs { - records = records.filter { threadIDs.contains($0.threadID) } - } - if let kinds = query.kinds { - records = records.filter { kinds.contains($0.pendingState.kind) } - } - records = sort(records, using: query.sort) - if let limit = query.limit { - records = Array(records.prefix(max(0, limit))) - } - return records - } - - func execute(_ query: StructuredOutputQuery) -> [AgentStructuredOutputRecord] { - var records = historyByThread.values - .flatMap { $0 } - .compactMap { record -> AgentStructuredOutputRecord? in - switch record.item { - case let .structuredOutput(structuredOutput): - return structuredOutput - - case let .message(message): - guard let metadata = message.structuredOutput else { - return nil - } - return AgentStructuredOutputRecord( - threadID: message.threadID, - turnID: "", - messageID: message.id, - metadata: metadata, - committedAt: message.createdAt - ) - - default: - return nil - } - } - - if let threadIDs = query.threadIDs { - records = records.filter { threadIDs.contains($0.threadID) } - } - if let formatNames = query.formatNames { - records = records.filter { formatNames.contains($0.metadata.formatName) } - } - - records = sort(records, using: query.sort) - - if query.latestOnly { - var seen = Set() - records = records.filter { record in - seen.insert(record.threadID).inserted - } - } - - if let limit = query.limit { - records = Array(records.prefix(max(0, limit))) - } - return records - } - - func execute(_ query: ThreadSnapshotQuery) -> [AgentThreadSnapshot] { - var snapshots = threads.compactMap { thread -> AgentThreadSnapshot? in - guard query.threadIDs?.contains(thread.id) ?? true else { - return nil - } - let summary = summariesByThread[thread.id] ?? threadSummaryFallback(for: thread) - return summary.snapshot - } - snapshots = sort(snapshots, using: query.sort) - if let limit = query.limit { - snapshots = Array(snapshots.prefix(max(0, limit))) - } - return snapshots - } - - func execute(_ query: ThreadContextStateQuery) -> [AgentThreadContextState] { - var records = Array(contextStateByThread.values) - if let threadIDs = query.threadIDs { - records = records.filter { threadIDs.contains($0.threadID) } - } - records.sort { lhs, rhs in - if lhs.generation == rhs.generation { - return lhs.threadID < rhs.threadID - } - return lhs.generation > rhs.generation - } - if let limit = query.limit { - records = Array(records.prefix(max(0, limit))) - } - return records - } - - private static func syntheticHistory(from messages: [AgentMessage]) -> [AgentHistoryRecord] { - let orderedMessages = messages.enumerated().sorted { lhs, rhs in - let left = lhs.element - let right = rhs.element - if left.createdAt == right.createdAt { - return lhs.offset < rhs.offset - } - return left.createdAt < right.createdAt - } - - return orderedMessages.enumerated().map { index, pair in - AgentHistoryRecord( - sequenceNumber: index + 1, - createdAt: pair.element.createdAt, - item: .message(pair.element) - ) - } - } - - private static func rebuildSummary( - for thread: AgentThread, - history: [AgentHistoryRecord], - existing: AgentThreadSummary? - ) -> AgentThreadSummary { - var latestAssistantMessagePreview = existing?.latestAssistantMessagePreview - var latestStructuredOutputMetadata = existing?.latestStructuredOutputMetadata - var latestToolState = existing?.latestToolState - var latestTurnStatus = existing?.latestTurnStatus - let latestPartialStructuredOutput = existing?.latestPartialStructuredOutput - let pendingState = existing?.pendingState - - for record in history { - switch record.item { - case let .message(message): - if message.role == .assistant { - latestAssistantMessagePreview = message.displayText - if let structuredOutput = message.structuredOutput { - latestStructuredOutputMetadata = structuredOutput - } - } - - case let .toolCall(toolCall): - latestToolState = AgentLatestToolState( - invocationID: toolCall.invocation.id, - turnID: toolCall.invocation.turnID, - toolName: toolCall.invocation.toolName, - status: .waiting, - updatedAt: toolCall.requestedAt - ) - - case let .toolResult(toolResult): - latestToolState = Self.latestToolState(from: toolResult) - - case let .structuredOutput(structuredOutput): - latestStructuredOutputMetadata = structuredOutput.metadata - - case .approval: - break - - case let .systemEvent(systemEvent): - switch systemEvent.type { - case .turnStarted: - latestTurnStatus = .running - case .turnCompleted: - latestTurnStatus = .completed - case .turnFailed: - latestTurnStatus = .failed - case .threadCreated, .threadResumed, .threadStatusChanged, .contextCompacted: - break - } - } - } - - return AgentThreadSummary( - threadID: thread.id, - createdAt: thread.createdAt, - updatedAt: thread.updatedAt, - latestItemAt: history.last?.createdAt, - itemCount: history.count, - latestAssistantMessagePreview: latestAssistantMessagePreview, - latestStructuredOutputMetadata: latestStructuredOutputMetadata, - latestPartialStructuredOutput: latestPartialStructuredOutput, - latestToolState: latestToolState, - latestTurnStatus: latestTurnStatus, - pendingState: pendingState - ) - } - - private static func latestToolState(from toolResult: AgentToolResultRecord) -> AgentLatestToolState { - let preview = toolResult.result.primaryText - let session = toolResult.result.session - let status: AgentToolSessionStatus - if toolResult.result.errorMessage == "Tool execution was denied by the user." { - status = .denied - } else if let session, !session.isTerminal { - status = .running - } else if toolResult.result.success { - status = .completed - } else { - status = .failed - } - - return AgentLatestToolState( - invocationID: toolResult.result.invocationID, - turnID: toolResult.turnID, - toolName: toolResult.result.toolName, - status: status, - success: toolResult.result.success, - sessionID: session?.sessionID, - sessionStatus: session?.status, - metadata: session?.metadata, - resumable: session?.resumable ?? false, - updatedAt: toolResult.completedAt, - resultPreview: preview - ) - } -} - -private extension Array where Element == AgentHistoryRecord { - func endIndexForBackward(anchor: Int?) -> Int { - guard let anchor else { - return count - } - - return firstIndex(where: { $0.sequenceNumber >= anchor }) ?? count - } - - func startIndexForForward(anchor: Int?) -> Int { - guard let anchor else { - return 0 - } - - return firstIndex(where: { $0.sequenceNumber > anchor }) ?? count - } -} - -private extension StoredRuntimeState { - func sort( - _ records: [AgentHistoryRecord], - using sort: AgentHistorySort - ) -> [AgentHistoryRecord] { - records.sorted { lhs, rhs in - switch sort { - case let .sequence(order): - if lhs.sequenceNumber == rhs.sequenceNumber { - return lhs.createdAt < rhs.createdAt - } - return order == .ascending - ? lhs.sequenceNumber < rhs.sequenceNumber - : lhs.sequenceNumber > rhs.sequenceNumber - - case let .createdAt(order): - if lhs.createdAt == rhs.createdAt { - return lhs.sequenceNumber < rhs.sequenceNumber - } - return order == .ascending - ? lhs.createdAt < rhs.createdAt - : lhs.createdAt > rhs.createdAt - } - } - } - - func page( - _ records: [AgentHistoryRecord], - threadID: String, - with page: AgentQueryPage?, - sort: AgentHistorySort - ) throws -> AgentHistoryQueryResult { - guard let page else { - let ordered = normalizePageRecords(records, sort: sort) - return AgentHistoryQueryResult( - threadID: threadID, - records: ordered, - nextCursor: nil, - previousCursor: nil, - hasMoreBefore: false, - hasMoreAfter: false - ) - } - - let limit = max(1, page.limit) - let anchor = try page.cursor?.decodedSequenceNumber(expectedThreadID: threadID) - let ascending = normalizePageRecords(records, sort: sort) - let endIndex = if let anchor { - ascending.firstIndex(where: { $0.sequenceNumber >= anchor }) ?? ascending.count - } else { - ascending.count - } - let startIndex = max(0, endIndex - limit) - let sliced = Array(ascending[startIndex ..< endIndex]) - return AgentHistoryQueryResult( - threadID: threadID, - records: sliced, - nextCursor: startIndex > 0 ? AgentHistoryCursor(threadID: threadID, sequenceNumber: sliced.first?.sequenceNumber) : nil, - previousCursor: endIndex < ascending.count ? AgentHistoryCursor(threadID: threadID, sequenceNumber: sliced.last?.sequenceNumber) : nil, - hasMoreBefore: startIndex > 0, - hasMoreAfter: endIndex < ascending.count - ) - } - - func normalizePageRecords( - _ records: [AgentHistoryRecord], - sort: AgentHistorySort - ) -> [AgentHistoryRecord] { - switch sort { - case .sequence(.ascending), .createdAt(.ascending): - return records - case .sequence(.descending), .createdAt(.descending): - return records.reversed() - } - } - - func sort( - _ threads: [AgentThread], - using sort: AgentThreadMetadataSort - ) -> [AgentThread] { - threads.sorted { lhs, rhs in - switch sort { - case let .updatedAt(order): - if lhs.updatedAt == rhs.updatedAt { - return lhs.id < rhs.id - } - return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt - case let .createdAt(order): - if lhs.createdAt == rhs.createdAt { - return lhs.id < rhs.id - } - return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt - } - } - } - - func sort( - _ records: [AgentPendingStateRecord], - using sort: AgentPendingStateSort - ) -> [AgentPendingStateRecord] { - records.sorted { lhs, rhs in - switch sort { - case let .updatedAt(order): - if lhs.updatedAt == rhs.updatedAt { - return lhs.threadID < rhs.threadID - } - return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt - } - } - } - - func sort( - _ records: [AgentStructuredOutputRecord], - using sort: AgentStructuredOutputSort - ) -> [AgentStructuredOutputRecord] { - records.sorted { lhs, rhs in - switch sort { - case let .committedAt(order): - if lhs.committedAt == rhs.committedAt { - return lhs.threadID < rhs.threadID - } - return order == .ascending ? lhs.committedAt < rhs.committedAt : lhs.committedAt > rhs.committedAt - } - } - } - - func sort( - _ records: [AgentThreadSnapshot], - using sort: AgentThreadSnapshotSort - ) -> [AgentThreadSnapshot] { - records.sorted { lhs, rhs in - switch sort { - case let .updatedAt(order): - if lhs.updatedAt == rhs.updatedAt { - return lhs.threadID < rhs.threadID - } - return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt - case let .createdAt(order): - if lhs.createdAt == rhs.createdAt { - return lhs.threadID < rhs.threadID - } - return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt - } - } - } -} - -private struct AgentHistoryCursorPayload: Codable { - let version: Int - let threadID: String - let sequenceNumber: Int -} - -private extension AgentHistoryCursor { - init(threadID: String, sequenceNumber: Int?) { - guard let sequenceNumber else { - self.init(rawValue: "") - return - } - - let payload = AgentHistoryCursorPayload( - version: 1, - threadID: threadID, - sequenceNumber: sequenceNumber - ) - let data = (try? JSONEncoder().encode(payload)) ?? Data() - let base64 = data.base64EncodedString() - .replacingOccurrences(of: "+", with: "-") - .replacingOccurrences(of: "/", with: "_") - .replacingOccurrences(of: "=", with: "") - self.init(rawValue: base64) - } - - func decodedSequenceNumber(expectedThreadID: String) throws -> Int { - let padded = rawValue - .replacingOccurrences(of: "-", with: "+") - .replacingOccurrences(of: "_", with: "/") - let remainder = padded.count % 4 - let adjusted = padded + String(repeating: "=", count: remainder == 0 ? 0 : 4 - remainder) - - guard let data = Data(base64Encoded: adjusted) else { - throw AgentRuntimeError.invalidHistoryCursor() - } - - let payload = try JSONDecoder().decode(AgentHistoryCursorPayload.self, from: data) - guard payload.threadID == expectedThreadID else { - throw AgentRuntimeError.invalidHistoryCursor() - } - return payload.sequenceNumber - } -} diff --git a/Sources/CodexKit/Runtime/StoredRuntimeState+Execution.swift b/Sources/CodexKit/Runtime/StoredRuntimeState+Execution.swift new file mode 100644 index 0000000..cad2e99 --- /dev/null +++ b/Sources/CodexKit/Runtime/StoredRuntimeState+Execution.swift @@ -0,0 +1,368 @@ +import Foundation + +extension StoredRuntimeState { + func execute(_ query: HistoryItemsQuery) throws -> AgentHistoryQueryResult { + guard threads.contains(where: { $0.id == query.threadID }) else { + return AgentHistoryQueryResult( + threadID: query.threadID, + records: [], + nextCursor: nil, + previousCursor: nil, + hasMoreBefore: false, + hasMoreAfter: false + ) + } + + var records = historyByThread[query.threadID] ?? [] + if let kinds = query.kinds { + records = records.filter { kinds.contains($0.item.kind) } + } + if let createdAtRange = query.createdAtRange { + records = records.filter { createdAtRange.contains($0.createdAt) } + } + if let turnID = query.turnID { + records = records.filter { $0.item.turnID == turnID } + } + if !query.includeRedacted { + records = records.filter { $0.redaction == nil } + } + if !query.includeCompactionEvents { + records = records.filter { !$0.item.isCompactionMarker } + } + + records = sort(records, using: query.sort) + return try page(records, threadID: query.threadID, with: query.page, sort: query.sort) + } + + func execute(_ query: ThreadMetadataQuery) -> [AgentThread] { + var filtered = threads + if let threadIDs = query.threadIDs { + filtered = filtered.filter { threadIDs.contains($0.id) } + } + if let statuses = query.statuses { + filtered = filtered.filter { statuses.contains($0.status) } + } + if let updatedAtRange = query.updatedAtRange { + filtered = filtered.filter { updatedAtRange.contains($0.updatedAt) } + } + filtered = sort(filtered, using: query.sort) + if let limit = query.limit { + filtered = Array(filtered.prefix(max(0, limit))) + } + return filtered + } + + func execute(_ query: PendingStateQuery) -> [AgentPendingStateRecord] { + var records = summariesByThread.compactMap { threadID, summary -> AgentPendingStateRecord? in + guard let pendingState = summary.pendingState else { + return nil + } + return AgentPendingStateRecord( + threadID: threadID, + pendingState: pendingState, + updatedAt: summary.updatedAt + ) + } + + if let threadIDs = query.threadIDs { + records = records.filter { threadIDs.contains($0.threadID) } + } + if let kinds = query.kinds { + records = records.filter { kinds.contains($0.pendingState.kind) } + } + records = sort(records, using: query.sort) + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } + + func execute(_ query: StructuredOutputQuery) -> [AgentStructuredOutputRecord] { + var records = historyByThread.values + .flatMap { $0 } + .compactMap { record -> AgentStructuredOutputRecord? in + switch record.item { + case let .structuredOutput(structuredOutput): + return structuredOutput + + case let .message(message): + guard let metadata = message.structuredOutput else { + return nil + } + return AgentStructuredOutputRecord( + threadID: message.threadID, + turnID: "", + messageID: message.id, + metadata: metadata, + committedAt: message.createdAt + ) + + default: + return nil + } + } + + if let threadIDs = query.threadIDs { + records = records.filter { threadIDs.contains($0.threadID) } + } + if let formatNames = query.formatNames { + records = records.filter { formatNames.contains($0.metadata.formatName) } + } + + records = sort(records, using: query.sort) + + if query.latestOnly { + var seen = Set() + records = records.filter { record in + seen.insert(record.threadID).inserted + } + } + + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } + + func execute(_ query: ThreadSnapshotQuery) -> [AgentThreadSnapshot] { + var snapshots = threads.compactMap { thread -> AgentThreadSnapshot? in + guard query.threadIDs?.contains(thread.id) ?? true else { + return nil + } + let summary = summariesByThread[thread.id] ?? threadSummaryFallback(for: thread) + return summary.snapshot + } + snapshots = sort(snapshots, using: query.sort) + if let limit = query.limit { + snapshots = Array(snapshots.prefix(max(0, limit))) + } + return snapshots + } + + func execute(_ query: ThreadContextStateQuery) -> [AgentThreadContextState] { + var records = Array(contextStateByThread.values) + if let threadIDs = query.threadIDs { + records = records.filter { threadIDs.contains($0.threadID) } + } + records.sort { lhs, rhs in + if lhs.generation == rhs.generation { + return lhs.threadID < rhs.threadID + } + return lhs.generation > rhs.generation + } + if let limit = query.limit { + records = Array(records.prefix(max(0, limit))) + } + return records + } +} + +extension Array where Element == AgentHistoryRecord { + func endIndexForBackward(anchor: Int?) -> Int { + guard let anchor else { + return count + } + + return firstIndex(where: { $0.sequenceNumber >= anchor }) ?? count + } + + func startIndexForForward(anchor: Int?) -> Int { + guard let anchor else { + return 0 + } + + return firstIndex(where: { $0.sequenceNumber > anchor }) ?? count + } +} + +private extension StoredRuntimeState { + func sort( + _ records: [AgentHistoryRecord], + using sort: AgentHistorySort + ) -> [AgentHistoryRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .sequence(order): + if lhs.sequenceNumber == rhs.sequenceNumber { + return lhs.createdAt < rhs.createdAt + } + return order == .ascending + ? lhs.sequenceNumber < rhs.sequenceNumber + : lhs.sequenceNumber > rhs.sequenceNumber + + case let .createdAt(order): + if lhs.createdAt == rhs.createdAt { + return lhs.sequenceNumber < rhs.sequenceNumber + } + return order == .ascending + ? lhs.createdAt < rhs.createdAt + : lhs.createdAt > rhs.createdAt + } + } + } + + func page( + _ records: [AgentHistoryRecord], + threadID: String, + with page: AgentQueryPage?, + sort: AgentHistorySort + ) throws -> AgentHistoryQueryResult { + guard let page else { + let ordered = normalizePageRecords(records, sort: sort) + return AgentHistoryQueryResult( + threadID: threadID, + records: ordered, + nextCursor: nil, + previousCursor: nil, + hasMoreBefore: false, + hasMoreAfter: false + ) + } + + let limit = max(1, page.limit) + let anchor = try page.cursor?.decodedSequenceNumber(expectedThreadID: threadID) + let ascending = normalizePageRecords(records, sort: sort) + let endIndex = if let anchor { + ascending.firstIndex(where: { $0.sequenceNumber >= anchor }) ?? ascending.count + } else { + ascending.count + } + let startIndex = max(0, endIndex - limit) + let sliced = Array(ascending[startIndex ..< endIndex]) + return AgentHistoryQueryResult( + threadID: threadID, + records: sliced, + nextCursor: startIndex > 0 ? AgentHistoryCursor(threadID: threadID, sequenceNumber: sliced.first?.sequenceNumber) : nil, + previousCursor: endIndex < ascending.count ? AgentHistoryCursor(threadID: threadID, sequenceNumber: sliced.last?.sequenceNumber) : nil, + hasMoreBefore: startIndex > 0, + hasMoreAfter: endIndex < ascending.count + ) + } + + func normalizePageRecords( + _ records: [AgentHistoryRecord], + sort: AgentHistorySort + ) -> [AgentHistoryRecord] { + switch sort { + case .sequence(.ascending), .createdAt(.ascending): + return records + case .sequence(.descending), .createdAt(.descending): + return records.reversed() + } + } + + func sort( + _ threads: [AgentThread], + using sort: AgentThreadMetadataSort + ) -> [AgentThread] { + threads.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.id < rhs.id + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + case let .createdAt(order): + if lhs.createdAt == rhs.createdAt { + return lhs.id < rhs.id + } + return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt + } + } + } + + func sort( + _ records: [AgentPendingStateRecord], + using sort: AgentPendingStateSort + ) -> [AgentPendingStateRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + } + } + } + + func sort( + _ records: [AgentStructuredOutputRecord], + using sort: AgentStructuredOutputSort + ) -> [AgentStructuredOutputRecord] { + records.sorted { lhs, rhs in + switch sort { + case let .committedAt(order): + if lhs.committedAt == rhs.committedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.committedAt < rhs.committedAt : lhs.committedAt > rhs.committedAt + } + } + } + + func sort( + _ records: [AgentThreadSnapshot], + using sort: AgentThreadSnapshotSort + ) -> [AgentThreadSnapshot] { + records.sorted { lhs, rhs in + switch sort { + case let .updatedAt(order): + if lhs.updatedAt == rhs.updatedAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.updatedAt < rhs.updatedAt : lhs.updatedAt > rhs.updatedAt + case let .createdAt(order): + if lhs.createdAt == rhs.createdAt { + return lhs.threadID < rhs.threadID + } + return order == .ascending ? lhs.createdAt < rhs.createdAt : lhs.createdAt > rhs.createdAt + } + } + } +} + +struct AgentHistoryCursorPayload: Codable { + let version: Int + let threadID: String + let sequenceNumber: Int +} + +extension AgentHistoryCursor { + init(threadID: String, sequenceNumber: Int?) { + guard let sequenceNumber else { + self.init(rawValue: "") + return + } + + let payload = AgentHistoryCursorPayload( + version: 1, + threadID: threadID, + sequenceNumber: sequenceNumber + ) + let data = (try? JSONEncoder().encode(payload)) ?? Data() + let base64 = data.base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + self.init(rawValue: base64) + } + + func decodedSequenceNumber(expectedThreadID: String) throws -> Int { + let padded = rawValue + .replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let remainder = padded.count % 4 + let adjusted = padded + String(repeating: "=", count: remainder == 0 ? 0 : 4 - remainder) + + guard let data = Data(base64Encoded: adjusted) else { + throw AgentRuntimeError.invalidHistoryCursor() + } + + let payload = try JSONDecoder().decode(AgentHistoryCursorPayload.self, from: data) + guard payload.threadID == expectedThreadID else { + throw AgentRuntimeError.invalidHistoryCursor() + } + return payload.sequenceNumber + } +} diff --git a/Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift b/Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift new file mode 100644 index 0000000..f6302f9 --- /dev/null +++ b/Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift @@ -0,0 +1,385 @@ +import Foundation + +extension StoredRuntimeState { + func normalized() -> StoredRuntimeState { + let sortedThreads = threads.sorted { lhs, rhs in + if lhs.updatedAt == rhs.updatedAt { + return lhs.id < rhs.id + } + return lhs.updatedAt > rhs.updatedAt + } + + var normalizedHistory = historyByThread + .mapValues { records in + records.sorted { lhs, rhs in + if lhs.sequenceNumber == rhs.sequenceNumber { + return lhs.createdAt < rhs.createdAt + } + return lhs.sequenceNumber < rhs.sequenceNumber + } + } + + for (threadID, messages) in messagesByThread where normalizedHistory[threadID]?.isEmpty != false { + normalizedHistory[threadID] = Self.syntheticHistory(from: messages) + } + + let normalizedMessages: [String: [AgentMessage]] = normalizedHistory.mapValues { records in + records.compactMap { record -> AgentMessage? in + guard case let .message(message) = record.item else { + return nil + } + return message + } + } + + var normalizedNextSequence = nextHistorySequenceByThread + for thread in sortedThreads { + let history = normalizedHistory[thread.id] ?? [] + let nextSequence = (history.last?.sequenceNumber ?? 0) + 1 + normalizedNextSequence[thread.id] = max(normalizedNextSequence[thread.id] ?? 0, nextSequence) + } + + var normalizedSummaries: [String: AgentThreadSummary] = [:] + var normalizedContextState = contextStateByThread + for thread in sortedThreads { + let history = normalizedHistory[thread.id] ?? [] + normalizedSummaries[thread.id] = Self.rebuildSummary( + for: thread, + history: history, + existing: summariesByThread[thread.id] + ) + if let existing = normalizedContextState[thread.id] { + normalizedContextState[thread.id] = AgentThreadContextState( + threadID: thread.id, + effectiveMessages: existing.effectiveMessages, + generation: existing.generation, + lastCompactedAt: existing.lastCompactedAt, + lastCompactionReason: existing.lastCompactionReason, + latestMarkerID: existing.latestMarkerID + ) + } + } + + return StoredRuntimeState( + threads: sortedThreads, + messagesByThread: normalizedMessages, + historyByThread: normalizedHistory, + summariesByThread: normalizedSummaries, + contextStateByThread: normalizedContextState, + nextHistorySequenceByThread: normalizedNextSequence, + normalizeState: false + ) + } + + func threadSummary(id: String) throws -> AgentThreadSummary { + guard let thread = threads.first(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + + return summariesByThread[id] ?? threadSummaryFallback(for: thread) + } + + func threadSummaryFallback(for thread: AgentThread) -> AgentThreadSummary { + Self.rebuildSummary( + for: thread, + history: historyByThread[thread.id] ?? [], + existing: summariesByThread[thread.id] + ) + } + + func threadHistoryPage( + id: String, + query: AgentHistoryQuery + ) throws -> AgentThreadHistoryPage { + guard threads.contains(where: { $0.id == id }) else { + throw AgentRuntimeError.threadNotFound(id) + } + + let limit = max(1, query.limit) + let filter = query.filter ?? AgentHistoryFilter() + let records = (historyByThread[id] ?? []).filter { filter.matches($0.item) } + let anchor = try query.cursor?.decodedSequenceNumber(expectedThreadID: id) + + switch query.direction { + case .backward: + let endIndex = records.endIndexForBackward(anchor: anchor) + let startIndex = max(0, endIndex - limit) + let pageRecords = Array(records[startIndex ..< endIndex]) + let hasMoreBefore = startIndex > 0 + let hasMoreAfter = endIndex < records.count + return AgentThreadHistoryPage( + threadID: id, + items: pageRecords.map { $0.item }, + nextCursor: hasMoreBefore ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + previousCursor: hasMoreAfter ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + + case .forward: + let startIndex = records.startIndexForForward(anchor: anchor) + let endIndex = min(records.count, startIndex + limit) + let pageRecords = Array(records[startIndex ..< endIndex]) + let hasMoreBefore = startIndex > 0 + let hasMoreAfter = endIndex < records.count + return AgentThreadHistoryPage( + threadID: id, + items: pageRecords.map { $0.item }, + nextCursor: hasMoreAfter ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.last?.sequenceNumber) : nil, + previousCursor: hasMoreBefore ? AgentHistoryCursor(threadID: id, sequenceNumber: pageRecords.first?.sequenceNumber) : nil, + hasMoreBefore: hasMoreBefore, + hasMoreAfter: hasMoreAfter + ) + } + } + + func applying(_ operations: [AgentStoreWriteOperation]) throws -> StoredRuntimeState { + var updated = self + + for operation in operations { + switch operation { + case let .upsertThread(thread): + if let index = updated.threads.firstIndex(where: { $0.id == thread.id }) { + updated.threads[index] = thread + } else { + updated.threads.append(thread) + } + + case let .upsertSummary(threadID, summary): + updated.summariesByThread[threadID] = summary + + case let .appendHistoryItems(threadID, items): + updated.historyByThread[threadID, default: []].append(contentsOf: items) + let nextSequence = (updated.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 + updated.nextHistorySequenceByThread[threadID] = nextSequence + + case let .appendCompactionMarker(threadID, marker): + updated.historyByThread[threadID, default: []].append(marker) + let nextSequence = (updated.historyByThread[threadID]?.last?.sequenceNumber ?? 0) + 1 + updated.nextHistorySequenceByThread[threadID] = nextSequence + + case let .upsertThreadContextState(threadID, state): + updated.contextStateByThread[threadID] = state + + case let .deleteThreadContextState(threadID): + updated.contextStateByThread.removeValue(forKey: threadID) + + case let .setPendingState(threadID, state): + if let thread = updated.threads.first(where: { $0.id == threadID }) { + let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) + updated.summariesByThread[threadID] = AgentThreadSummary( + threadID: current.threadID, + createdAt: current.createdAt, + updatedAt: current.updatedAt, + latestItemAt: current.latestItemAt, + itemCount: current.itemCount, + latestAssistantMessagePreview: current.latestAssistantMessagePreview, + latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, + latestPartialStructuredOutput: current.latestPartialStructuredOutput, + latestToolState: current.latestToolState, + latestTurnStatus: current.latestTurnStatus, + pendingState: state + ) + } + + case let .setPartialStructuredSnapshot(threadID, snapshot): + if let thread = updated.threads.first(where: { $0.id == threadID }) { + let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) + updated.summariesByThread[threadID] = AgentThreadSummary( + threadID: current.threadID, + createdAt: current.createdAt, + updatedAt: current.updatedAt, + latestItemAt: current.latestItemAt, + itemCount: current.itemCount, + latestAssistantMessagePreview: current.latestAssistantMessagePreview, + latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, + latestPartialStructuredOutput: snapshot, + latestToolState: current.latestToolState, + latestTurnStatus: current.latestTurnStatus, + pendingState: current.pendingState + ) + } + + case let .upsertToolSession(threadID, session): + if let thread = updated.threads.first(where: { $0.id == threadID }) { + let current = updated.summariesByThread[threadID] ?? updated.threadSummaryFallback(for: thread) + let latestToolState = AgentLatestToolState( + invocationID: session.invocationID, + turnID: session.turnID, + toolName: session.toolName, + status: .running, + success: nil, + sessionID: session.sessionID, + sessionStatus: session.sessionStatus, + metadata: session.metadata, + resumable: session.resumable, + updatedAt: session.updatedAt, + resultPreview: nil + ) + updated.summariesByThread[threadID] = AgentThreadSummary( + threadID: current.threadID, + createdAt: current.createdAt, + updatedAt: current.updatedAt, + latestItemAt: current.latestItemAt, + itemCount: current.itemCount, + latestAssistantMessagePreview: current.latestAssistantMessagePreview, + latestStructuredOutputMetadata: current.latestStructuredOutputMetadata, + latestPartialStructuredOutput: current.latestPartialStructuredOutput, + latestToolState: latestToolState, + latestTurnStatus: current.latestTurnStatus, + pendingState: .toolWait( + AgentPendingToolWaitState( + invocationID: session.invocationID, + turnID: session.turnID, + toolName: session.toolName, + startedAt: session.updatedAt, + sessionID: session.sessionID, + sessionStatus: session.sessionStatus, + metadata: session.metadata, + resumable: session.resumable + ) + ) + ) + } + + case let .redactHistoryItems(threadID, itemIDs, reason): + guard !itemIDs.isEmpty else { + continue + } + updated.historyByThread[threadID] = updated.historyByThread[threadID]?.map { record in + guard itemIDs.contains(record.id) else { + return record + } + return record.redacted(reason: reason) + } + + case let .deleteThread(threadID): + updated.threads.removeAll { $0.id == threadID } + updated.messagesByThread.removeValue(forKey: threadID) + updated.historyByThread.removeValue(forKey: threadID) + updated.summariesByThread.removeValue(forKey: threadID) + updated.contextStateByThread.removeValue(forKey: threadID) + updated.nextHistorySequenceByThread.removeValue(forKey: threadID) + } + } + + return updated.normalized() + } + + private static func syntheticHistory(from messages: [AgentMessage]) -> [AgentHistoryRecord] { + let orderedMessages = messages.enumerated().sorted { lhs, rhs in + let left = lhs.element + let right = rhs.element + if left.createdAt == right.createdAt { + return lhs.offset < rhs.offset + } + return left.createdAt < right.createdAt + } + + return orderedMessages.enumerated().map { index, pair in + AgentHistoryRecord( + sequenceNumber: index + 1, + createdAt: pair.element.createdAt, + item: .message(pair.element) + ) + } + } + + private static func rebuildSummary( + for thread: AgentThread, + history: [AgentHistoryRecord], + existing: AgentThreadSummary? + ) -> AgentThreadSummary { + var latestAssistantMessagePreview = existing?.latestAssistantMessagePreview + var latestStructuredOutputMetadata = existing?.latestStructuredOutputMetadata + var latestToolState = existing?.latestToolState + var latestTurnStatus = existing?.latestTurnStatus + let latestPartialStructuredOutput = existing?.latestPartialStructuredOutput + let pendingState = existing?.pendingState + + for record in history { + switch record.item { + case let .message(message): + if message.role == .assistant { + latestAssistantMessagePreview = message.displayText + if let structuredOutput = message.structuredOutput { + latestStructuredOutputMetadata = structuredOutput + } + } + + case let .toolCall(toolCall): + latestToolState = AgentLatestToolState( + invocationID: toolCall.invocation.id, + turnID: toolCall.invocation.turnID, + toolName: toolCall.invocation.toolName, + status: .waiting, + updatedAt: toolCall.requestedAt + ) + + case let .toolResult(toolResult): + latestToolState = Self.latestToolState(from: toolResult) + + case let .structuredOutput(structuredOutput): + latestStructuredOutputMetadata = structuredOutput.metadata + + case .approval: + break + + case let .systemEvent(systemEvent): + switch systemEvent.type { + case .turnStarted: + latestTurnStatus = .running + case .turnCompleted: + latestTurnStatus = .completed + case .turnFailed: + latestTurnStatus = .failed + case .threadCreated, .threadResumed, .threadStatusChanged, .contextCompacted: + break + } + } + } + + return AgentThreadSummary( + threadID: thread.id, + createdAt: thread.createdAt, + updatedAt: thread.updatedAt, + latestItemAt: history.last?.createdAt, + itemCount: history.count, + latestAssistantMessagePreview: latestAssistantMessagePreview, + latestStructuredOutputMetadata: latestStructuredOutputMetadata, + latestPartialStructuredOutput: latestPartialStructuredOutput, + latestToolState: latestToolState, + latestTurnStatus: latestTurnStatus, + pendingState: pendingState + ) + } + + private static func latestToolState(from toolResult: AgentToolResultRecord) -> AgentLatestToolState { + let preview = toolResult.result.primaryText + let session = toolResult.result.session + let status: AgentToolSessionStatus + if toolResult.result.errorMessage == "Tool execution was denied by the user." { + status = .denied + } else if let session, !session.isTerminal { + status = .running + } else if toolResult.result.success { + status = .completed + } else { + status = .failed + } + + return AgentLatestToolState( + invocationID: toolResult.result.invocationID, + turnID: toolResult.turnID, + toolName: toolResult.result.toolName, + status: status, + success: toolResult.result.success, + sessionID: session?.sessionID, + sessionStatus: session?.status, + metadata: session?.metadata, + resumable: session?.resumable ?? false, + updatedAt: toolResult.completedAt, + resultPreview: preview + ) + } +} diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift b/Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift new file mode 100644 index 0000000..bffb4c2 --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift @@ -0,0 +1,94 @@ +import CodexKit +import CodexKitUI +import XCTest + +extension AgentRuntimeTests { + func testManualCompactionPreservesVisibleHistoryAndHidesMarkersByDefault() async throws { + let backend = CompactingTestBackend() + let runtime = try makeHistoryRuntime(backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), contextCompaction: AgentContextCompactionConfiguration(isEnabled: true, mode: .automatic)) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Compaction") + _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + + let visibleBefore = await runtime.messages(for: thread.id) + XCTAssertEqual(visibleBefore.count, 6) + + let contextState = try await runtime.compactThreadContext(id: thread.id) + let visibleAfter = await runtime.messages(for: thread.id) + let defaultSystemHistory = try await runtime.execute( + HistoryItemsQuery(threadID: thread.id, kinds: [.systemEvent]) + ) + let debugSystemHistory = try await runtime.execute( + HistoryItemsQuery( + threadID: thread.id, + kinds: [.systemEvent], + includeCompactionEvents: true + ) + ) + + XCTAssertEqual(contextState.generation, 1) + XCTAssertLessThan(contextState.effectiveMessages.count, visibleBefore.count) + XCTAssertEqual(visibleAfter, visibleBefore) + XCTAssertFalse(defaultSystemHistory.records.contains { $0.item.isCompactionMarker }) + XCTAssertTrue(debugSystemHistory.records.contains { $0.item.isCompactionMarker }) + } + + func testAutomaticRetryCompactionRecoversFromContextLimitError() async throws { + let backend = CompactingTestBackend(failOnHistoryCountAbove: 2) + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + contextCompaction: AgentContextCompactionConfiguration(isEnabled: true, mode: .automatic, trigger: AgentContextCompactionTrigger(estimatedTokenThreshold: 100_000, retryOnContextLimitError: true)) + ) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Retry Compact") + _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + let reply = try await runtime.sendMessage(UserMessageRequest(text: "four"), in: thread.id) + let compactCallCount = await backend.compactCallCount() + let historyCounts = await backend.beginTurnHistoryCounts() + + XCTAssertEqual(reply, "Echo: four") + XCTAssertEqual(compactCallCount, 1) + XCTAssertGreaterThanOrEqual(historyCounts.count, 5) + } + + func testContextStatePersistsAcrossGRDBReload() async throws { + let url = temporaryRuntimeSQLiteURL() + defer { try? FileManager.default.removeItem(at: url) } + + let backend = CompactingTestBackend() + let runtime = try makeHistoryRuntime(backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: try GRDBRuntimeStateStore(url: url), contextCompaction: AgentContextCompactionConfiguration(isEnabled: true, mode: .automatic)) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Persisted Context") + _ = try await runtime.sendMessage(UserMessageRequest(text: "alpha"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "beta"), in: thread.id) + _ = try await runtime.compactThreadContext(id: thread.id) + + let reloadedRuntime = try makeHistoryRuntime(backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: try GRDBRuntimeStateStore(url: url), contextCompaction: AgentContextCompactionConfiguration(isEnabled: true, mode: .automatic)) + let restoredContext = try await reloadedRuntime.fetchThreadContextState(id: thread.id) + XCTAssertEqual(restoredContext?.generation, 1) + XCTAssertFalse(restoredContext?.effectiveMessages.isEmpty ?? true) + } + + func testContextCompactionConfigurationDefaultsAndCodableShape() throws { + let configuration = AgentContextCompactionConfiguration() + XCTAssertFalse(configuration.isEnabled) + XCTAssertEqual(configuration.mode, .automatic) + + let encoded = try JSONEncoder().encode(AgentContextCompactionConfiguration(isEnabled: true, mode: .manual)) + let decoded = try JSONDecoder().decode(AgentContextCompactionConfiguration.self, from: encoded) + XCTAssertTrue(decoded.isEnabled) + XCTAssertEqual(decoded.mode, .manual) + } +} diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryGRDBTests.swift b/Tests/CodexKitTests/AgentRuntimeHistoryGRDBTests.swift new file mode 100644 index 0000000..b8d5a67 --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimeHistoryGRDBTests.swift @@ -0,0 +1,136 @@ +import CodexKit +import CodexKitUI +import XCTest + +extension AgentRuntimeTests { + func testGRDBRuntimeStateStorePersistsSummariesAndQueriesAcrossReload() async throws { + let url = temporaryRuntimeSQLiteURL() + defer { try? FileManager.default.removeItem(at: url) } + + let backend = InMemoryAgentBackend(structuredResponseText: #"{"reply":"The replacement is shipping today.","priority":"urgent"}"#) + let store = try GRDBRuntimeStateStore(url: url) + let runtime = try makeHistoryRuntime(backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: store) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "GRDB Thread") + _ = try await runtime.sendMessage(UserMessageRequest(text: "Draft the shipping update."), in: thread.id, expecting: ShippingReplyDraft.self) + + let reloadedStore = try GRDBRuntimeStateStore(url: url) + let reloadedRuntime = try makeHistoryRuntime(backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: reloadedStore) + + let metadata = try await reloadedRuntime.prepareStore() + XCTAssertEqual(metadata.storeKind, "GRDBRuntimeStateStore") + XCTAssertEqual(metadata.storeSchemaVersion, 2) + + let summary = try await reloadedRuntime.fetchThreadSummary(id: thread.id) + XCTAssertEqual(summary.latestTurnStatus, .completed) + XCTAssertEqual(summary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") + + let snapshots = try await reloadedRuntime.execute(ThreadSnapshotQuery(threadIDs: [thread.id])) + XCTAssertEqual(snapshots.count, 1) + XCTAssertEqual(snapshots.first?.threadID, thread.id) + + let history = try await reloadedRuntime.execute(HistoryItemsQuery(threadID: thread.id, kinds: [.message, .structuredOutput])) + XCTAssertFalse(history.records.isEmpty) + + let typed = try await reloadedRuntime.fetchLatestStructuredOutput(id: thread.id, as: ShippingReplyDraft.self) + XCTAssertEqual(typed, ShippingReplyDraft(reply: "The replacement is shipping today.", priority: "urgent")) + } + + func testGRDBRuntimeStateStorePersistsRedactionAndDeletion() async throws { + let url = temporaryRuntimeSQLiteURL() + defer { try? FileManager.default.removeItem(at: url) } + + let runtime = try makeHistoryRuntime(backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: try GRDBRuntimeStateStore(url: url)) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "GRDB Mutations") + _ = try await runtime.sendMessage(UserMessageRequest(text: "please redact me"), in: thread.id) + + let messageHistory = try await runtime.execute(HistoryItemsQuery(threadID: thread.id, kinds: [.message])) + guard let firstMessage = messageHistory.records.first else { return XCTFail("Expected a persisted message record.") } + + try await runtime.redactHistoryItems([firstMessage.id], in: thread.id) + + let reloadedAfterRedaction = try makeHistoryRuntime(backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: try GRDBRuntimeStateStore(url: url)) + let redactedHistory = try await reloadedAfterRedaction.execute(HistoryItemsQuery(threadID: thread.id, kinds: [.message])) + + guard let redactedRecord = redactedHistory.records.first(where: { $0.id == firstMessage.id }) else { + return XCTFail("Expected the redacted record to still be queryable.") + } + XCTAssertNotNil(redactedRecord.redaction) + + try await reloadedAfterRedaction.deleteThread(id: thread.id) + let deletedRuntime = try makeHistoryRuntime(backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: try GRDBRuntimeStateStore(url: url)) + let deletedThreads = try await deletedRuntime.execute(ThreadMetadataQuery(threadIDs: [thread.id])) + XCTAssertTrue(deletedThreads.isEmpty) + } + + func testGRDBRuntimeStateStoreImportsLegacyFileStateOnFirstPrepare() async throws { + let directory = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) + try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: directory) } + + let legacyURL = directory.appendingPathComponent("runtime-state").appendingPathExtension("json") + let sqliteURL = directory.appendingPathComponent("runtime-state").appendingPathExtension("sqlite") + + let backend = InMemoryAgentBackend(structuredResponseText: #"{"reply":"Legacy import payload.","priority":"normal"}"#) + let legacyRuntime = try makeHistoryRuntime(backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: FileRuntimeStateStore(url: legacyURL)) + _ = try await legacyRuntime.restore() + _ = try await legacyRuntime.signIn() + + let thread = try await legacyRuntime.createThread(title: "Legacy File Thread") + _ = try await legacyRuntime.sendMessage(UserMessageRequest(text: "Create a legacy payload."), in: thread.id, expecting: ShippingReplyDraft.self) + + let importedStore = try GRDBRuntimeStateStore(url: sqliteURL) + let importedRuntime = try makeHistoryRuntime(backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: importedStore) + _ = try await importedRuntime.prepareStore() + + let summary = try await importedRuntime.fetchThreadSummary(id: thread.id) + let importedHistory = try await importedRuntime.execute( + HistoryItemsQuery(threadID: thread.id, kinds: [.message, .structuredOutput]) + ) + XCTAssertEqual(summary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") + XCTAssertFalse(importedHistory.records.isEmpty) + } + + func testGRDBRuntimeStateStoreExternalizesImageAttachments() async throws { + let url = temporaryRuntimeSQLiteURL() + let attachmentsDirectory = url.deletingLastPathComponent() + .appendingPathComponent("\(url.deletingPathExtension().lastPathComponent).codexkit-state", isDirectory: true) + .appendingPathComponent("attachments", isDirectory: true) + defer { + try? FileManager.default.removeItem(at: url) + try? FileManager.default.removeItem(at: attachmentsDirectory.deletingLastPathComponent()) + } + + let imageData = Data([0x89, 0x50, 0x4E, 0x47, 0xDE, 0xAD, 0xBE, 0xEF]) + let runtime = try makeHistoryRuntime(backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: try GRDBRuntimeStateStore(url: url)) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Attachment Thread") + _ = try await runtime.sendMessage(UserMessageRequest(text: "here is an image", images: [.png(imageData)]), in: thread.id) + + let reloadedRuntime = try makeHistoryRuntime(backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: try GRDBRuntimeStateStore(url: url)) + let history = try await reloadedRuntime.execute(HistoryItemsQuery(threadID: thread.id, kinds: [.message])) + guard let userMessage = history.records.compactMap({ record -> AgentMessage? in + guard case let .message(message) = record.item, message.role == .user else { return nil } + return message + }).first else { return XCTFail("Expected a persisted user message with an attachment.") } + + XCTAssertEqual(userMessage.images.count, 1) + XCTAssertEqual(userMessage.images.first?.data, imageData) + + let attachmentFiles = try FileManager.default.contentsOfDirectory( + at: attachmentsDirectory.appendingPathComponent(thread.id, isDirectory: true) + .appendingPathComponent(userMessage.id, isDirectory: true), + includingPropertiesForKeys: nil + ) + XCTAssertEqual(attachmentFiles.count, 1) + XCTAssertNil(try Data(contentsOf: url).range(of: imageData.base64EncodedData())) + } +} diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift b/Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift new file mode 100644 index 0000000..045a4e8 --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift @@ -0,0 +1,323 @@ +import CodexKit +import CodexKitUI +import XCTest + +func makeHistoryRuntime( + backend: any AgentBackend, + approvalPresenter: any ApprovalPresenting, + stateStore: any RuntimeStateStoring, + tools: [AgentRuntime.ToolRegistration] = [], + contextCompaction: AgentContextCompactionConfiguration = AgentContextCompactionConfiguration() +) throws -> AgentRuntime { + try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), + backend: backend, + approvalPresenter: approvalPresenter, + stateStore: stateStore, + tools: tools, + contextCompaction: contextCompaction + )) +} + +func temporaryRuntimeSQLiteURL() -> URL { + FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString) + .appendingPathExtension("sqlite") +} + +func messageTexts(in page: AgentThreadHistoryPage) -> [String] { + page.items.compactMap { item in + guard case let .message(message) = item else { return nil } + return message.displayText + } +} + +func waitUntil( + timeoutNanoseconds: UInt64 = 2_000_000_000, + intervalNanoseconds: UInt64 = 20_000_000, + condition: @escaping @Sendable () async throws -> Bool +) async throws { + let deadline = DispatchTime.now().uptimeNanoseconds + timeoutNanoseconds + while DispatchTime.now().uptimeNanoseconds < deadline { + if try await condition() { + return + } + try await Task.sleep(nanoseconds: intervalNanoseconds) + } + + XCTFail("Timed out waiting for condition.") +} + +enum TimedAsyncOperationError: Error { + case timedOut +} + +func awaitValue( + timeoutNanoseconds: UInt64 = 500_000_000, + operation: @escaping @Sendable () async throws -> T +) async throws -> T { + try await withThrowingTaskGroup(of: T.self) { group in + group.addTask { + try await operation() + } + group.addTask { + try await Task.sleep(nanoseconds: timeoutNanoseconds) + throw TimedAsyncOperationError.timedOut + } + + guard let value = try await group.next() else { + throw TimedAsyncOperationError.timedOut + } + group.cancelAll() + return value + } +} + +actor ToolExecutionGate { + private var waiters: [CheckedContinuation] = [] + private var released = false + + func wait() async { + guard !released else { return } + await withCheckedContinuation { continuation in + waiters.append(continuation) + } + } + + func release() { + released = true + let continuations = waiters + waiters.removeAll() + continuations.forEach { $0.resume() } + } +} + +private actor BeginTurnDelayGate { + private var didStart = false + private var startWaiters: [CheckedContinuation] = [] + private var released = false + private var releaseWaiters: [CheckedContinuation] = [] + + func markStarted() { + didStart = true + let continuations = startWaiters + startWaiters.removeAll() + continuations.forEach { $0.resume() } + } + + func waitForStart() async { + guard !didStart else { return } + await withCheckedContinuation { continuation in + startWaiters.append(continuation) + } + } + + func waitForRelease() async { + guard !released else { return } + await withCheckedContinuation { continuation in + releaseWaiters.append(continuation) + } + } + + func release() { + released = true + let continuations = releaseWaiters + releaseWaiters.removeAll() + continuations.forEach { $0.resume() } + } +} + +private actor PartialEmissionGate { + private var partialEmitted = false + private var partialWaiters: [CheckedContinuation] = [] + private var commitReleased = false + private var commitWaiters: [CheckedContinuation] = [] + + func markPartialEmitted() { + partialEmitted = true + let continuations = partialWaiters + partialWaiters.removeAll() + continuations.forEach { $0.resume() } + } + + func waitForPartialEmission() async { + guard !partialEmitted else { return } + await withCheckedContinuation { continuation in + partialWaiters.append(continuation) + } + } + + func waitForCommitRelease() async { + guard !commitReleased else { return } + await withCheckedContinuation { continuation in + commitWaiters.append(continuation) + } + } + + func releaseCommit() { + commitReleased = true + let continuations = commitWaiters + commitWaiters.removeAll() + continuations.forEach { $0.resume() } + } +} + +actor CompactingTestBackend: AgentBackend, AgentBackendContextCompacting { + nonisolated let baseInstructions: String? = nil + + private let failOnHistoryCountAbove: Int? + private var threads: [String: AgentThread] = [:] + private var compactCalls = 0 + private var historyCounts: [Int] = [] + + init(failOnHistoryCountAbove: Int? = nil) { + self.failOnHistoryCountAbove = failOnHistoryCountAbove + } + + func createThread(session _: ChatGPTSession) async throws -> AgentThread { + let thread = AgentThread(id: UUID().uuidString) + threads[thread.id] = thread + return thread + } + + func resumeThread(id: String, session _: ChatGPTSession) async throws -> AgentThread { + if let thread = threads[id] { return thread } + let thread = AgentThread(id: id) + threads[id] = thread + return thread + } + + func beginTurn( + thread: AgentThread, + history: [AgentMessage], + message: UserMessageRequest, + instructions _: String, + responseFormat _: AgentStructuredOutputFormat?, + streamedStructuredOutput _: AgentStreamedStructuredOutputRequest?, + tools _: [ToolDefinition], + session _: ChatGPTSession + ) async throws -> any AgentTurnStreaming { + historyCounts.append(history.count) + if let failOnHistoryCountAbove, + history.count > failOnHistoryCountAbove, + !history.contains(where: { $0.role == .system && $0.text.contains("Compacted conversation summary") }) { + throw AgentRuntimeError(code: "context_limit_exceeded", message: "Maximum context length exceeded.") + } + + return MockAgentTurnSession(thread: thread, message: message, selectedTool: nil, structuredResponseText: nil, streamedStructuredOutput: nil) + } + + func compactContext( + thread: AgentThread, + effectiveHistory: [AgentMessage], + instructions _: String, + tools _: [ToolDefinition], + session _: ChatGPTSession + ) async throws -> AgentCompactionResult { + compactCalls += 1 + let lastUser = effectiveHistory.last(where: { $0.role == .user }) + let lastAssistant = effectiveHistory.last(where: { $0.role == .assistant }) + var compacted = [AgentMessage(threadID: thread.id, role: .system, text: "Compacted conversation summary")] + if let lastUser { compacted.append(lastUser) } + if let lastAssistant { compacted.append(lastAssistant) } + return AgentCompactionResult(effectiveMessages: compacted, summaryPreview: "Compacted conversation summary") + } + + func compactCallCount() -> Int { compactCalls } + func beginTurnHistoryCounts() -> [Int] { historyCounts } +} + +actor DelayedBeginTurnBackend: AgentBackend { + private let gate = BeginTurnDelayGate() + + func createThread(session _: ChatGPTSession) async throws -> AgentThread { + AgentThread(id: UUID().uuidString) + } + + func resumeThread(id: String, session _: ChatGPTSession) async throws -> AgentThread { + AgentThread(id: id) + } + + func beginTurn( + thread: AgentThread, + history _: [AgentMessage], + message: UserMessageRequest, + instructions _: String, + responseFormat _: AgentStructuredOutputFormat?, + streamedStructuredOutput _: AgentStreamedStructuredOutputRequest?, + tools _: [ToolDefinition], + session _: ChatGPTSession + ) async throws -> any AgentTurnStreaming { + await gate.markStarted() + await gate.waitForRelease() + return MockAgentTurnSession( + thread: thread, + message: message, + selectedTool: nil, + structuredResponseText: nil, + streamedStructuredOutput: nil + ) + } + + func waitForBeginTurnStart() async { + await gate.waitForStart() + } + + func releaseBeginTurn() async { + await gate.release() + } +} + +actor BlockingStructuredPartialBackend: AgentBackend { + private let gate = PartialEmissionGate() + + func createThread(session _: ChatGPTSession) async throws -> AgentThread { + AgentThread(id: UUID().uuidString) + } + + func resumeThread(id: String, session _: ChatGPTSession) async throws -> AgentThread { + AgentThread(id: id) + } + + func beginTurn( + thread: AgentThread, + history _: [AgentMessage], + message _: UserMessageRequest, + instructions _: String, + responseFormat _: AgentStructuredOutputFormat?, + streamedStructuredOutput _: AgentStreamedStructuredOutputRequest?, + tools _: [ToolDefinition], + session _: ChatGPTSession + ) async throws -> any AgentTurnStreaming { + BlockingStructuredPartialTurnSession(threadID: thread.id, gate: gate) + } + + func waitForPartialEmission() async { await gate.waitForPartialEmission() } + func releaseCommit() async { await gate.releaseCommit() } +} + +private final class BlockingStructuredPartialTurnSession: AgentTurnStreaming, @unchecked Sendable { + let events: AsyncThrowingStream + + init(threadID: String, gate: PartialEmissionGate) { + let turn = AgentTurn(id: UUID().uuidString, threadID: threadID) + let payload: JSONValue = .object(["reply": .string("Your order is already in transit."), "priority": .string("high")]) + + events = AsyncThrowingStream { continuation in + Task { + continuation.yield(.turnStarted(turn)) + continuation.yield(.assistantMessageDelta(threadID: threadID, turnID: turn.id, delta: "Echo: Draft a shipping reply.")) + continuation.yield(.structuredOutputPartial(payload)) + await gate.markPartialEmitted() + await gate.waitForCommitRelease() + continuation.yield(.structuredOutputCommitted(payload)) + continuation.yield(.assistantMessageCompleted(AgentMessage(threadID: threadID, role: .assistant, text: "Echo: Draft a shipping reply.", structuredOutput: AgentStructuredOutputMetadata(formatName: "shipping_reply_draft", payload: payload)))) + continuation.yield(.turnCompleted(AgentTurnSummary(threadID: threadID, turnID: turn.id, usage: AgentUsage(inputTokens: 1, outputTokens: 1)))) + continuation.finish() + } + } + } + + func submitToolResult(_: ToolResultEnvelope, for _: String) async throws {} +} diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift index e0f0556..4b6474a 100644 --- a/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeHistoryTests.swift @@ -448,663 +448,4 @@ extension AgentRuntimeTests { XCTAssertTrue(history.records.isEmpty) } - func testGRDBRuntimeStateStorePersistsSummariesAndQueriesAcrossReload() async throws { - let url = temporaryRuntimeSQLiteURL() - defer { try? FileManager.default.removeItem(at: url) } - - let backend = InMemoryAgentBackend( - structuredResponseText: #"{"reply":"The replacement is shipping today.","priority":"urgent"}"# - ) - let store = try GRDBRuntimeStateStore(url: url) - let runtime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: store - ) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "GRDB Thread") - _ = try await runtime.sendMessage( - UserMessageRequest(text: "Draft the shipping update."), - in: thread.id, - expecting: ShippingReplyDraft.self - ) - - let reloadedStore = try GRDBRuntimeStateStore(url: url) - let reloadedRuntime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: reloadedStore - ) - - let metadata = try await reloadedRuntime.prepareStore() - XCTAssertEqual(metadata.storeKind, "GRDBRuntimeStateStore") - XCTAssertEqual(metadata.storeSchemaVersion, 2) - - let summary = try await reloadedRuntime.fetchThreadSummary(id: thread.id) - XCTAssertEqual(summary.latestTurnStatus, .completed) - XCTAssertEqual(summary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") - - let snapshots = try await reloadedRuntime.execute( - ThreadSnapshotQuery(threadIDs: [thread.id]) - ) - XCTAssertEqual(snapshots.count, 1) - XCTAssertEqual(snapshots.first?.threadID, thread.id) - - let history = try await reloadedRuntime.execute( - HistoryItemsQuery( - threadID: thread.id, - kinds: [.message, .structuredOutput] - ) - ) - XCTAssertFalse(history.records.isEmpty) - - let typed = try await reloadedRuntime.fetchLatestStructuredOutput( - id: thread.id, - as: ShippingReplyDraft.self - ) - XCTAssertEqual( - typed, - ShippingReplyDraft( - reply: "The replacement is shipping today.", - priority: "urgent" - ) - ) - } - - func testGRDBRuntimeStateStorePersistsRedactionAndDeletion() async throws { - let url = temporaryRuntimeSQLiteURL() - defer { try? FileManager.default.removeItem(at: url) } - - let runtime = try makeHistoryRuntime( - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: try GRDBRuntimeStateStore(url: url) - ) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "GRDB Mutations") - _ = try await runtime.sendMessage(UserMessageRequest(text: "please redact me"), in: thread.id) - - let messageHistory = try await runtime.execute( - HistoryItemsQuery(threadID: thread.id, kinds: [.message]) - ) - guard let firstMessage = messageHistory.records.first else { - return XCTFail("Expected a persisted message record.") - } - - try await runtime.redactHistoryItems([firstMessage.id], in: thread.id) - - let reloadedAfterRedaction = try makeHistoryRuntime( - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: try GRDBRuntimeStateStore(url: url) - ) - let redactedHistory = try await reloadedAfterRedaction.execute( - HistoryItemsQuery(threadID: thread.id, kinds: [.message]) - ) - - guard let redactedRecord = redactedHistory.records.first(where: { $0.id == firstMessage.id }) else { - return XCTFail("Expected the redacted record to still be queryable.") - } - XCTAssertNotNil(redactedRecord.redaction) - - try await reloadedAfterRedaction.deleteThread(id: thread.id) - - let deletedRuntime = try makeHistoryRuntime( - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: try GRDBRuntimeStateStore(url: url) - ) - let threads = try await deletedRuntime.execute( - ThreadMetadataQuery(threadIDs: [thread.id]) - ) - XCTAssertTrue(threads.isEmpty) - } - - func testGRDBRuntimeStateStoreImportsLegacyFileStateOnFirstPrepare() async throws { - let directory = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) - try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true) - defer { try? FileManager.default.removeItem(at: directory) } - - let legacyURL = directory.appendingPathComponent("runtime-state").appendingPathExtension("json") - let sqliteURL = directory.appendingPathComponent("runtime-state").appendingPathExtension("sqlite") - - let backend = InMemoryAgentBackend( - structuredResponseText: #"{"reply":"Legacy import payload.","priority":"normal"}"# - ) - let legacyRuntime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: FileRuntimeStateStore(url: legacyURL) - ) - - _ = try await legacyRuntime.restore() - _ = try await legacyRuntime.signIn() - - let thread = try await legacyRuntime.createThread(title: "Legacy File Thread") - _ = try await legacyRuntime.sendMessage( - UserMessageRequest(text: "Create a legacy payload."), - in: thread.id, - expecting: ShippingReplyDraft.self - ) - - let importedStore = try GRDBRuntimeStateStore(url: sqliteURL) - let importedRuntime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: importedStore - ) - - _ = try await importedRuntime.prepareStore() - - let summary = try await importedRuntime.fetchThreadSummary(id: thread.id) - XCTAssertEqual(summary.latestStructuredOutputMetadata?.formatName, "shipping_reply_draft") - - let history = try await importedRuntime.execute( - HistoryItemsQuery(threadID: thread.id, kinds: [.message, .structuredOutput]) - ) - XCTAssertFalse(history.records.isEmpty) - } - - func testGRDBRuntimeStateStoreExternalizesImageAttachments() async throws { - let url = temporaryRuntimeSQLiteURL() - let attachmentsDirectory = url.deletingLastPathComponent() - .appendingPathComponent("\(url.deletingPathExtension().lastPathComponent).codexkit-state", isDirectory: true) - .appendingPathComponent("attachments", isDirectory: true) - defer { - try? FileManager.default.removeItem(at: url) - try? FileManager.default.removeItem(at: attachmentsDirectory.deletingLastPathComponent()) - } - - let imageData = Data([0x89, 0x50, 0x4E, 0x47, 0xDE, 0xAD, 0xBE, 0xEF]) - let runtime = try makeHistoryRuntime( - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: try GRDBRuntimeStateStore(url: url) - ) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Attachment Thread") - _ = try await runtime.sendMessage( - UserMessageRequest( - text: "here is an image", - images: [.png(imageData)] - ), - in: thread.id - ) - - let reloadedRuntime = try makeHistoryRuntime( - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: try GRDBRuntimeStateStore(url: url) - ) - - let history = try await reloadedRuntime.execute( - HistoryItemsQuery(threadID: thread.id, kinds: [.message]) - ) - guard let userMessage = history.records.compactMap({ record -> AgentMessage? in - guard case let .message(message) = record.item, message.role == .user else { - return nil - } - return message - }).first else { - return XCTFail("Expected a persisted user message with an attachment.") - } - - XCTAssertEqual(userMessage.images.count, 1) - XCTAssertEqual(userMessage.images.first?.data, imageData) - - let attachmentFiles = try FileManager.default.contentsOfDirectory( - at: attachmentsDirectory.appendingPathComponent(thread.id, isDirectory: true) - .appendingPathComponent(userMessage.id, isDirectory: true), - includingPropertiesForKeys: nil - ) - XCTAssertEqual(attachmentFiles.count, 1) - - let databaseData = try Data(contentsOf: url) - XCTAssertNil(databaseData.range(of: imageData.base64EncodedData())) - } - - func testManualCompactionPreservesVisibleHistoryAndHidesMarkersByDefault() async throws { - let backend = CompactingTestBackend() - let runtime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - contextCompaction: AgentContextCompactionConfiguration( - isEnabled: true, - mode: .automatic - ) - ) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Compaction") - _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) - - let visibleBefore = await runtime.messages(for: thread.id) - XCTAssertEqual(visibleBefore.count, 6) - - let contextState = try await runtime.compactThreadContext(id: thread.id) - XCTAssertEqual(contextState.generation, 1) - XCTAssertLessThan(contextState.effectiveMessages.count, visibleBefore.count) - - let visibleAfter = await runtime.messages(for: thread.id) - XCTAssertEqual(visibleAfter, visibleBefore) - - let hiddenHistory = try await runtime.execute( - HistoryItemsQuery( - threadID: thread.id, - kinds: [.systemEvent] - ) - ) - XCTAssertFalse(hiddenHistory.records.contains(where: { $0.item.isCompactionMarker })) - - let debugHistory = try await runtime.execute( - HistoryItemsQuery( - threadID: thread.id, - kinds: [.systemEvent], - includeCompactionEvents: true - ) - ) - XCTAssertTrue(debugHistory.records.contains(where: { $0.item.isCompactionMarker })) - } - - func testAutomaticRetryCompactionRecoversFromContextLimitError() async throws { - let backend = CompactingTestBackend(failOnHistoryCountAbove: 2) - let runtime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - contextCompaction: AgentContextCompactionConfiguration( - isEnabled: true, - mode: .automatic, - trigger: AgentContextCompactionTrigger( - estimatedTokenThreshold: 100_000, - retryOnContextLimitError: true - ) - ) - ) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Retry Compact") - _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) - let reply = try await runtime.sendMessage(UserMessageRequest(text: "four"), in: thread.id) - - XCTAssertEqual(reply, "Echo: four") - let compactCallCount = await backend.compactCallCount() - let beginTurnHistoryCounts = await backend.beginTurnHistoryCounts() - XCTAssertEqual(compactCallCount, 1) - XCTAssertGreaterThanOrEqual(beginTurnHistoryCounts.count, 5) - } - - func testContextStatePersistsAcrossGRDBReload() async throws { - let url = temporaryRuntimeSQLiteURL() - defer { try? FileManager.default.removeItem(at: url) } - - let backend = CompactingTestBackend() - let runtime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: try GRDBRuntimeStateStore(url: url), - contextCompaction: AgentContextCompactionConfiguration( - isEnabled: true, - mode: .automatic - ) - ) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Persisted Context") - _ = try await runtime.sendMessage(UserMessageRequest(text: "alpha"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "beta"), in: thread.id) - _ = try await runtime.compactThreadContext(id: thread.id) - - let reloadedRuntime = try makeHistoryRuntime( - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: try GRDBRuntimeStateStore(url: url), - contextCompaction: AgentContextCompactionConfiguration( - isEnabled: true, - mode: .automatic - ) - ) - - let restoredContext = try await reloadedRuntime.fetchThreadContextState(id: thread.id) - XCTAssertEqual(restoredContext?.generation, 1) - XCTAssertFalse(restoredContext?.effectiveMessages.isEmpty ?? true) - } - - func testContextCompactionConfigurationDefaultsAndCodableShape() throws { - let configuration = AgentContextCompactionConfiguration() - XCTAssertFalse(configuration.isEnabled) - XCTAssertEqual(configuration.mode, .automatic) - - let encoded = try JSONEncoder().encode( - AgentContextCompactionConfiguration(isEnabled: true, mode: .manual) - ) - let decoded = try JSONDecoder().decode( - AgentContextCompactionConfiguration.self, - from: encoded - ) - - XCTAssertTrue(decoded.isEnabled) - XCTAssertEqual(decoded.mode, .manual) - } -} - -private func makeHistoryRuntime( - backend: any AgentBackend, - approvalPresenter: any ApprovalPresenting, - stateStore: any RuntimeStateStoring, - tools: [AgentRuntime.ToolRegistration] = [], - contextCompaction: AgentContextCompactionConfiguration = AgentContextCompactionConfiguration() -) throws -> AgentRuntime { - try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: approvalPresenter, - stateStore: stateStore, - tools: tools, - contextCompaction: contextCompaction - )) -} - -private func temporaryRuntimeSQLiteURL() -> URL { - FileManager.default.temporaryDirectory - .appendingPathComponent(UUID().uuidString) - .appendingPathExtension("sqlite") -} - -private func messageTexts(in page: AgentThreadHistoryPage) -> [String] { - page.items.compactMap { item in - guard case let .message(message) = item else { - return nil - } - return message.displayText - } -} - -private func waitUntil( - timeoutNanoseconds: UInt64 = 2_000_000_000, - intervalNanoseconds: UInt64 = 20_000_000, - condition: @escaping @Sendable () async throws -> Bool -) async throws { - let deadline = DispatchTime.now().uptimeNanoseconds + timeoutNanoseconds - while DispatchTime.now().uptimeNanoseconds < deadline { - if try await condition() { - return - } - try await Task.sleep(nanoseconds: intervalNanoseconds) - } - - XCTFail("Timed out waiting for condition.") -} - -private actor ToolExecutionGate { - private var waiters: [CheckedContinuation] = [] - private var released = false - - func wait() async { - guard !released else { - return - } - - await withCheckedContinuation { continuation in - waiters.append(continuation) - } - } - - func release() { - released = true - let continuations = waiters - waiters.removeAll() - continuations.forEach { $0.resume() } - } -} - -private actor PartialEmissionGate { - private var partialEmitted = false - private var partialWaiters: [CheckedContinuation] = [] - private var commitReleased = false - private var commitWaiters: [CheckedContinuation] = [] - - func markPartialEmitted() { - partialEmitted = true - let continuations = partialWaiters - partialWaiters.removeAll() - continuations.forEach { $0.resume() } - } - - func waitForPartialEmission() async { - guard !partialEmitted else { - return - } - - await withCheckedContinuation { continuation in - partialWaiters.append(continuation) - } - } - - func waitForCommitRelease() async { - guard !commitReleased else { - return - } - - await withCheckedContinuation { continuation in - commitWaiters.append(continuation) - } - } - - func releaseCommit() { - commitReleased = true - let continuations = commitWaiters - commitWaiters.removeAll() - continuations.forEach { $0.resume() } - } -} - -private actor CompactingTestBackend: AgentBackend, AgentBackendContextCompacting { - nonisolated let baseInstructions: String? = nil - - private let failOnHistoryCountAbove: Int? - private var threads: [String: AgentThread] = [:] - private var compactCalls = 0 - private var historyCounts: [Int] = [] - - init(failOnHistoryCountAbove: Int? = nil) { - self.failOnHistoryCountAbove = failOnHistoryCountAbove - } - - func createThread(session _: ChatGPTSession) async throws -> AgentThread { - let thread = AgentThread(id: UUID().uuidString) - threads[thread.id] = thread - return thread - } - - func resumeThread(id: String, session _: ChatGPTSession) async throws -> AgentThread { - if let thread = threads[id] { - return thread - } - let thread = AgentThread(id: id) - threads[id] = thread - return thread - } - - func beginTurn( - thread: AgentThread, - history: [AgentMessage], - message: UserMessageRequest, - instructions _: String, - responseFormat _: AgentStructuredOutputFormat?, - streamedStructuredOutput _: AgentStreamedStructuredOutputRequest?, - tools _: [ToolDefinition], - session _: ChatGPTSession - ) async throws -> any AgentTurnStreaming { - historyCounts.append(history.count) - if let failOnHistoryCountAbove, - history.count > failOnHistoryCountAbove, - !history.contains(where: { $0.role == .system && $0.text.contains("Compacted conversation summary") }) { - throw AgentRuntimeError( - code: "context_limit_exceeded", - message: "Maximum context length exceeded." - ) - } - - return MockAgentTurnSession( - thread: thread, - message: message, - selectedTool: nil, - structuredResponseText: nil, - streamedStructuredOutput: nil - ) - } - - func compactContext( - thread: AgentThread, - effectiveHistory: [AgentMessage], - instructions _: String, - tools _: [ToolDefinition], - session _: ChatGPTSession - ) async throws -> AgentCompactionResult { - compactCalls += 1 - let lastUser = effectiveHistory.last(where: { $0.role == .user }) - let lastAssistant = effectiveHistory.last(where: { $0.role == .assistant }) - var compacted = [ - AgentMessage( - threadID: thread.id, - role: .system, - text: "Compacted conversation summary" - ), - ] - if let lastUser { - compacted.append(lastUser) - } - if let lastAssistant { - compacted.append(lastAssistant) - } - return AgentCompactionResult( - effectiveMessages: compacted, - summaryPreview: "Compacted conversation summary" - ) - } - - func compactCallCount() -> Int { - compactCalls - } - - func beginTurnHistoryCounts() -> [Int] { - historyCounts - } -} - -private actor BlockingStructuredPartialBackend: AgentBackend { - private let gate = PartialEmissionGate() - private var latestThreadID: String? - - func createThread(session _: ChatGPTSession) async throws -> AgentThread { - AgentThread(id: UUID().uuidString) - } - - func resumeThread(id: String, session _: ChatGPTSession) async throws -> AgentThread { - AgentThread(id: id) - } - - func beginTurn( - thread: AgentThread, - history _: [AgentMessage], - message _: UserMessageRequest, - instructions _: String, - responseFormat _: AgentStructuredOutputFormat?, - streamedStructuredOutput _: AgentStreamedStructuredOutputRequest?, - tools _: [ToolDefinition], - session _: ChatGPTSession - ) async throws -> any AgentTurnStreaming { - latestThreadID = thread.id - return BlockingStructuredPartialTurnSession( - threadID: thread.id, - gate: gate - ) - } - - func waitForPartialEmission() async { - await gate.waitForPartialEmission() - } - - func releaseCommit() async { - await gate.releaseCommit() - } -} - -private final class BlockingStructuredPartialTurnSession: AgentTurnStreaming, @unchecked Sendable { - let events: AsyncThrowingStream - - init(threadID: String, gate: PartialEmissionGate) { - let turn = AgentTurn(id: UUID().uuidString, threadID: threadID) - let payload: JSONValue = .object([ - "reply": .string("Your order is already in transit."), - "priority": .string("high"), - ]) - - events = AsyncThrowingStream { continuation in - Task { - continuation.yield(.turnStarted(turn)) - continuation.yield( - .assistantMessageDelta( - threadID: threadID, - turnID: turn.id, - delta: "Echo: Draft a shipping reply." - ) - ) - continuation.yield(.structuredOutputPartial(payload)) - await gate.markPartialEmitted() - await gate.waitForCommitRelease() - continuation.yield(.structuredOutputCommitted(payload)) - continuation.yield( - .assistantMessageCompleted( - AgentMessage( - threadID: threadID, - role: .assistant, - text: "Echo: Draft a shipping reply.", - structuredOutput: AgentStructuredOutputMetadata( - formatName: "shipping_reply_draft", - payload: payload - ) - ) - ) - ) - continuation.yield( - .turnCompleted( - AgentTurnSummary( - threadID: threadID, - turnID: turn.id, - usage: AgentUsage(inputTokens: 1, outputTokens: 1) - ) - ) - ) - continuation.finish() - } - } - } - - func submitToolResult(_: ToolResultEnvelope, for _: String) async throws {} } diff --git a/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift new file mode 100644 index 0000000..2757b93 --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift @@ -0,0 +1,263 @@ +import Combine +import CodexKit +import XCTest + +extension AgentRuntimeTests { + func testImportedContentInitializerBuildsMessageWithSharedURLs() async throws { + let importedContent = AgentImportedContent(textSnippets: ["Customer says the package arrived damaged."], urls: [URL(string: "https://example.com/delivery-update")!]) + let request = UserMessageRequest(prompt: "Summarize and draft a reply.", importedContent: importedContent) + XCTAssertTrue(request.text.contains("Summarize and draft a reply.")) + XCTAssertTrue(request.text.contains("https://example.com/delivery-update")) + XCTAssertTrue(request.text.contains("Customer says the package arrived damaged.")) + } + + func testImageOnlyMessageIsAcceptedAndPersisted() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore())) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Images") + let image = AgentImageAttachment.png(Data([0x89, 0x50, 0x4E, 0x47])) + _ = try await runtime.sendMessage(UserMessageRequest(text: "", images: [image]), in: thread.id) + + let messages = await runtime.messages(for: thread.id) + XCTAssertEqual(messages.first?.images.count, 1) + XCTAssertEqual(messages.first?.role, .user) + } + + func testAssistantImagesAreCommittedToThreadHistory() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: ImageReplyAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore())) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Assistant Images") + let reply = try await runtime.sendMessage(UserMessageRequest(text: "show me an image"), in: thread.id) + + XCTAssertEqual(reply, "Attached 1 image") + let messages = await runtime.messages(for: thread.id) + XCTAssertEqual(messages.last?.role, .assistant) + XCTAssertEqual(messages.last?.images.count, 1) + } + + func testRuntimeStreamsToolApprovalAndCompletion() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), tools: [.init(definition: ToolDefinition(name: "demo_lookup_profile", description: "Lookup profile", inputSchema: .object([:]), approvalPolicy: .requiresApproval), executor: AnyToolExecutor { invocation, _ in .success(invocation: invocation, text: "demo-result") })])) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread() + let stream = try await runtime.streamMessage(UserMessageRequest(text: "please use the tool"), in: thread.id) + + var sawApproval = false + var sawToolResult = false + for try await event in stream { + switch event { + case .approvalRequested: + sawApproval = true + case let .toolCallFinished(result): + sawToolResult = true + XCTAssertEqual(result.primaryText, "demo-result") + default: + break + } + } + + XCTAssertTrue(sawApproval) + XCTAssertTrue(sawToolResult) + } + + func testStructuredStreamWorksAlongsideToolCalls() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), tools: [.init(definition: ToolDefinition(name: "demo_lookup_profile", description: "Lookup profile", inputSchema: .object([:]), approvalPolicy: .requiresApproval), executor: AnyToolExecutor { invocation, _ in .success(invocation: invocation, text: "demo-result") })])) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread() + let stream = try await runtime.streamMessage(UserMessageRequest(text: "please use the tool"), in: thread.id, expecting: ShippingReplyDraft.self) + + var sawToolResult = false + var sawCommitted = false + for try await event in stream { + switch event { + case .toolCallFinished: + sawToolResult = true + case .structuredOutputCommitted: + sawCommitted = true + default: + break + } + } + + XCTAssertTrue(sawToolResult) + XCTAssertTrue(sawCommitted) + } + + func testSendMessageRetriesUnauthorizedByRefreshingSession() async throws { + let authProvider = RotatingDemoAuthProvider() + let backend = UnauthorizedThenSuccessBackend() + let runtime = try AgentRuntime(configuration: .init(authProvider: authProvider, secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore())) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Recovered Thread") + _ = try await runtime.sendMessage(UserMessageRequest(text: "Hello after refresh"), in: thread.id) + + let refreshCount = await authProvider.refreshCount() + let attemptedTokens = await backend.attemptedAccessTokens() + let assistantCount = await runtime.messages(for: thread.id) + .filter { $0.role == .assistant } + .count + + XCTAssertEqual(refreshCount, 1) + XCTAssertEqual(attemptedTokens.count, 2) + XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") + XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") + XCTAssertEqual(assistantCount, 1) + } + + func testCreateThreadRetriesUnauthorizedByRefreshingSession() async throws { + let authProvider = RotatingDemoAuthProvider() + let backend = UnauthorizedOnCreateThenSuccessBackend() + let runtime = try AgentRuntime(configuration: .init(authProvider: authProvider, secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore())) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Recovered Thread") + let refreshCountBeforeAssertions = await authProvider.refreshCount() + XCTAssertEqual(thread.title, "Recovered Thread") + let attemptedTokens = await backend.attemptedAccessTokens() + XCTAssertEqual(refreshCountBeforeAssertions, 1) + XCTAssertEqual(attemptedTokens.count, 2) + XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") + XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") + } + + func testConfigurationRegistersInitialTools() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), tools: [.init(definition: ToolDefinition(name: "demo_lookup_profile", description: "Lookup profile", inputSchema: .object([:]), approvalPolicy: .requiresApproval), executor: AnyToolExecutor { invocation, _ in .success(invocation: invocation, text: "demo-result") })])) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread() + let stream = try await runtime.streamMessage(UserMessageRequest(text: "please use the tool"), in: thread.id) + + var sawToolResult = false + for try await event in stream { + if case let .toolCallFinished(result) = event { + sawToolResult = true + XCTAssertEqual(result.primaryText, "demo-result") + } + } + + XCTAssertTrue(sawToolResult) + } + + func testStreamMessageReturnsAndYieldsUserMessageBeforeTurnStartupCompletes() async throws { + let backend = DelayedBeginTurnBackend() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Immediate Stream") + let streamTask = Task { + try await runtime.streamMessage(UserMessageRequest(text: "Hello there"), in: thread.id) + } + + await backend.waitForBeginTurnStart() + let stream = try await awaitValue { + try await streamTask.value + } + + var iterator = stream.makeAsyncIterator() + let firstEvent = try await iterator.next() + let secondEvent = try await iterator.next() + + switch firstEvent { + case let .messageCommitted(message): + XCTAssertEqual(message.role, .user) + XCTAssertEqual(message.text, "Hello there") + default: + XCTFail("Expected the committed user message to arrive first.") + } + + switch secondEvent { + case let .threadStatusChanged(threadID, status): + XCTAssertEqual(threadID, thread.id) + XCTAssertEqual(status, .streaming) + default: + XCTFail("Expected a streaming status update after the committed user message.") + } + + let messagesBeforeRelease = await runtime.messages(for: thread.id) + XCTAssertEqual(messagesBeforeRelease.count, 1) + XCTAssertEqual(messagesBeforeRelease.first?.text, "Hello there") + + await backend.releaseBeginTurn() + while let _ = try await iterator.next() {} + } + + func testSendMessagePublishesUserMessageObservationBeforeFinalReplyCompletes() async throws { + let backend = DelayedBeginTurnBackend() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Observed Send") + let observedUserMessage = expectation(description: "Observed the local user message") + observedUserMessage.assertForOverFulfill = false + var cancellables = Set() + + runtime.observations + .sink { observation in + guard case let .messagesChanged(threadID, messages) = observation, + threadID == thread.id, + messages.last?.role == .user, + messages.last?.text == "Observe me" + else { + return + } + observedUserMessage.fulfill() + } + .store(in: &cancellables) + + let sendTask = Task { + try await runtime.sendMessage(UserMessageRequest(text: "Observe me"), in: thread.id) + } + + await backend.waitForBeginTurnStart() + await fulfillment(of: [observedUserMessage], timeout: 0.5) + + let messagesBeforeRelease = await runtime.messages(for: thread.id) + XCTAssertEqual(messagesBeforeRelease.map(\.text), ["Observe me"]) + + await backend.releaseBeginTurn() + let reply = try await sendTask.value + XCTAssertEqual(reply, "Echo: Observe me") + } +} + +func drainStructuredStream( + _ stream: AsyncThrowingStream, Error> +) async throws { + for try await _ in stream {} +} + +func collectStructuredStreamFailures( + from stream: AsyncThrowingStream, Error>, + into failures: inout [AgentStructuredOutputValidationFailure] +) async throws { + for try await event in stream { + if case let .structuredOutputValidationFailed(validationFailure) = event { + failures.append(validationFailure) + } + } +} diff --git a/Tests/CodexKitTests/AgentRuntimeMessageTests.swift b/Tests/CodexKitTests/AgentRuntimeMessageTests.swift index dfe1f23..2266ffb 100644 --- a/Tests/CodexKitTests/AgentRuntimeMessageTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeMessageTests.swift @@ -344,308 +344,4 @@ extension AgentRuntimeTests { XCTAssertEqual(validationFailures.last?.stage, .committed) } - func testImportedContentInitializerBuildsMessageWithSharedURLs() async throws { - let importedContent = AgentImportedContent( - textSnippets: ["Customer says the package arrived damaged."], - urls: [URL(string: "https://example.com/delivery-update")!] - ) - - let request = UserMessageRequest( - prompt: "Summarize and draft a reply.", - importedContent: importedContent - ) - - XCTAssertTrue(request.text.contains("Summarize and draft a reply.")) - XCTAssertTrue(request.text.contains("https://example.com/delivery-update")) - XCTAssertTrue(request.text.contains("Customer says the package arrived damaged.")) - } - - func testImageOnlyMessageIsAcceptedAndPersisted() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Images") - let image = AgentImageAttachment.png(Data([0x89, 0x50, 0x4E, 0x47])) - - _ = try await runtime.sendMessage( - UserMessageRequest(text: "", images: [image]), - in: thread.id - ) - - let messages = await runtime.messages(for: thread.id) - XCTAssertEqual(messages.first?.images.count, 1) - XCTAssertEqual(messages.first?.role, .user) - } - - func testAssistantImagesAreCommittedToThreadHistory() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: ImageReplyAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Assistant Images") - let reply = try await runtime.sendMessage( - UserMessageRequest(text: "show me an image"), - in: thread.id - ) - - XCTAssertEqual(reply, "Attached 1 image") - - let messages = await runtime.messages(for: thread.id) - XCTAssertEqual(messages.last?.role, .assistant) - XCTAssertEqual(messages.last?.images.count, 1) - } - - func testRuntimeStreamsToolApprovalAndCompletion() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - tools: [ - .init( - definition: ToolDefinition( - name: "demo_lookup_profile", - description: "Lookup profile", - inputSchema: .object([:]), - approvalPolicy: .requiresApproval - ), - executor: AnyToolExecutor { invocation, _ in - .success(invocation: invocation, text: "demo-result") - } - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread() - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "please use the tool"), - in: thread.id - ) - - var sawApproval = false - var sawToolResult = false - - for try await event in stream { - switch event { - case .approvalRequested: - sawApproval = true - case let .toolCallFinished(result): - sawToolResult = true - XCTAssertEqual(result.primaryText, "demo-result") - default: - break - } - } - - XCTAssertTrue(sawApproval) - XCTAssertTrue(sawToolResult) - } - - func testStructuredStreamWorksAlongsideToolCalls() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - tools: [ - .init( - definition: ToolDefinition( - name: "demo_lookup_profile", - description: "Lookup profile", - inputSchema: .object([:]), - approvalPolicy: .requiresApproval - ), - executor: AnyToolExecutor { invocation, _ in - .success(invocation: invocation, text: "demo-result") - } - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread() - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "please use the tool"), - in: thread.id, - expecting: ShippingReplyDraft.self - ) - - var sawToolResult = false - var sawCommitted = false - - for try await event in stream { - switch event { - case .toolCallFinished: - sawToolResult = true - case .structuredOutputCommitted: - sawCommitted = true - default: - break - } - } - - XCTAssertTrue(sawToolResult) - XCTAssertTrue(sawCommitted) - } - - func testSendMessageRetriesUnauthorizedByRefreshingSession() async throws { - let authProvider = RotatingDemoAuthProvider() - let backend = UnauthorizedThenSuccessBackend() - let runtime = try AgentRuntime(configuration: .init( - authProvider: authProvider, - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Recovered Thread") - _ = try await runtime.sendMessage( - UserMessageRequest(text: "Hello after refresh"), - in: thread.id - ) - - let refreshCount = await authProvider.refreshCount() - XCTAssertEqual(refreshCount, 1) - - let attemptedTokens = await backend.attemptedAccessTokens() - XCTAssertEqual(attemptedTokens.count, 2) - XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") - XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") - - let messages = await runtime.messages(for: thread.id) - XCTAssertEqual(messages.filter { $0.role == .assistant }.count, 1) - } - - func testCreateThreadRetriesUnauthorizedByRefreshingSession() async throws { - let authProvider = RotatingDemoAuthProvider() - let backend = UnauthorizedOnCreateThenSuccessBackend() - let runtime = try AgentRuntime(configuration: .init( - authProvider: authProvider, - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Recovered Thread") - XCTAssertEqual(thread.title, "Recovered Thread") - - let refreshCount = await authProvider.refreshCount() - XCTAssertEqual(refreshCount, 1) - - let attemptedTokens = await backend.attemptedAccessTokens() - XCTAssertEqual(attemptedTokens.count, 2) - XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") - XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") - } - - func testConfigurationRegistersInitialTools() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - tools: [ - .init( - definition: ToolDefinition( - name: "demo_lookup_profile", - description: "Lookup profile", - inputSchema: .object([:]), - approvalPolicy: .requiresApproval - ), - executor: AnyToolExecutor { invocation, _ in - .success(invocation: invocation, text: "demo-result") - } - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread() - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "please use the tool"), - in: thread.id - ) - - var sawToolResult = false - - for try await event in stream { - if case let .toolCallFinished(result) = event { - sawToolResult = true - XCTAssertEqual(result.primaryText, "demo-result") - } - } - - XCTAssertTrue(sawToolResult) - } -} - -private func drainStructuredStream( - _ stream: AsyncThrowingStream, Error> -) async throws { - for try await _ in stream {} -} - -private func collectStructuredStreamFailures( - from stream: AsyncThrowingStream, Error>, - into failures: inout [AgentStructuredOutputValidationFailure] -) async throws { - for try await event in stream { - if case let .structuredOutputValidationFailed(validationFailure) = event { - failures.append(validationFailure) - } - } } diff --git a/Tests/CodexKitTests/AgentRuntimePersonaSkillPolicyTests.swift b/Tests/CodexKitTests/AgentRuntimePersonaSkillPolicyTests.swift new file mode 100644 index 0000000..9cf4abb --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimePersonaSkillPolicyTests.swift @@ -0,0 +1,109 @@ +import CodexKit +import XCTest + +extension AgentRuntimeTests { + func testSkillPolicyBlocksDisallowedToolCalls() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), skills: [.init(id: "strict_support", name: "Strict Support", instructions: "Answer directly.", executionPolicy: .init(allowedToolNames: ["allowed_tool"]))])) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + try await runtime.registerTool(ToolDefinition(name: "demo_lookup_profile", description: "Lookup profile", inputSchema: .object([:]), approvalPolicy: .automatic), executor: AnyToolExecutor { invocation, _ in .success(invocation: invocation, text: "profile-ok") }) + let thread = try await runtime.createThread(title: "Strict Tool Policy", skillIDs: ["strict_support"]) + let stream = try await runtime.streamMessage(UserMessageRequest(text: "please use the tool"), in: thread.id) + for try await _ in stream {} + + let assistantText = await runtime.messages(for: thread.id) + .filter { $0.role == .assistant } + .map(\.text) + .joined(separator: "\n") + XCTAssertTrue(assistantText.contains("not allowed by the active skill policy")) + } + + func testSkillPolicyFailsTurnWhenRequiredToolIsMissing() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), skills: [.init(id: "requires_tool", name: "Requires Tool", instructions: "Use the required tool.", executionPolicy: .init(requiredToolNames: ["demo_lookup_profile"]))])) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Required Tool", skillIDs: ["requires_tool"]) + let stream = try await runtime.streamMessage(UserMessageRequest(text: "hello without tool"), in: thread.id) + + var sawTurnFailed = false + var failureError: AgentRuntimeError? + + do { + for try await event in stream { + if case let .turnFailed(error) = event { + sawTurnFailed = true + failureError = error + } + } + XCTFail("Expected turn stream to throw when required tools are missing.") + } catch { + XCTAssertEqual((error as? AgentRuntimeError)?.code, "skill_required_tools_missing") + } + + XCTAssertTrue(sawTurnFailed) + XCTAssertEqual(failureError?.code, "skill_required_tools_missing") + } + + func testRuntimeRejectsSkillWithInvalidPolicyToolName() async throws { + XCTAssertThrowsError( + try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), skills: [.init(id: "invalid_policy", name: "Invalid Policy", instructions: "Invalid tool name policy.", executionPolicy: .init(requiredToolNames: ["bad tool name"]))])) + ) { error in + XCTAssertEqual((error as? AgentRuntimeError)?.code, "invalid_skill_tool_name") + } + } + + func testResolvedInstructionsPreviewIncludesThreadPersonaAndSkills() async throws { + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: InMemoryAgentBackend(baseInstructions: "Base host instructions."), approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore(), skills: [.init(id: "health_coach", name: "Health Coach", instructions: "Coach users toward their daily step goals.")])) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Preview", personaStack: AgentPersonaStack(layers: [.init(name: "planner", instructions: "Act as a planning specialist.")]), skillIDs: ["health_coach"]) + let preview = try await runtime.resolvedInstructionsPreview(for: thread.id, request: UserMessageRequest(text: "Give me a plan.")) + + XCTAssertTrue(preview.contains("Base host instructions.")) + XCTAssertTrue(preview.contains("Thread Persona Layers:")) + XCTAssertTrue(preview.contains("[planner]")) + XCTAssertTrue(preview.contains("Thread Skills:")) + XCTAssertTrue(preview.contains("[health_coach: Health Coach]")) + } + + func testCreateThreadLoadsPersonaFromFileSource() async throws { + let backend = InMemoryAgentBackend(baseInstructions: "Base host instructions.") + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore())) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let personaText = "Act as a migration planning assistant focused on sequencing." + let personaFile = try temporaryFile(with: personaText, pathExtension: "txt") + let thread = try await runtime.createThread(title: "Dynamic Persona", personaSource: .file(personaFile)) + + let personaStack = try XCTUnwrap(thread.personaStack) + XCTAssertEqual(personaStack.layers.count, 1) + XCTAssertEqual(personaStack.layers[0].instructions, personaText) + } + + func testRegisterSkillFromFileSourceCanBeUsedInThread() async throws { + let backend = InMemoryAgentBackend(baseInstructions: "Base host instructions.") + let runtime = try AgentRuntime(configuration: .init(authProvider: DemoChatGPTAuthProvider(), secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), backend: backend, approvalPresenter: AutoApprovalPresenter(), stateStore: InMemoryRuntimeStateStore())) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let skillJSON = """ + { + "id": "travel_planner", + "name": "Travel Planner", + "instructions": "Keep itineraries compact and transit-aware." + } + """ + let skillFile = try temporaryFile(with: skillJSON, pathExtension: "json") + _ = try await runtime.registerSkill(from: .file(skillFile)) + let thread = try await runtime.createThread(title: "Travel", skillIDs: ["travel_planner"]) + _ = try await runtime.sendMessage(UserMessageRequest(text: "Plan my day"), in: thread.id) + + let receivedInstructions = await backend.receivedInstructions() + let resolved = try XCTUnwrap(receivedInstructions.last) + XCTAssertTrue(resolved.contains("[travel_planner: Travel Planner]")) + } +} diff --git a/Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift b/Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift index 5ecd58a..cc90409 100644 --- a/Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift +++ b/Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift @@ -349,265 +349,4 @@ extension AgentRuntimeTests { } } - func testSkillPolicyBlocksDisallowedToolCalls() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "strict_support", - name: "Strict Support", - instructions: "Answer directly.", - executionPolicy: .init( - allowedToolNames: ["allowed_tool"] - ) - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - try await runtime.registerTool( - ToolDefinition( - name: "demo_lookup_profile", - description: "Lookup profile", - inputSchema: .object([:]), - approvalPolicy: .automatic - ), - executor: AnyToolExecutor { invocation, _ in - .success(invocation: invocation, text: "profile-ok") - } - ) - - let thread = try await runtime.createThread( - title: "Strict Tool Policy", - skillIDs: ["strict_support"] - ) - - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "please use the tool"), - in: thread.id - ) - for try await _ in stream {} - - let messages = await runtime.messages(for: thread.id) - let assistantText = messages - .filter { $0.role == .assistant } - .map(\.text) - .joined(separator: "\n") - XCTAssertTrue(assistantText.contains("not allowed by the active skill policy")) - } - - func testSkillPolicyFailsTurnWhenRequiredToolIsMissing() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "requires_tool", - name: "Requires Tool", - instructions: "Use the required tool.", - executionPolicy: .init( - requiredToolNames: ["demo_lookup_profile"] - ) - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Required Tool", - skillIDs: ["requires_tool"] - ) - - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "hello without tool"), - in: thread.id - ) - - var sawTurnFailed = false - var failureError: AgentRuntimeError? - - do { - for try await event in stream { - if case let .turnFailed(error) = event { - sawTurnFailed = true - failureError = error - } - } - XCTFail("Expected turn stream to throw when required tools are missing.") - } catch { - XCTAssertEqual((error as? AgentRuntimeError)?.code, "skill_required_tools_missing") - } - - XCTAssertTrue(sawTurnFailed) - XCTAssertEqual(failureError?.code, "skill_required_tools_missing") - } - - func testRuntimeRejectsSkillWithInvalidPolicyToolName() async throws { - XCTAssertThrowsError( - try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "invalid_policy", - name: "Invalid Policy", - instructions: "Invalid tool name policy.", - executionPolicy: .init( - requiredToolNames: ["bad tool name"] - ) - ), - ] - )) - ) { error in - XCTAssertEqual((error as? AgentRuntimeError)?.code, "invalid_skill_tool_name") - } - } - - func testResolvedInstructionsPreviewIncludesThreadPersonaAndSkills() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(baseInstructions: "Base host instructions."), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "health_coach", - name: "Health Coach", - instructions: "Coach users toward their daily step goals." - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Preview", - personaStack: AgentPersonaStack(layers: [ - .init(name: "planner", instructions: "Act as a planning specialist.") - ]), - skillIDs: ["health_coach"] - ) - - let preview = try await runtime.resolvedInstructionsPreview( - for: thread.id, - request: UserMessageRequest(text: "Give me a plan.") - ) - - XCTAssertTrue(preview.contains("Base host instructions.")) - XCTAssertTrue(preview.contains("Thread Persona Layers:")) - XCTAssertTrue(preview.contains("[planner]")) - XCTAssertTrue(preview.contains("Thread Skills:")) - XCTAssertTrue(preview.contains("[health_coach: Health Coach]")) - } - - func testCreateThreadLoadsPersonaFromFileSource() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let personaText = "Act as a migration planning assistant focused on sequencing." - let personaFile = try temporaryFile( - with: personaText, - pathExtension: "txt" - ) - - let thread = try await runtime.createThread( - title: "Dynamic Persona", - personaSource: .file(personaFile) - ) - - let personaStack = try XCTUnwrap(thread.personaStack) - XCTAssertEqual(personaStack.layers.count, 1) - XCTAssertEqual(personaStack.layers[0].instructions, personaText) - } - - func testRegisterSkillFromFileSourceCanBeUsedInThread() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let skillJSON = """ - { - "id": "travel_planner", - "name": "Travel Planner", - "instructions": "Keep itineraries compact and transit-aware." - } - """ - let skillFile = try temporaryFile( - with: skillJSON, - pathExtension: "json" - ) - - _ = try await runtime.registerSkill( - from: .file(skillFile) - ) - - let thread = try await runtime.createThread( - title: "Travel", - skillIDs: ["travel_planner"] - ) - - _ = try await runtime.sendMessage( - UserMessageRequest(text: "Plan my day"), - in: thread.id - ) - - let instructions = await backend.receivedInstructions() - let resolved = try XCTUnwrap(instructions.last) - XCTAssertTrue(resolved.contains("[travel_planner: Travel Planner]")) - } } diff --git a/Tests/CodexKitTests/CodexResponsesBackendRetryTests.swift b/Tests/CodexKitTests/CodexResponsesBackendRetryTests.swift new file mode 100644 index 0000000..e7a92ba --- /dev/null +++ b/Tests/CodexKitTests/CodexResponsesBackendRetryTests.swift @@ -0,0 +1,98 @@ +import CodexKit +import XCTest + +extension CodexResponsesBackendTests { + func testBackendRetriesTransientStatusCodeWithBackoffPolicy() async throws { + let backend = CodexResponsesBackend( + configuration: CodexResponsesBackendConfiguration( + requestRetryPolicy: .init( + maxAttempts: 2, + initialBackoff: 0, + maxBackoff: 0, + jitterFactor: 0 + ) + ), + urlSession: makeTestURLSession() + ) + let session = ChatGPTSession( + accessToken: "access-token", + refreshToken: "refresh-token", + account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus) + ) + + await TestURLProtocol.enqueue(.init(statusCode: 503, headers: ["Content-Type": "application/json"], body: Data(#"{"error":"upstream overloaded"}"#.utf8))) + await TestURLProtocol.enqueue(.init(headers: ["Content-Type": "text/event-stream"], body: Data(""" + event: response.output_item.done + data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Recovered"}]}} + + event: response.completed + data: {"type":"response.completed","response":{"id":"resp_retry","usage":{"input_tokens":5,"input_tokens_details":{"cached_tokens":0},"output_tokens":2}}} + + """.utf8))) + + let turnStream = try await backend.beginTurn(thread: AgentThread(id: "thread-retry"), history: [], message: UserMessageRequest(text: "Hi"), instructions: "Resolved instructions", responseFormat: nil, streamedStructuredOutput: nil, tools: [], session: session) + + var assistantMessage: AgentMessage? + for try await event in turnStream.events { + if case let .assistantMessageCompleted(message) = event { + assistantMessage = message + } + } + + XCTAssertEqual(assistantMessage?.text, "Recovered") + } + + func testBackendDoesNotRetryNonRetryableStatusCode() async throws { + let backend = CodexResponsesBackend( + configuration: CodexResponsesBackendConfiguration( + requestRetryPolicy: .init(maxAttempts: 3, initialBackoff: 0, maxBackoff: 0, jitterFactor: 0) + ), + urlSession: makeTestURLSession() + ) + let session = ChatGPTSession(accessToken: "access-token", refreshToken: "refresh-token", account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus)) + + await TestURLProtocol.enqueue(.init(statusCode: 400, headers: ["Content-Type": "application/json"], body: Data(#"{"error":"bad request"}"#.utf8))) + await TestURLProtocol.enqueue(.init(headers: ["Content-Type": "text/event-stream"], body: Data(), inspect: { _ in XCTFail("Non-retryable 400 should not trigger a retry.") })) + + let turnStream = try await backend.beginTurn(thread: AgentThread(id: "thread-no-retry"), history: [], message: UserMessageRequest(text: "Hi"), instructions: "Resolved instructions", responseFormat: nil, streamedStructuredOutput: nil, tools: [], session: session) + + await XCTAssertThrowsErrorAsync(try await drainEvents(turnStream.events)) { error in + XCTAssertEqual(error as? AgentRuntimeError, AgentRuntimeError(code: "responses_http_status_400", message: "The ChatGPT responses request failed with status 400: {\"error\":\"bad request\"}")) + } + } + + func testBackendRetriesWhenNetworkConnectionIsLostBeforeOutput() async throws { + let backend = CodexResponsesBackend( + configuration: CodexResponsesBackendConfiguration( + requestRetryPolicy: .init(maxAttempts: 2, initialBackoff: 0, maxBackoff: 0, jitterFactor: 0) + ), + urlSession: makeTestURLSession() + ) + let session = ChatGPTSession(accessToken: "access-token", refreshToken: "refresh-token", account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus)) + + await TestURLProtocol.enqueue(.init(body: Data(), error: URLError(.networkConnectionLost))) + await TestURLProtocol.enqueue(.init(headers: ["Content-Type": "text/event-stream"], body: Data(""" + event: response.output_item.done + data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Recovered after network loss"}]}} + + event: response.completed + data: {"type":"response.completed","response":{"id":"resp_network_retry","usage":{"input_tokens":5,"input_tokens_details":{"cached_tokens":0},"output_tokens":2}}} + + """.utf8))) + + let turnStream = try await backend.beginTurn(thread: AgentThread(id: "thread-network-retry"), history: [], message: UserMessageRequest(text: "Retry me"), instructions: "Resolved instructions", responseFormat: nil, streamedStructuredOutput: nil, tools: [], session: session) + + var assistantMessage: AgentMessage? + for try await event in turnStream.events { + if case let .assistantMessageCompleted(message) = event { + assistantMessage = message + } + } + + XCTAssertEqual(assistantMessage?.text, "Recovered after network loss") + } +} + +private func drainEvents(_ events: AsyncThrowingStream) async throws { + for try await _ in events {} +} diff --git a/Tests/CodexKitTests/CodexResponsesBackendStructuredTests.swift b/Tests/CodexKitTests/CodexResponsesBackendStructuredTests.swift new file mode 100644 index 0000000..3cc727f --- /dev/null +++ b/Tests/CodexKitTests/CodexResponsesBackendStructuredTests.swift @@ -0,0 +1,92 @@ +import CodexKit +import XCTest + +extension CodexResponsesBackendTests { + func testBackendEncodesStructuredOutputFormat() async throws { + let backend = CodexResponsesBackend(urlSession: makeTestURLSession()) + let session = ChatGPTSession(accessToken: "access-token", refreshToken: "refresh-token", account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus)) + let responseFormat = AgentStructuredOutputFormat(name: "shipping_reply_draft", description: "A concise shipping support reply draft.", schema: .object(properties: ["reply": .string()], required: ["reply"], additionalProperties: false)) + + await TestURLProtocol.enqueue(.init(headers: ["Content-Type": "text/event-stream"], body: Data(""" + event: response.output_item.done + data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"{\\"reply\\":\\"Done\\"}"}]}} + + event: response.completed + data: {"type":"response.completed","response":{"id":"resp_structured","usage":{"input_tokens":4,"input_tokens_details":{"cached_tokens":0},"output_tokens":1}}} + + """.utf8), inspect: { request in + let body = try XCTUnwrap(requestBodyData(for: request)) + let json = try JSONSerialization.jsonObject(with: body) as? [String: Any] + let text = try XCTUnwrap(json?["text"] as? [String: Any]) + let format = try XCTUnwrap(text["format"] as? [String: Any]) + XCTAssertEqual(format["type"] as? String, "json_schema") + XCTAssertEqual(format["name"] as? String, "shipping_reply_draft") + XCTAssertEqual(format["description"] as? String, "A concise shipping support reply draft.") + XCTAssertEqual(format["strict"] as? Bool, true) + let schema = try XCTUnwrap(format["schema"] as? [String: Any]) + XCTAssertEqual(schema["type"] as? String, "object") + })) + + let turnStream = try await backend.beginTurn(thread: AgentThread(id: "thread-structured"), history: [], message: UserMessageRequest(text: "Draft a reply."), instructions: "Resolved instructions", responseFormat: responseFormat, streamedStructuredOutput: nil, tools: [], session: session) + for try await _ in turnStream.events {} + } + + func testBackendStripsStructuredStreamFramingFromVisibleAssistantText() async throws { + let backend = CodexResponsesBackend(urlSession: makeTestURLSession()) + let session = ChatGPTSession(accessToken: "access-token", refreshToken: "refresh-token", account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus)) + let streamedStructuredOutput = AgentStreamedStructuredOutputRequest( + responseFormat: AgentStructuredOutputFormat( + name: "shipping_reply_draft", + description: "A concise shipping support reply draft.", + schema: .object(properties: ["reply": .string(), "priority": .string()], required: ["reply", "priority"], additionalProperties: false) + ), + options: .init(required: true) + ) + + await TestURLProtocol.enqueue(.init(headers: ["Content-Type": "text/event-stream"], body: Data(""" + event: response.output_text.delta + data: {"type":"response.output_text.delta","delta":"Hello "} + + event: response.output_text.delta + data: {"type":"response.output_text.delta","delta":"{\\"reply\\":\\"Done\\",\\"priority\\":\\"high\\"}"} + + event: response.output_item.done + data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello {\\"reply\\":\\"Done\\",\\"priority\\":\\"high\\"}"}]}} + + event: response.completed + data: {"type":"response.completed","response":{"id":"resp_streamed","usage":{"input_tokens":4,"input_tokens_details":{"cached_tokens":0},"output_tokens":1}}} + + """.utf8), inspect: { request in + let body = try XCTUnwrap(requestBodyData(for: request)) + let json = try JSONSerialization.jsonObject(with: body) as? [String: Any] + let instructions = try XCTUnwrap(json?["instructions"] as? String) + XCTAssertTrue(instructions.contains("CodexKit private streaming contract")) + let text = json?["text"] as? [String: Any] + let format = try XCTUnwrap(text?["format"] as? [String: Any]) + XCTAssertEqual(format["type"] as? String, "text") + })) + + let turnStream = try await backend.beginTurn(thread: AgentThread(id: "thread-streamed-structured"), history: [], message: UserMessageRequest(text: "Draft a reply."), instructions: "Resolved instructions", responseFormat: nil, streamedStructuredOutput: streamedStructuredOutput, tools: [], session: session) + + var deltas: [String] = [] + var finalAssistantMessage: AgentMessage? + var committedValue: JSONValue? + + for try await event in turnStream.events { + switch event { + case let .assistantMessageDelta(_, _, delta): + deltas.append(delta) + case let .assistantMessageCompleted(message): + finalAssistantMessage = message + case let .structuredOutputCommitted(value): + committedValue = value + default: + break + } + } + + XCTAssertEqual(deltas.joined(), "Hello ") + XCTAssertEqual(finalAssistantMessage?.text, "Hello") + XCTAssertEqual(committedValue, .object(["reply": .string("Done"), "priority": .string("high")])) + } +} diff --git a/Tests/CodexKitTests/CodexResponsesBackendTests.swift b/Tests/CodexKitTests/CodexResponsesBackendTests.swift index ff8d789..062b248 100644 --- a/Tests/CodexKitTests/CodexResponsesBackendTests.swift +++ b/Tests/CodexKitTests/CodexResponsesBackendTests.swift @@ -462,344 +462,4 @@ final class CodexResponsesBackendTests: XCTestCase { XCTAssertEqual(assistantMessage?.images.first?.data, pngBytes) } - func testBackendRetriesTransientStatusCodeWithBackoffPolicy() async throws { - let backend = CodexResponsesBackend( - configuration: CodexResponsesBackendConfiguration( - requestRetryPolicy: .init( - maxAttempts: 2, - initialBackoff: 0, - maxBackoff: 0, - jitterFactor: 0 - ) - ), - urlSession: makeTestURLSession() - ) - let session = ChatGPTSession( - accessToken: "access-token", - refreshToken: "refresh-token", - account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus) - ) - - await TestURLProtocol.enqueue( - .init( - statusCode: 503, - headers: ["Content-Type": "application/json"], - body: Data(#"{"error":"upstream overloaded"}"#.utf8) - ) - ) - - await TestURLProtocol.enqueue( - .init( - headers: ["Content-Type": "text/event-stream"], - body: Data( - """ - event: response.output_item.done - data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Recovered"}]}} - - event: response.completed - data: {"type":"response.completed","response":{"id":"resp_retry","usage":{"input_tokens":5,"input_tokens_details":{"cached_tokens":0},"output_tokens":2}}} - - """.utf8 - ) - ) - ) - - let turnStream = try await backend.beginTurn( - thread: AgentThread(id: "thread-retry"), - history: [], - message: UserMessageRequest(text: "Hi"), - instructions: "Resolved instructions", - responseFormat: nil, - streamedStructuredOutput: nil, - tools: [], - session: session - ) - - var assistantMessage: AgentMessage? - for try await event in turnStream.events { - if case let .assistantMessageCompleted(message) = event { - assistantMessage = message - } - } - - XCTAssertEqual(assistantMessage?.text, "Recovered") - } - - func testBackendDoesNotRetryNonRetryableStatusCode() async throws { - let backend = CodexResponsesBackend( - configuration: CodexResponsesBackendConfiguration( - requestRetryPolicy: .init( - maxAttempts: 3, - initialBackoff: 0, - maxBackoff: 0, - jitterFactor: 0 - ) - ), - urlSession: makeTestURLSession() - ) - let session = ChatGPTSession( - accessToken: "access-token", - refreshToken: "refresh-token", - account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus) - ) - - await TestURLProtocol.enqueue( - .init( - statusCode: 400, - headers: ["Content-Type": "application/json"], - body: Data(#"{"error":"bad request"}"#.utf8) - ) - ) - await TestURLProtocol.enqueue( - .init( - headers: ["Content-Type": "text/event-stream"], - body: Data(), - inspect: { _ in - XCTFail("Non-retryable 400 should not trigger a retry.") - } - ) - ) - - let turnStream = try await backend.beginTurn( - thread: AgentThread(id: "thread-no-retry"), - history: [], - message: UserMessageRequest(text: "Hi"), - instructions: "Resolved instructions", - responseFormat: nil, - streamedStructuredOutput: nil, - tools: [], - session: session - ) - - await XCTAssertThrowsErrorAsync(try await drainEvents(turnStream.events)) { error in - XCTAssertEqual( - error as? AgentRuntimeError, - AgentRuntimeError( - code: "responses_http_status_400", - message: "The ChatGPT responses request failed with status 400: {\"error\":\"bad request\"}" - ) - ) - } - } - - func testBackendRetriesWhenNetworkConnectionIsLostBeforeOutput() async throws { - let backend = CodexResponsesBackend( - configuration: CodexResponsesBackendConfiguration( - requestRetryPolicy: .init( - maxAttempts: 2, - initialBackoff: 0, - maxBackoff: 0, - jitterFactor: 0 - ) - ), - urlSession: makeTestURLSession() - ) - let session = ChatGPTSession( - accessToken: "access-token", - refreshToken: "refresh-token", - account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus) - ) - - await TestURLProtocol.enqueue( - .init( - body: Data(), - error: URLError(.networkConnectionLost) - ) - ) - - await TestURLProtocol.enqueue( - .init( - headers: ["Content-Type": "text/event-stream"], - body: Data( - """ - event: response.output_item.done - data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Recovered after network loss"}]}} - - event: response.completed - data: {"type":"response.completed","response":{"id":"resp_network_retry","usage":{"input_tokens":5,"input_tokens_details":{"cached_tokens":0},"output_tokens":2}}} - - """.utf8 - ) - ) - ) - - let turnStream = try await backend.beginTurn( - thread: AgentThread(id: "thread-network-retry"), - history: [], - message: UserMessageRequest(text: "Retry me"), - instructions: "Resolved instructions", - responseFormat: nil, - streamedStructuredOutput: nil, - tools: [], - session: session - ) - - var assistantMessage: AgentMessage? - for try await event in turnStream.events { - if case let .assistantMessageCompleted(message) = event { - assistantMessage = message - } - } - - XCTAssertEqual(assistantMessage?.text, "Recovered after network loss") - } - - func testBackendEncodesStructuredOutputFormat() async throws { - let backend = CodexResponsesBackend(urlSession: makeTestURLSession()) - let session = ChatGPTSession( - accessToken: "access-token", - refreshToken: "refresh-token", - account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus) - ) - let responseFormat = AgentStructuredOutputFormat( - name: "shipping_reply_draft", - description: "A concise shipping support reply draft.", - schema: .object( - properties: [ - "reply": .string(), - ], - required: ["reply"], - additionalProperties: false - ) - ) - - await TestURLProtocol.enqueue( - .init( - headers: ["Content-Type": "text/event-stream"], - body: Data( - """ - event: response.output_item.done - data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"{\\"reply\\":\\"Done\\"}"}]}} - - event: response.completed - data: {"type":"response.completed","response":{"id":"resp_structured","usage":{"input_tokens":4,"input_tokens_details":{"cached_tokens":0},"output_tokens":1}}} - - """.utf8 - ), - inspect: { request in - let body = try XCTUnwrap(requestBodyData(for: request)) - let json = try JSONSerialization.jsonObject(with: body) as? [String: Any] - let text = try XCTUnwrap(json?["text"] as? [String: Any]) - let format = try XCTUnwrap(text["format"] as? [String: Any]) - XCTAssertEqual(format["type"] as? String, "json_schema") - XCTAssertEqual(format["name"] as? String, "shipping_reply_draft") - XCTAssertEqual(format["description"] as? String, "A concise shipping support reply draft.") - XCTAssertEqual(format["strict"] as? Bool, true) - let schema = try XCTUnwrap(format["schema"] as? [String: Any]) - XCTAssertEqual(schema["type"] as? String, "object") - } - ) - ) - - let turnStream = try await backend.beginTurn( - thread: AgentThread(id: "thread-structured"), - history: [], - message: UserMessageRequest(text: "Draft a reply."), - instructions: "Resolved instructions", - responseFormat: responseFormat, - streamedStructuredOutput: nil, - tools: [], - session: session - ) - - for try await _ in turnStream.events {} - } - - func testBackendStripsStructuredStreamFramingFromVisibleAssistantText() async throws { - let backend = CodexResponsesBackend(urlSession: makeTestURLSession()) - let session = ChatGPTSession( - accessToken: "access-token", - refreshToken: "refresh-token", - account: ChatGPTAccount(id: "workspace-123", email: "taylor@example.com", plan: .plus) - ) - let streamedStructuredOutput = AgentStreamedStructuredOutputRequest( - responseFormat: AgentStructuredOutputFormat( - name: "shipping_reply_draft", - description: "A concise shipping support reply draft.", - schema: .object( - properties: [ - "reply": .string(), - "priority": .string(), - ], - required: ["reply", "priority"], - additionalProperties: false - ) - ), - options: .init(required: true) - ) - - await TestURLProtocol.enqueue( - .init( - headers: ["Content-Type": "text/event-stream"], - body: Data( - """ - event: response.output_text.delta - data: {"type":"response.output_text.delta","delta":"Hello "} - - event: response.output_text.delta - data: {"type":"response.output_text.delta","delta":"{\\"reply\\":\\"Done\\",\\"priority\\":\\"high\\"}"} - - event: response.output_item.done - data: {"type":"response.output_item.done","item":{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello {\\"reply\\":\\"Done\\",\\"priority\\":\\"high\\"}"}]}} - - event: response.completed - data: {"type":"response.completed","response":{"id":"resp_streamed","usage":{"input_tokens":4,"input_tokens_details":{"cached_tokens":0},"output_tokens":1}}} - - """.utf8 - ), - inspect: { request in - let body = try XCTUnwrap(requestBodyData(for: request)) - let json = try JSONSerialization.jsonObject(with: body) as? [String: Any] - let instructions = try XCTUnwrap(json?["instructions"] as? String) - XCTAssertTrue(instructions.contains("CodexKit private streaming contract")) - let text = json?["text"] as? [String: Any] - let format = try XCTUnwrap(text?["format"] as? [String: Any]) - XCTAssertEqual(format["type"] as? String, "text") - } - ) - ) - - let turnStream = try await backend.beginTurn( - thread: AgentThread(id: "thread-streamed-structured"), - history: [], - message: UserMessageRequest(text: "Draft a reply."), - instructions: "Resolved instructions", - responseFormat: nil, - streamedStructuredOutput: streamedStructuredOutput, - tools: [], - session: session - ) - - var deltas: [String] = [] - var finalAssistantMessage: AgentMessage? - var committedValue: JSONValue? - - for try await event in turnStream.events { - switch event { - case let .assistantMessageDelta(_, _, delta): - deltas.append(delta) - case let .assistantMessageCompleted(message): - finalAssistantMessage = message - case let .structuredOutputCommitted(value): - committedValue = value - default: - break - } - } - - XCTAssertEqual(deltas.joined(), "Hello ") - XCTAssertEqual(finalAssistantMessage?.text, "Hello") - XCTAssertEqual( - committedValue, - .object([ - "reply": .string("Done"), - "priority": .string("high"), - ]) - ) - } - -} - -private func drainEvents(_ events: AsyncThrowingStream) async throws { - for try await _ in events {} } From 468a0e432298ee3bf7e1230f1fe47097ab85afbe Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 15:09:41 +1100 Subject: [PATCH 10/19] Add Combine state publishers for runtime observation --- .../Runtime/AgentRuntimeObservation.swift | 134 ++++++++++++++++++ .../AgentRuntimeMessageBehaviorTests.swift | 101 +++++++++++++ 2 files changed, 235 insertions(+) diff --git a/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift b/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift index 78565a5..51312ce 100644 --- a/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift +++ b/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift @@ -11,7 +11,13 @@ public enum AgentRuntimeObservation: Sendable { } public final class AgentRuntimeObservationCenter: @unchecked Sendable { + private let lock = NSLock() private let subject = PassthroughSubject() + private let threadsSubject = CurrentValueSubject<[AgentThread], Never>([]) + private var threadSubjects: [String: CurrentValueSubject] = [:] + private var messageSubjects: [String: CurrentValueSubject<[AgentMessage], Never>] = [:] + private var summarySubjects: [String: CurrentValueSubject] = [:] + private var contextStateSubjects: [String: CurrentValueSubject] = [:] public init() {} @@ -19,7 +25,135 @@ public final class AgentRuntimeObservationCenter: @unchecked Sendable { subject.eraseToAnyPublisher() } + public var threadListPublisher: AnyPublisher<[AgentThread], Never> { + threadsSubject.eraseToAnyPublisher() + } + + public func threadPublisher(for threadID: String) -> AnyPublisher { + withLock { + threadSubject(for: threadID).eraseToAnyPublisher() + } + } + + public func messagePublisher(for threadID: String) -> AnyPublisher<[AgentMessage], Never> { + withLock { + messageSubject(for: threadID).eraseToAnyPublisher() + } + } + + public func threadSummaryPublisher(for threadID: String) -> AnyPublisher { + withLock { + summarySubject(for: threadID).eraseToAnyPublisher() + } + } + + public func threadContextStatePublisher(for threadID: String) -> AnyPublisher { + withLock { + contextStateSubject(for: threadID).eraseToAnyPublisher() + } + } + func send(_ observation: AgentRuntimeObservation) { + var updates: [() -> Void] = [] + + withLock { + switch observation { + case let .threadsChanged(threads): + updates.append { self.threadsSubject.send(threads) } + + case let .threadChanged(thread): + let subject = threadSubject(for: thread.id) + updates.append { subject.send(thread) } + + case let .messagesChanged(threadID, messages): + let subject = messageSubject(for: threadID) + updates.append { subject.send(messages) } + + case let .threadSummaryChanged(summary): + let subject = summarySubject(for: summary.threadID) + updates.append { subject.send(summary) } + + case let .threadContextStateChanged(threadID, state): + let subject = contextStateSubject(for: threadID) + updates.append { subject.send(state) } + + case let .threadDeleted(threadID): + let threadSubject = threadSubject(for: threadID) + let messageSubject = messageSubject(for: threadID) + let summarySubject = summarySubject(for: threadID) + let contextStateSubject = contextStateSubject(for: threadID) + updates.append { threadSubject.send(nil) } + updates.append { messageSubject.send([]) } + updates.append { summarySubject.send(nil) } + updates.append { contextStateSubject.send(nil) } + } + } + + updates.forEach { $0() } subject.send(observation) } + + private func threadSubject(for threadID: String) -> CurrentValueSubject { + if let subject = threadSubjects[threadID] { + return subject + } + let subject = CurrentValueSubject(nil) + threadSubjects[threadID] = subject + return subject + } + + private func messageSubject(for threadID: String) -> CurrentValueSubject<[AgentMessage], Never> { + if let subject = messageSubjects[threadID] { + return subject + } + let subject = CurrentValueSubject<[AgentMessage], Never>([]) + messageSubjects[threadID] = subject + return subject + } + + private func summarySubject(for threadID: String) -> CurrentValueSubject { + if let subject = summarySubjects[threadID] { + return subject + } + let subject = CurrentValueSubject(nil) + summarySubjects[threadID] = subject + return subject + } + + private func contextStateSubject(for threadID: String) -> CurrentValueSubject { + if let subject = contextStateSubjects[threadID] { + return subject + } + let subject = CurrentValueSubject(nil) + contextStateSubjects[threadID] = subject + return subject + } + + private func withLock(_ body: () -> T) -> T { + lock.lock() + defer { lock.unlock() } + return body() + } +} + +extension AgentRuntime { + public nonisolated func observeThreads() -> AnyPublisher<[AgentThread], Never> { + observationCenter.threadListPublisher + } + + public nonisolated func observeThread(id threadID: String) -> AnyPublisher { + observationCenter.threadPublisher(for: threadID) + } + + public nonisolated func observeMessages(in threadID: String) -> AnyPublisher<[AgentMessage], Never> { + observationCenter.messagePublisher(for: threadID) + } + + public nonisolated func observeThreadSummary(id threadID: String) -> AnyPublisher { + observationCenter.threadSummaryPublisher(for: threadID) + } + + public nonisolated func observeThreadContextState(id threadID: String) -> AnyPublisher { + observationCenter.threadContextStatePublisher(for: threadID) + } } diff --git a/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift index 2757b93..ab8ef51 100644 --- a/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift @@ -243,6 +243,107 @@ extension AgentRuntimeTests { let reply = try await sendTask.value XCTAssertEqual(reply, "Echo: Observe me") } + + func testObserveMessagesPublishesInitialStateAndLocalPendingMessage() async throws { + let backend = DelayedBeginTurnBackend() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Observe Messages") + let observedInitialState = expectation(description: "Observed the initial empty message state") + let observedPendingMessage = expectation(description: "Observed the local pending user message") + observedPendingMessage.assertForOverFulfill = false + var snapshots: [[String]] = [] + var cancellables = Set() + + let publisher = runtime.observeMessages(in: thread.id) + publisher + .sink { messages in + snapshots.append(messages.map(\.text)) + if messages.isEmpty { + observedInitialState.fulfill() + } + if messages.count == 1, + messages.last?.role == .user, + messages.last?.text == "Observe messages" { + observedPendingMessage.fulfill() + } + } + .store(in: &cancellables) + + await fulfillment(of: [observedInitialState], timeout: 0.5) + + let sendTask = Task { + try await runtime.sendMessage(UserMessageRequest(text: "Observe messages"), in: thread.id) + } + + await backend.waitForBeginTurnStart() + await fulfillment(of: [observedPendingMessage], timeout: 0.5) + XCTAssertTrue(snapshots.contains([])) + XCTAssertTrue(snapshots.contains(["Observe messages"])) + + await backend.releaseBeginTurn() + let reply = try await sendTask.value + XCTAssertEqual(reply, "Echo: Observe messages") + } + + func testObserveThreadContextStatePublishesInitialAndCompactedValues() async throws { + let backend = CompactingTestBackend() + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + contextCompaction: .init( + isEnabled: true, + mode: .manual, + strategy: .localOnly + ) + ) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Observe Context") + _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) + _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + + let observedInitialState = expectation(description: "Observed the initial context state") + let observedCompactedState = expectation(description: "Observed the compacted context state") + observedCompactedState.assertForOverFulfill = false + var contexts: [AgentThreadContextState?] = [] + var cancellables = Set() + + let publisher = runtime.observeThreadContextState(id: thread.id) + publisher + .sink { state in + contexts.append(state) + if state?.generation == 0 { + observedInitialState.fulfill() + } + if let state, + state.generation == 1, + state.lastCompactionReason == .manual { + observedCompactedState.fulfill() + } + } + .store(in: &cancellables) + + await fulfillment(of: [observedInitialState], timeout: 0.5) + + let compacted = try await runtime.compactThreadContext(id: thread.id) + XCTAssertEqual(compacted.generation, 1) + + await fulfillment(of: [observedCompactedState], timeout: 0.5) + XCTAssertTrue(contexts.contains(where: { $0?.generation == 0 })) + XCTAssertTrue(contexts.contains(where: { $0?.generation == 1 })) + } } func drainStructuredStream( From a15c6a5c2f78bcec2668f9b5e17c900d1f85eeef Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 16:11:57 +1100 Subject: [PATCH 11/19] Add demo observation flow and thread title editing --- .../project.pbxproj | 4 + .../Shared/AgentDemoViewModel+Messaging.swift | 1 + .../AgentDemoViewModel+Observation.swift | 84 +++++++++++++++++++ .../AgentDemoViewModel+ThreadState.swift | 37 +++++++- .../Shared/AgentDemoViewModel.swift | 10 +++ .../Shared/ThreadDetailView.swift | 83 ++++++++++++++---- DemoApp/README.md | 15 ++++ README.md | 43 ++++++++++ .../Runtime/AgentRuntime+Threads.swift | 14 ++++ .../AgentRuntimeMessageBehaviorTests.swift | 53 ++++++++++++ 10 files changed, 324 insertions(+), 20 deletions(-) create mode 100644 DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift diff --git a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj index bfb29e4..1591882 100644 --- a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj +++ b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj @@ -27,6 +27,7 @@ 1A2B3C4D5E6F700000000012 /* MemoryDemoView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000012 /* MemoryDemoView.swift */; }; 1A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */; }; 1A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */; }; + 1A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift */; }; 7482123BC63AC10F104DE092 /* AssistantRuntimeDemoApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5A6999E6475919476E726E8C /* AssistantRuntimeDemoApp.swift */; }; 84726927B752451499D9257F /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 906A95007C8ECB92CFC2CE15 /* Foundation.framework */; }; B060448C6464C41789B56EED /* AgentDemoView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3CA22585116A120BA97F76B8 /* AgentDemoView.swift */; }; @@ -57,6 +58,7 @@ 2A2B3C4D5E6F700000000012 /* MemoryDemoView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = MemoryDemoView.swift; sourceTree = ""; }; 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+ComposerState.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+ThreadState.swift"; sourceTree = ""; }; + 2A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+Observation.swift"; sourceTree = ""; }; 2481147A958D00EB4A70C928 /* AgentDemoViewModel.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoViewModel.swift; sourceTree = ""; }; 3CA22585116A120BA97F76B8 /* AgentDemoView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoView.swift; sourceTree = ""; }; 5A6999E6475919476E726E8C /* AssistantRuntimeDemoApp.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AssistantRuntimeDemoApp.swift; sourceTree = ""; }; @@ -139,6 +141,7 @@ 2481147A958D00EB4A70C928 /* AgentDemoViewModel.swift */, 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */, 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */, + 2A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift */, 2A2B3C4D5E6F700000000006 /* AgentDemoViewModel+Messaging.swift */, 2A2B3C4D5E6F700000000007 /* AgentDemoViewModel+Tools.swift */, 2A2B3C4D5E6F700000000008 /* AgentDemoViewModel+HealthCoach.swift */, @@ -248,6 +251,7 @@ 1A2B3C4D5E6F700000000012 /* MemoryDemoView.swift in Sources */, 1A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift in Sources */, 1A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift in Sources */, + 1A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift in Sources */, 7482123BC63AC10F104DE092 /* AssistantRuntimeDemoApp.swift in Sources */, BB4F38E64D1EBBB3821AC4E3 /* AgentDemoRuntimeFactory.swift in Sources */, B060448C6464C41789B56EED /* AgentDemoView.swift in Sources */, diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift index eda56b3..f498173 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Messaging.swift @@ -28,6 +28,7 @@ extension AgentDemoViewModel { ) threads = await runtime.threads() activeThreadID = thread.id + bindActiveThreadObservation(for: thread.id) setMessages(await runtime.messages(for: thread.id)) developerLog( "Created thread. id=\(thread.id) title=\(thread.title ?? "") totalThreads=\(threads.count)" diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift new file mode 100644 index 0000000..cd3c49a --- /dev/null +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift @@ -0,0 +1,84 @@ +import Combine +import CodexKit +import Foundation + +@MainActor +extension AgentDemoViewModel { + func configureRuntimeObservationBindings() { + runtimeObservationCancellables.removeAll() + + runtime.observeThreads() + .receive(on: DispatchQueue.main) + .sink { [weak self] threads in + guard let self else { + return + } + + self.threads = threads + if let activeThreadID = self.activeThreadID, + !threads.contains(where: { $0.id == activeThreadID }) { + self.activeThreadID = nil + self.resetObservedThreadState() + self.messages = [] + } else if let activeThreadID = self.activeThreadID { + self.observedThread = threads.first { $0.id == activeThreadID } + } + } + .store(in: &runtimeObservationCancellables) + + if let activeThreadID { + bindActiveThreadObservation(for: activeThreadID) + } else { + resetObservedThreadState() + } + } + + func bindActiveThreadObservation(for threadID: String) { + activeThreadObservationCancellables.removeAll() + resetObservedThreadState() + + runtime.observeThread(id: threadID) + .receive(on: DispatchQueue.main) + .sink { [weak self] thread in + self?.observedThread = thread + } + .store(in: &activeThreadObservationCancellables) + + runtime.observeMessages(in: threadID) + .receive(on: DispatchQueue.main) + .sink { [weak self] messages in + guard let self else { + return + } + self.observedMessages = messages + self.setMessages(messages) + } + .store(in: &activeThreadObservationCancellables) + + runtime.observeThreadSummary(id: threadID) + .receive(on: DispatchQueue.main) + .sink { [weak self] summary in + self?.observedThreadSummary = summary + } + .store(in: &activeThreadObservationCancellables) + + runtime.observeThreadContextState(id: threadID) + .receive(on: DispatchQueue.main) + .sink { [weak self] contextState in + guard let self else { + return + } + self.observedThreadContextState = contextState + self.activeThreadContextState = contextState + } + .store(in: &activeThreadObservationCancellables) + } + + func resetObservedThreadState() { + observedThread = nil + observedMessages = [] + observedThreadSummary = nil + observedThreadContextState = nil + activeThreadContextState = nil + } +} diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift index e9b5f6f..fa78e26 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift @@ -39,6 +39,7 @@ extension AgentDemoViewModel { approvalInbox: approvalInbox, deviceCodePromptCoordinator: deviceCodePromptCoordinator ) + configureRuntimeObservationBindings() defer { isAuthenticating = false @@ -88,6 +89,7 @@ extension AgentDemoViewModel { approvalInbox: approvalInbox, deviceCodePromptCoordinator: deviceCodePromptCoordinator ) + configureRuntimeObservationBindings() do { _ = try await runtime.restore() @@ -172,6 +174,7 @@ extension AgentDemoViewModel { func activateThread(id: String) async { activeThreadID = id + bindActiveThreadObservation(for: id) setMessages(await runtime.messages(for: id)) streamingText = "" await refreshThreadContextState(for: id) @@ -202,6 +205,8 @@ extension AgentDemoViewModel { memoryPreviewResult = nil activeThreadID = nil healthCoachThreadID = nil + activeThreadObservationCancellables.removeAll() + resetObservedThreadState() healthCoachFeedback = "Set a step goal, then start moving." healthLastUpdatedAt = nil healthKitAuthorized = false @@ -234,6 +239,7 @@ extension AgentDemoViewModel { let selectedThreadID = activeThreadID if let selectedThreadID, threads.contains(where: { $0.id == selectedThreadID }) { + bindActiveThreadObservation(for: selectedThreadID) setMessages(await runtime.messages(for: selectedThreadID)) await refreshThreadContextState(for: selectedThreadID) return @@ -241,12 +247,13 @@ extension AgentDemoViewModel { if let firstThread = threads.first { activeThreadID = firstThread.id + bindActiveThreadObservation(for: firstThread.id) setMessages(await runtime.messages(for: firstThread.id)) await refreshThreadContextState(for: firstThread.id) } else { activeThreadID = nil messages = [] - activeThreadContextState = nil + resetObservedThreadState() } } @@ -260,23 +267,47 @@ extension AgentDemoViewModel { isRunningSkillPolicyProbe = false skillPolicyProbeResult = nil activeThreadID = nil - activeThreadContextState = nil + activeThreadObservationCancellables.removeAll() + resetObservedThreadState() } func refreshThreadContextState(for threadID: String? = nil) async { guard let resolvedThreadID = threadID ?? activeThreadID else { - activeThreadContextState = nil + resetObservedThreadState() return } do { activeThreadContextState = try await runtime.fetchThreadContextState(id: resolvedThreadID) + observedThreadContextState = activeThreadContextState } catch { activeThreadContextState = nil + observedThreadContextState = nil developerErrorLog("Failed to fetch thread context state. threadID=\(resolvedThreadID) error=\(error.localizedDescription)") } } + func updateActiveThreadTitle(_ title: String) async { + guard let activeThreadID else { + lastError = "Select a thread before renaming it." + return + } + + let normalizedTitle = title.trimmingCharacters(in: .whitespacesAndNewlines) + + do { + try await runtime.setTitle( + normalizedTitle.isEmpty ? nil : normalizedTitle, + for: activeThreadID + ) + developerLog( + "Updated thread title. threadID=\(activeThreadID) title=\(normalizedTitle.isEmpty ? "" : normalizedTitle)" + ) + } catch { + reportError(error) + } + } + func compactActiveThreadContext() async { guard let activeThreadID else { lastError = "Select a thread before compacting its prompt context." diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift index 016eec3..4678a84 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift @@ -1,3 +1,4 @@ +import Combine import CodexKit import CodexKitUI import Foundation @@ -234,6 +235,10 @@ final class AgentDemoViewModel: @unchecked Sendable { var currentAuthenticationMethod: DemoAuthenticationMethod = .deviceCode var activeThreadContextState: AgentThreadContextState? var isCompactingThreadContext = false + var observedThread: AgentThread? + var observedMessages: [AgentMessage] = [] + var observedThreadSummary: AgentThreadSummary? + var observedThreadContextState: AgentThreadContextState? let approvalInbox: ApprovalInbox let deviceCodePromptCoordinator: DeviceCodePromptCoordinator @@ -249,6 +254,10 @@ final class AgentDemoViewModel: @unchecked Sendable { var runtime: AgentRuntime var activeThreadID: String? var healthCoachThreadID: String? + @ObservationIgnored + var runtimeObservationCancellables: Set = [] + @ObservationIgnored + var activeThreadObservationCancellables: Set = [] #if os(iOS) let healthStore = HKHealthStore() @@ -278,6 +287,7 @@ final class AgentDemoViewModel: @unchecked Sendable { self.keychainAccount = keychainAccount self.approvalInbox = approvalInbox self.deviceCodePromptCoordinator = deviceCodePromptCoordinator + configureRuntimeObservationBindings() } var activeThread: AgentThread? { diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift index 4cd2b92..40ed20d 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift @@ -15,6 +15,7 @@ struct ThreadDetailView: View { @State private var selectedPhotoItem: PhotosPickerItem? @State private var isImportingPhoto = false + @State private var threadTitleDraft = "" @FocusState private var isComposerFocused: Bool init(viewModel: AgentDemoViewModel, threadID: String) { @@ -26,6 +27,7 @@ struct ThreadDetailView: View { ScrollView { VStack(alignment: .leading, spacing: 16) { threadHeaderCard + observationCard compactionCard transcriptCard } @@ -50,10 +52,16 @@ struct ThreadDetailView: View { } .background(.regularMaterial) } - .navigationTitle(activeThread?.title ?? "Thread") + .navigationTitle(observedThread?.title ?? activeThread?.title ?? "Thread") .navigationBarTitleDisplayMode(.inline) .task(id: threadID) { await viewModel.activateThread(id: threadID) + threadTitleDraft = observedThread?.title ?? "" + } + .onChange(of: observedThread?.title) { previousValue, newValue in + if threadTitleDraft == (previousValue ?? "") || threadTitleDraft.isEmpty { + threadTitleDraft = newValue ?? "" + } } .onChange(of: selectedPhotoItem) { _, newItem in guard let newItem else { @@ -73,12 +81,24 @@ private extension ThreadDetailView { viewModel.threads.first { $0.id == threadID } } + var observedThread: AgentThread? { + viewModel.observedThread?.id == threadID ? viewModel.observedThread : activeThread + } + + var observedContextState: AgentThreadContextState? { + viewModel.observedThreadContextState ?? viewModel.activeThreadContextState + } + + var observedSummary: AgentThreadSummary? { + viewModel.observedThreadSummary + } + var threadHeaderCard: some View { DemoSectionCard { - Text(activeThread?.title ?? "Thread") + Text(observedThread?.title ?? activeThread?.title ?? "Thread") .font(.title3.weight(.semibold)) - Text(activeThread?.status.rawValue.capitalized ?? "Idle") + Text(observedThread?.status.rawValue.capitalized ?? activeThread?.status.rawValue.capitalized ?? "Idle") .font(.subheadline) .foregroundStyle(.secondary) @@ -106,6 +126,38 @@ private extension ThreadDetailView { } } + var observationCard: some View { + DemoSectionCard { + VStack(alignment: .leading, spacing: 8) { + Text("Observation Demo") + .font(.headline) + + Text("This card is driven by `observeThread`, `observeMessages`, `observeThreadSummary`, and `observeThreadContextState`, so title, transcript, summary, and compaction changes update live without a manual refresh.") + .font(.subheadline) + .foregroundStyle(.secondary) + } + + HStack(spacing: 12) { + compactionMetric(title: "Observed Messages", value: "\(viewModel.observedMessages.count)") + compactionMetric(title: "Latest Preview", value: observedSummary?.latestAssistantMessagePreview?.isEmpty == false ? "Ready" : "None") + compactionMetric(title: "Observed Status", value: observedThread?.status.rawValue.capitalized ?? "Idle") + } + + TextField("Rename this thread", text: $threadTitleDraft) + .textFieldStyle(.roundedBorder) + + Button { + Task { + await viewModel.updateActiveThreadTitle(threadTitleDraft) + } + } label: { + Label("Save Thread Title", systemImage: "pencil") + } + .buttonStyle(.borderedProminent) + .disabled(viewModel.session == nil || observedThread == nil) + } + } + var transcriptCard: some View { DemoSectionCard { if threadMessages.isEmpty && viewModel.streamingText.isEmpty { @@ -145,15 +197,15 @@ private extension ThreadDetailView { ) compactionMetric( title: "Effective Messages", - value: "\(viewModel.activeThreadContextState?.effectiveMessages.count ?? threadMessages.count)" + value: "\(observedContextState?.effectiveMessages.count ?? threadMessages.count)" ) compactionMetric( title: "Generation", - value: "\(viewModel.activeThreadContextState?.generation ?? 0)" + value: "\(observedContextState?.generation ?? 0)" ) } - if let contextState = viewModel.activeThreadContextState { + if let contextState = observedContextState { VStack(alignment: .leading, spacing: 6) { if let reason = contextState.lastCompactionReason { Text("Last compaction: \(reason.rawValue)") @@ -195,17 +247,11 @@ private extension ThreadDetailView { } .buttonStyle(.borderedProminent) .disabled(viewModel.session == nil || activeThread == nil || viewModel.isCompactingThreadContext) - - Button { - Task { - await viewModel.refreshThreadContextState(for: threadID) - } - } label: { - Label("Refresh State", systemImage: "arrow.clockwise") - } - .buttonStyle(.bordered) - .disabled(viewModel.session == nil || activeThread == nil) } + + Text("Observed context state updates here automatically after compaction and future turns.") + .font(.caption) + .foregroundStyle(.secondary) } } @@ -286,7 +332,10 @@ private extension ThreadDetailView { } var threadMessages: [AgentMessage] { - viewModel.activeThreadID == threadID ? viewModel.messages : [] + guard viewModel.activeThreadID == threadID else { + return [] + } + return viewModel.observedMessages.isEmpty ? viewModel.messages : viewModel.observedMessages } var isStreamingActive: Bool { diff --git a/DemoApp/README.md b/DemoApp/README.md index 27f82b9..daad663 100644 --- a/DemoApp/README.md +++ b/DemoApp/README.md @@ -22,6 +22,8 @@ The Xcode project is the source of truth for the demo app. Edit it directly in X - lets you attach a photo from the library and send it with or without text - renders attached user images in the transcript - streams assistant output into the UI +- demonstrates live Combine observation of thread, message, summary, and context-state updates +- lets you rename the active thread from the thread detail screen using `setTitle(_:for:)` - includes a thread-level `Context Compaction` card so you can compact effective prompt state without removing visible transcript history - supports approval prompts for host-defined tools that opt into `requiresApproval` - demonstrates thread-pinned personas and one-turn persona overrides @@ -58,6 +60,19 @@ The checked-in demo enables context compaction in automatic mode. In a thread de - last compaction reason/time - a `Compact Context Now` action for manual testing +The same thread detail screen also includes an `Observation Demo` card. It subscribes to: + +- `observeThread(id:)` +- `observeMessages(in:)` +- `observeThreadSummary(id:)` +- `observeThreadContextState(id:)` + +Use that card to verify that: + +- local title changes propagate immediately through `setTitle(_:for:)` +- new messages appear from the observation stream without a manual refresh +- context compaction updates the observed context state live + ## Files - `DemoApp/AssistantRuntimeDemoApp/AssistantRuntimeDemoApp.swift` diff --git a/README.md b/README.md index 9e2ace3..d252456 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,49 @@ let snapshots = try await runtime.execute( This path also supports explicit history redaction and whole-thread deletion without forcing hosts to replay raw event streams themselves. +## Live Observation + +`CodexKit` exposes Combine publishers so apps can react to runtime state changes without polling or manual callback wiring. + +```swift +import Combine + +var cancellables = Set() + +runtime.observeThread(id: thread.id) + .receive(on: DispatchQueue.main) + .sink { thread in + print("Observed title:", thread?.title ?? "Untitled") + } + .store(in: &cancellables) + +runtime.observeMessages(in: thread.id) + .receive(on: DispatchQueue.main) + .sink { messages in + print("Observed message count:", messages.count) + } + .store(in: &cancellables) + +runtime.observeThreadContextState(id: thread.id) + .receive(on: DispatchQueue.main) + .sink { contextState in + print("Observed compaction generation:", contextState?.generation ?? 0) + } + .store(in: &cancellables) + +try await runtime.setTitle("Shipping Triage", for: thread.id) +``` + +Available built-in publishers: + +- `observeThreads()` +- `observeThread(id:)` +- `observeMessages(in:)` +- `observeThreadSummary(id:)` +- `observeThreadContextState(id:)` + +The checked-in demo app includes a thread detail `Observation Demo` card that exercises these publishers live, along with a rename control that calls `setTitle(_:for:)`. + ## Effective Context Compaction `CodexKit` can compact the runtime's effective prompt context without mutating canonical thread history. diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift index 1623053..c536851 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift @@ -91,6 +91,20 @@ extension AgentRuntime { return thread.personaStack } + public func setTitle( + _ title: String?, + for threadID: String + ) async throws { + guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + state.threads[index].title = title + state.threads[index].updatedAt = Date() + enqueueStoreOperation(.upsertThread(state.threads[index])) + try await persistState() + } + public func setPersonaStack( _ personaStack: AgentPersonaStack?, for threadID: String diff --git a/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift index ab8ef51..ac93ad3 100644 --- a/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift @@ -344,6 +344,59 @@ extension AgentRuntimeTests { XCTAssertTrue(contexts.contains(where: { $0?.generation == 0 })) XCTAssertTrue(contexts.contains(where: { $0?.generation == 1 })) } + + func testSetTitlePublishesObservedThreadUpdateAndPersists() async throws { + let stateStore = InMemoryRuntimeStateStore() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + )) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Original Title") + let observedInitialThread = expectation(description: "Observed the initial thread") + let observedRetitledThread = expectation(description: "Observed the retitled thread") + observedRetitledThread.assertForOverFulfill = false + var observedTitles: [String?] = [] + var cancellables = Set() + + runtime.observeThread(id: thread.id) + .sink { observedThread in + observedTitles.append(observedThread?.title) + if observedThread?.title == "Original Title" { + observedInitialThread.fulfill() + } + if observedThread?.title == "Updated Title" { + observedRetitledThread.fulfill() + } + } + .store(in: &cancellables) + + await fulfillment(of: [observedInitialThread], timeout: 0.5) + + try await runtime.setTitle("Updated Title", for: thread.id) + + await fulfillment(of: [observedRetitledThread], timeout: 0.5) + XCTAssertTrue(observedTitles.contains("Original Title")) + XCTAssertTrue(observedTitles.contains("Updated Title")) + + let restoredRuntime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore(service: "CodexKitTests.ChatGPTSession", account: UUID().uuidString), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: stateStore + )) + _ = try await restoredRuntime.restore() + + let restoredThreads = await restoredRuntime.threads() + let restoredThread = try XCTUnwrap(restoredThreads.first(where: { $0.id == thread.id })) + XCTAssertEqual(restoredThread.title, "Updated Title") + } } func drainStructuredStream( From 640f4abd303693c3f8bede547ba14cc31374cba2 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 16:23:35 +1100 Subject: [PATCH 12/19] Document iOS and macOS platform support --- DemoApp/README.md | 2 +- README.md | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/DemoApp/README.md b/DemoApp/README.md index daad663..d36d766 100644 --- a/DemoApp/README.md +++ b/DemoApp/README.md @@ -1,6 +1,6 @@ # CodexKit Demo App -This folder contains the checked-in iOS example app for exercising the `CodexKit` embedded agent runtime. +This folder contains the checked-in iOS example app for exercising the `CodexKit` embedded agent runtime. The package itself supports both iOS and macOS; this demo remains the iOS sample app. ## Open the app in Xcode diff --git a/README.md b/README.md index d252456..981ffde 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,13 @@ [![CI](https://github.com/timazed/CodexKit/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/timazed/CodexKit/actions/workflows/ci.yml) ![Version](https://img.shields.io/badge/main-2.0.0--dev-orange) -`CodexKit` is a lightweight iOS-first SDK for embedding OpenAI Codex-style agents in Apple apps. +`CodexKit` is a lightweight SDK for embedding OpenAI Codex-style agents in Apple apps, with explicit support for iOS and macOS. `main` documents the upcoming `2.0` development line. If you are integrating the latest stable release, use the [`v1.1.0` docs](https://github.com/timazed/CodexKit/blob/v1.1.0/README.md) instead. ## Who This Is For -Use `CodexKit` if you are building a SwiftUI/iOS app and want: +Use `CodexKit` if you are building a SwiftUI app for iOS or macOS and want: - ChatGPT sign-in (device code or OAuth) - secure session persistence @@ -102,7 +102,7 @@ let runtime = try AgentRuntime(configuration: .init( let _ = try await runtime.signIn() let thread = try await runtime.createThread(title: "First Chat") let stream = try await runtime.streamMessage( - UserMessageRequest(text: "Hello from iOS."), + UserMessageRequest(text: "Hello from Apple platforms."), in: thread.id ) ``` @@ -111,6 +111,7 @@ let stream = try await runtime.streamMessage( | Capability | Support | | --- | --- | +| Supported platforms | iOS 17+, macOS 14+ | | iOS auth: device code | Yes | | iOS auth: browser OAuth (localhost callback) | Yes | | Threaded runtime state + restore | Yes | @@ -134,6 +135,11 @@ let stream = try await runtime.streamMessage( - `CodexKit`: core runtime, auth, backend, tools, approvals - `CodexKitUI`: optional SwiftUI-facing helpers +Supported package platforms: + +- iOS 17+ +- macOS 14+ + ## Architecture ```mermaid @@ -150,7 +156,7 @@ flowchart LR ## Recommended Live Setup -The recommended production path for iOS is: +The recommended production path for iOS and macOS is: - `ChatGPTAuthProvider` - `KeychainSessionSecureStore` From 033bb9882f7cb33e06e868cd40d31e0338c4f1de Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 16:41:48 +1100 Subject: [PATCH 13/19] Add demo turn activity indicator --- .../project.pbxproj | 4 ++ .../Shared/ThreadDetailView.swift | 20 ++++++- .../Shared/ThreadTurnActivityView.swift | 60 +++++++++++++++++++ 3 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 DemoApp/AssistantRuntimeDemoApp/Shared/ThreadTurnActivityView.swift diff --git a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj index 1591882..7fa94ea 100644 --- a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj +++ b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj @@ -28,6 +28,7 @@ 1A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */; }; 1A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */; }; 1A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift */; }; + 1A2B3C4D5E6F700000000016 /* ThreadTurnActivityView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2A2B3C4D5E6F700000000016 /* ThreadTurnActivityView.swift */; }; 7482123BC63AC10F104DE092 /* AssistantRuntimeDemoApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5A6999E6475919476E726E8C /* AssistantRuntimeDemoApp.swift */; }; 84726927B752451499D9257F /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 906A95007C8ECB92CFC2CE15 /* Foundation.framework */; }; B060448C6464C41789B56EED /* AgentDemoView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3CA22585116A120BA97F76B8 /* AgentDemoView.swift */; }; @@ -59,6 +60,7 @@ 2A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+ComposerState.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+ThreadState.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+Observation.swift"; sourceTree = ""; }; + 2A2B3C4D5E6F700000000016 /* ThreadTurnActivityView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = ThreadTurnActivityView.swift; sourceTree = ""; }; 2481147A958D00EB4A70C928 /* AgentDemoViewModel.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoViewModel.swift; sourceTree = ""; }; 3CA22585116A120BA97F76B8 /* AgentDemoView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoView.swift; sourceTree = ""; }; 5A6999E6475919476E726E8C /* AssistantRuntimeDemoApp.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AssistantRuntimeDemoApp.swift; sourceTree = ""; }; @@ -152,6 +154,7 @@ 2A2B3C4D5E6F70000000000D /* DemoUIComponents.swift */, 2A2B3C4D5E6F70000000000E /* StructuredOutputDemoView.swift */, 2A2B3C4D5E6F70000000000F /* ThreadDetailView.swift */, + 2A2B3C4D5E6F700000000016 /* ThreadTurnActivityView.swift */, 2A2B3C4D5E6F700000000010 /* DemoMemoryExamples.swift */, 2A2B3C4D5E6F700000000011 /* AgentDemoViewModel+Memory.swift */, 2A2B3C4D5E6F700000000012 /* MemoryDemoView.swift */, @@ -252,6 +255,7 @@ 1A2B3C4D5E6F700000000013 /* AgentDemoViewModel+ComposerState.swift in Sources */, 1A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift in Sources */, 1A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift in Sources */, + 1A2B3C4D5E6F700000000016 /* ThreadTurnActivityView.swift in Sources */, 7482123BC63AC10F104DE092 /* AssistantRuntimeDemoApp.swift in Sources */, BB4F38E64D1EBBB3821AC4E3 /* AgentDemoRuntimeFactory.swift in Sources */, B060448C6464C41789B56EED /* AgentDemoView.swift in Sources */, diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift index 40ed20d..1f57e8d 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift @@ -89,8 +89,19 @@ private extension ThreadDetailView { viewModel.observedThreadContextState ?? viewModel.activeThreadContextState } - var observedSummary: AgentThreadSummary? { - viewModel.observedThreadSummary + var observedSummary: AgentThreadSummary? { viewModel.observedThreadSummary } + + var turnActivityStatus: AgentThreadStatus? { + guard let status = observedThread?.status else { + return nil + } + + switch status { + case .streaming where !isStreamingActive, .waitingForApproval, .waitingForToolResult: + return status + default: + return nil + } } var threadHeaderCard: some View { @@ -131,7 +142,6 @@ private extension ThreadDetailView { VStack(alignment: .leading, spacing: 8) { Text("Observation Demo") .font(.headline) - Text("This card is driven by `observeThread`, `observeMessages`, `observeThreadSummary`, and `observeThreadContextState`, so title, transcript, summary, and compaction changes update live without a manual refresh.") .font(.subheadline) .foregroundStyle(.secondary) @@ -170,6 +180,10 @@ private extension ThreadDetailView { ThreadMessageBubble(message: message) } + if let turnActivityStatus { + ThreadTurnActivityView(status: turnActivityStatus) + } + if isStreamingActive { ThreadStreamingBubble(text: viewModel.streamingText) } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadTurnActivityView.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadTurnActivityView.swift new file mode 100644 index 0000000..3a0e543 --- /dev/null +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadTurnActivityView.swift @@ -0,0 +1,60 @@ +import CodexKit +import SwiftUI + +@available(iOS 17.0, macOS 14.0, *) +struct ThreadTurnActivityView: View { + let status: AgentThreadStatus + + var body: some View { + HStack(spacing: 12) { + ProgressView() + .controlSize(.small) + + VStack(alignment: .leading, spacing: 4) { + Text(title) + .font(.subheadline.weight(.semibold)) + + Text(subtitle) + .font(.caption) + .foregroundStyle(.secondary) + } + + Spacer() + } + .padding(14) + .background( + RoundedRectangle(cornerRadius: 16, style: .continuous) + .fill(Color.primary.opacity(0.04)) + ) + } + + private var title: String { + switch status { + case .idle: + "Idle" + case .streaming: + "Thinking..." + case .waitingForApproval: + "Waiting for approval..." + case .waitingForToolResult: + "Running tool..." + case .failed: + "Turn failed" + } + } + + private var subtitle: String { + switch status { + case .idle: + "No active turn." + case .streaming: + "The assistant is preparing a reply." + case .waitingForApproval: + "Approve or deny the pending tool request to continue." + case .waitingForToolResult: + "A host tool is still executing." + case .failed: + "Check the latest error and try again." + } + } +} From 31eb976d22abd774f5944cfedb315c0d7455e917 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 17:07:52 +1100 Subject: [PATCH 14/19] Use GRDB requests for memory store mutations --- .../Memory/SQLiteMemoryStoreRepository.swift | 221 ++++++++++++------ .../Memory/SQLiteMemoryStoreSchema.swift | 4 + 2 files changed, 153 insertions(+), 72 deletions(-) diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift b/Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift index db2d02f..ea8b7f7 100644 --- a/Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift +++ b/Sources/CodexKit/Memory/SQLiteMemoryStoreRepository.swift @@ -64,7 +64,7 @@ struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws -> [MemoryRecord] { - let rows = try MemoryRecordRow + let rows = try MemoryRecordDatabaseRow .filter(Column("namespace") == namespace) .fetchAll(db) @@ -102,7 +102,7 @@ struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws -> MemoryRecord? { - let row = try MemoryRecordRow + let row = try MemoryRecordDatabaseRow .filter(Column("namespace") == namespace) .filter(Column("id") == id) .fetchOne(db) @@ -117,14 +117,10 @@ struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws { - try db.execute( - sql: """ - UPDATE memory_records - SET status = ? - WHERE namespace = ? AND id = ?; - """, - arguments: [MemoryRecordStatus.archived.rawValue, namespace, id] - ) + try MemoryRecordDatabaseRow + .filter(Column("namespace") == namespace) + .filter(Column("id") == id) + .updateAll(db, Column("status").set(to: MemoryRecordStatus.archived.rawValue)) } func deleteRecord( @@ -132,14 +128,14 @@ struct SQLiteMemoryStoreRepository: Sendable { namespace: String, in db: Database ) throws { - try db.execute( - sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", - arguments: [namespace, id] - ) - try db.execute( - sql: "DELETE FROM memory_records WHERE namespace = ? AND id = ?;", - arguments: [namespace, id] - ) + try MemoryFTSRow + .filter(Column("namespace") == namespace) + .filter(Column("record_id") == id) + .deleteAll(db) + try MemoryRecordDatabaseRow + .filter(Column("namespace") == namespace) + .filter(Column("id") == id) + .deleteAll(db) } func deleteRecord( @@ -159,74 +155,73 @@ struct SQLiteMemoryStoreRepository: Sendable { _ record: MemoryRecord, in db: Database ) throws { - try db.execute( - sql: """ - INSERT OR REPLACE INTO memory_records ( - namespace, id, scope, kind, summary, evidence_json, importance, - created_at, observed_at, expires_at, tags_json, related_ids_json, - dedupe_key, is_pinned, attributes_json, status - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); - """, - arguments: [ - record.namespace, - record.id, - record.scope.rawValue, - record.kind, - record.summary, - try codec.encode(record.evidence), - record.importance, - record.createdAt.timeIntervalSince1970, - record.observedAt?.timeIntervalSince1970, - record.expiresAt?.timeIntervalSince1970, - try codec.encode(record.tags), - try codec.encode(record.relatedIDs), - record.dedupeKey, - record.isPinned ? 1 : 0, - try codec.encodeNullable(record.attributes), - record.status.rawValue, - ] - ) + try MemoryRecordDatabaseRow + .filter(Column("namespace") == record.namespace) + .filter(Column("id") == record.id) + .deleteAll(db) + try makeDatabaseRow(from: record).insert(db) - try db.execute( - sql: "DELETE FROM memory_tags WHERE namespace = ? AND record_id = ?;", - arguments: [record.namespace, record.id] - ) - try db.execute( - sql: "DELETE FROM memory_related_ids WHERE namespace = ? AND record_id = ?;", - arguments: [record.namespace, record.id] - ) - try db.execute( - sql: "DELETE FROM memory_fts WHERE namespace = ? AND record_id = ?;", - arguments: [record.namespace, record.id] - ) + try MemoryTagRow + .filter(Column("namespace") == record.namespace) + .filter(Column("record_id") == record.id) + .deleteAll(db) + try MemoryRelatedIDRow + .filter(Column("namespace") == record.namespace) + .filter(Column("record_id") == record.id) + .deleteAll(db) + try MemoryFTSRow + .filter(Column("namespace") == record.namespace) + .filter(Column("record_id") == record.id) + .deleteAll(db) for tag in record.tags { - try db.execute( - sql: "INSERT INTO memory_tags(namespace, record_id, tag) VALUES (?, ?, ?);", - arguments: [record.namespace, record.id, tag] - ) + try MemoryTagRow( + namespace: record.namespace, + recordID: record.id, + tag: tag + ).insert(db) } for relatedID in record.relatedIDs { - try db.execute( - sql: """ - INSERT INTO memory_related_ids(namespace, record_id, related_id) - VALUES (?, ?, ?); - """, - arguments: [record.namespace, record.id, relatedID] - ) + try MemoryRelatedIDRow( + namespace: record.namespace, + recordID: record.id, + relatedID: relatedID + ).insert(db) } let ftsContent = ([record.summary] + record.evidence + record.tags + [record.kind]) .joined(separator: " ") - try db.execute( - sql: "INSERT INTO memory_fts(namespace, record_id, content) VALUES (?, ?, ?);", - arguments: [record.namespace, record.id, ftsContent] + try MemoryFTSRow( + namespace: record.namespace, + recordID: record.id, + content: ftsContent + ).insert(db) + } + + private func makeDatabaseRow(from record: MemoryRecord) throws -> MemoryRecordDatabaseRow { + try MemoryRecordDatabaseRow( + namespace: record.namespace, + id: record.id, + scope: record.scope.rawValue, + kind: record.kind, + summary: record.summary, + evidenceJSON: codec.encode(record.evidence), + importance: record.importance, + createdAt: record.createdAt.timeIntervalSince1970, + observedAt: record.observedAt?.timeIntervalSince1970, + expiresAt: record.expiresAt?.timeIntervalSince1970, + tagsJSON: codec.encode(record.tags), + relatedIDsJSON: codec.encode(record.relatedIDs), + dedupeKey: record.dedupeKey, + isPinned: record.isPinned, + attributesJSON: codec.encodeNullable(record.attributes), + status: record.status.rawValue ) } private func makeRecord( - from row: MemoryRecordRow, + from row: MemoryRecordDatabaseRow, namespace: String ) throws -> MemoryRecord { MemoryRecord( @@ -326,3 +321,85 @@ struct MemoryRecordRow: FetchableRecord, TableRecord { status = row["status"] } } + +struct MemoryRecordDatabaseRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "memory_records" + + let namespace: String + let id: String + let scope: String + let kind: String + let summary: String + let evidenceJSON: String + let importance: Double + let createdAt: Double + let observedAt: Double? + let expiresAt: Double? + let tagsJSON: String + let relatedIDsJSON: String + let dedupeKey: String? + let isPinned: Bool + let attributesJSON: String? + let status: String + + enum CodingKeys: String, CodingKey { + case namespace + case id + case scope + case kind + case summary + case evidenceJSON = "evidence_json" + case importance + case createdAt = "created_at" + case observedAt = "observed_at" + case expiresAt = "expires_at" + case tagsJSON = "tags_json" + case relatedIDsJSON = "related_ids_json" + case dedupeKey = "dedupe_key" + case isPinned = "is_pinned" + case attributesJSON = "attributes_json" + case status + } +} + +struct MemoryTagRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "memory_tags" + + let namespace: String + let recordID: String + let tag: String + + enum CodingKeys: String, CodingKey { + case namespace + case recordID = "record_id" + case tag + } +} + +struct MemoryRelatedIDRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "memory_related_ids" + + let namespace: String + let recordID: String + let relatedID: String + + enum CodingKeys: String, CodingKey { + case namespace + case recordID = "record_id" + case relatedID = "related_id" + } +} + +struct MemoryFTSRow: Codable, FetchableRecord, PersistableRecord, TableRecord { + static let databaseTableName = "memory_fts" + + let namespace: String + let recordID: String + let content: String + + enum CodingKeys: String, CodingKey { + case namespace + case recordID = "record_id" + case content + } +} diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift b/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift index 208d9f1..afd01fa 100644 --- a/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift +++ b/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift @@ -12,6 +12,9 @@ struct SQLiteMemoryStoreSchema: Sendable { var migrator = DatabaseMigrator() migrator.registerMigration("memory_store_v1") { db in + // Schema DDL stays as raw SQL here because GRDB's query interface + // is designed for data access, while table/index/FTS creation is + // clearest and most direct in SQLite DDL. try db.execute(sql: """ CREATE TABLE IF NOT EXISTS memory_records ( namespace TEXT NOT NULL, @@ -87,6 +90,7 @@ struct SQLiteMemoryStoreSchema: Sendable { USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); """) + // PRAGMA user_version is SQLite-specific migration state. try db.execute(sql: "PRAGMA user_version = \(currentVersion)") } From 179aadcd0adbc02a564c36a362832481e10758a9 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 17:14:13 +1100 Subject: [PATCH 15/19] Update demo Xcode project settings --- .../project.pbxproj | 31 ++++++++++++------- .../AssistantRuntimeDemoApp.xcscheme | 2 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj index 7fa94ea..de4695d 100644 --- a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj +++ b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/project.pbxproj @@ -3,7 +3,7 @@ archiveVersion = 1; classes = { }; - objectVersion = 46; + objectVersion = 60; objects = { /* Begin PBXBuildFile section */ @@ -39,6 +39,7 @@ /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + 2481147A958D00EB4A70C928 /* AgentDemoViewModel.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoViewModel.swift; sourceTree = ""; }; 2A2B3C4D5E6F700000000001 /* AgentDemoView+ChatSections.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoView+ChatSections.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000002 /* AgentDemoView+ComposerAndSheets.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoView+ComposerAndSheets.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000003 /* HealthCoachView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = HealthCoachView.swift; sourceTree = ""; }; @@ -61,7 +62,6 @@ 2A2B3C4D5E6F700000000014 /* AgentDemoViewModel+ThreadState.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+ThreadState.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000015 /* AgentDemoViewModel+Observation.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = "AgentDemoViewModel+Observation.swift"; sourceTree = ""; }; 2A2B3C4D5E6F700000000016 /* ThreadTurnActivityView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = ThreadTurnActivityView.swift; sourceTree = ""; }; - 2481147A958D00EB4A70C928 /* AgentDemoViewModel.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoViewModel.swift; sourceTree = ""; }; 3CA22585116A120BA97F76B8 /* AgentDemoView.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AgentDemoView.swift; sourceTree = ""; }; 5A6999E6475919476E726E8C /* AssistantRuntimeDemoApp.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = AssistantRuntimeDemoApp.swift; sourceTree = ""; }; 690A3E5A03E545F88FAF9A44 /* AssistantRuntimeDemoApp.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = AssistantRuntimeDemoApp.entitlements; sourceTree = ""; }; @@ -93,7 +93,6 @@ 690A3E5A03E545F88FAF9A44 /* AssistantRuntimeDemoApp.entitlements */, FE27E9F3273C1EAF47AE5198 /* Info.plist */, ); - name = AssistantRuntimeDemoApp; path = AssistantRuntimeDemoApp; sourceTree = ""; }; @@ -159,7 +158,6 @@ 2A2B3C4D5E6F700000000011 /* AgentDemoViewModel+Memory.swift */, 2A2B3C4D5E6F700000000012 /* MemoryDemoView.swift */, ); - name = Shared; path = Shared; sourceTree = ""; }; @@ -193,8 +191,9 @@ 26901593A1B92DDA950D134D /* Project object */ = { isa = PBXProject; attributes = { + BuildIndependentTargetsInParallel = YES; LastSwiftUpdateCheck = 1600; - LastUpgradeCheck = 1600; + LastUpgradeCheck = 2630; }; buildConfigurationList = 01B4A49BA7FDE96517E1DFD0 /* Build configuration list for PBXProject "AssistantRuntimeDemoApp" */; compatibilityVersion = "Xcode 3.2"; @@ -205,11 +204,9 @@ Base, ); mainGroup = B7BD2AF1370DA44DA0C52952; - minimizedProjectReferenceProxies = 0; packageReferences = ( 802CFFA933DE41B7A27EFE75 /* XCLocalSwiftPackageReference ".." */, ); - preferredProjectObjectVersion = 77; productRefGroup = CA88F59A68B805C4F9DA5CAA /* Products */; projectDirPath = ""; projectRoot = ""; @@ -300,9 +297,12 @@ CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; COPY_PHASE_STRIP = NO; + DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = dwarf; + DEVELOPMENT_TEAM = T9G4574SJG; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu11; GCC_DYNAMIC_NO_PIC = NO; GCC_NO_COMMON_BLOCKS = YES; @@ -321,6 +321,7 @@ MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; PRODUCT_NAME = "$(TARGET_NAME)"; + STRING_CATALOG_GENERATE_SYMBOLS = YES; SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_VERSION = 5.0; @@ -336,11 +337,13 @@ CLANG_ENABLE_OBJC_WEAK = NO; CODE_SIGN_ENTITLEMENTS = AssistantRuntimeDemoApp/AssistantRuntimeDemoApp.entitlements; CODE_SIGN_STYLE = Automatic; - DEVELOPMENT_TEAM = T9G4574SJG; GENERATE_INFOPLIST_FILE = NO; INFOPLIST_FILE = AssistantRuntimeDemoApp/Info.plist; IPHONEOS_DEPLOYMENT_TARGET = 17.0; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); PRODUCT_BUNDLE_IDENTIFIER = ai.assistantruntime.demoapp; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = iphoneos; @@ -359,11 +362,13 @@ CLANG_ENABLE_OBJC_WEAK = NO; CODE_SIGN_ENTITLEMENTS = AssistantRuntimeDemoApp/AssistantRuntimeDemoApp.entitlements; CODE_SIGN_STYLE = Automatic; - DEVELOPMENT_TEAM = T9G4574SJG; GENERATE_INFOPLIST_FILE = NO; INFOPLIST_FILE = AssistantRuntimeDemoApp/Info.plist; IPHONEOS_DEPLOYMENT_TARGET = 17.0; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); PRODUCT_BUNDLE_IDENTIFIER = ai.assistantruntime.demoapp; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = iphoneos; @@ -406,9 +411,12 @@ CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; COPY_PHASE_STRIP = NO; + DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + DEVELOPMENT_TEAM = T9G4574SJG; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu11; GCC_NO_COMMON_BLOCKS = YES; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; @@ -420,6 +428,7 @@ MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; PRODUCT_NAME = "$(TARGET_NAME)"; + STRING_CATALOG_GENERATE_SYMBOLS = YES; SWIFT_COMPILATION_MODE = wholemodule; SWIFT_OPTIMIZATION_LEVEL = "-O"; SWIFT_VERSION = 5.0; diff --git a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/xcshareddata/xcschemes/AssistantRuntimeDemoApp.xcscheme b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/xcshareddata/xcschemes/AssistantRuntimeDemoApp.xcscheme index ea3a6a0..d87abd5 100644 --- a/DemoApp/AssistantRuntimeDemoApp.xcodeproj/xcshareddata/xcschemes/AssistantRuntimeDemoApp.xcscheme +++ b/DemoApp/AssistantRuntimeDemoApp.xcodeproj/xcshareddata/xcschemes/AssistantRuntimeDemoApp.xcscheme @@ -1,6 +1,6 @@ Date: Tue, 24 Mar 2026 17:38:51 +1100 Subject: [PATCH 16/19] Refactor GRDB store collaborators and schema --- .../Memory/SQLiteMemoryStoreSchema.swift | 103 ++++++-------- .../GRDBRuntimeStateStore+Persistence.swift | 126 ++++++++---------- .../GRDBRuntimeStateStore+Queries.swift | 89 +++++++------ .../Runtime/GRDBRuntimeStateStore.swift | 53 +++++--- 4 files changed, 173 insertions(+), 198 deletions(-) diff --git a/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift b/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift index afd01fa..9a60427 100644 --- a/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift +++ b/Sources/CodexKit/Memory/SQLiteMemoryStoreSchema.swift @@ -12,79 +12,54 @@ struct SQLiteMemoryStoreSchema: Sendable { var migrator = DatabaseMigrator() migrator.registerMigration("memory_store_v1") { db in - // Schema DDL stays as raw SQL here because GRDB's query interface - // is designed for data access, while table/index/FTS creation is - // clearest and most direct in SQLite DDL. - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_records ( - namespace TEXT NOT NULL, - id TEXT NOT NULL, - scope TEXT NOT NULL, - kind TEXT NOT NULL, - summary TEXT NOT NULL, - evidence_json TEXT NOT NULL, - importance REAL NOT NULL, - created_at REAL NOT NULL, - observed_at REAL, - expires_at REAL, - tags_json TEXT NOT NULL, - related_ids_json TEXT NOT NULL, - dedupe_key TEXT, - is_pinned INTEGER NOT NULL, - attributes_json TEXT, - status TEXT NOT NULL, - PRIMARY KEY(namespace, id) - ); - """) + try db.create(table: "memory_records", ifNotExists: true) { table in + table.column("namespace", .text).notNull() + table.column("id", .text).notNull() + table.column("scope", .text).notNull() + table.column("kind", .text).notNull() + table.column("summary", .text).notNull() + table.column("evidence_json", .text).notNull() + table.column("importance", .double).notNull() + table.column("created_at", .double).notNull() + table.column("observed_at", .double) + table.column("expires_at", .double) + table.column("tags_json", .text).notNull() + table.column("related_ids_json", .text).notNull() + table.column("dedupe_key", .text) + table.column("is_pinned", .boolean).notNull() + table.column("attributes_json", .text) + table.column("status", .text).notNull() + table.primaryKey(["namespace", "id"]) + } + // The dedupe index stays raw because it is both partial and unique. try db.execute(sql: """ CREATE UNIQUE INDEX IF NOT EXISTS memory_records_namespace_dedupe ON memory_records(namespace, dedupe_key) WHERE dedupe_key IS NOT NULL; """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_scope - ON memory_records(namespace, scope); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_kind - ON memory_records(namespace, kind); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_records_namespace_status - ON memory_records(namespace, status); - """) + try db.create(index: "memory_records_namespace_scope", on: "memory_records", columns: ["namespace", "scope"]) + try db.create(index: "memory_records_namespace_kind", on: "memory_records", columns: ["namespace", "kind"]) + try db.create(index: "memory_records_namespace_status", on: "memory_records", columns: ["namespace", "status"]) - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_tags ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - tag TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_tags_lookup - ON memory_tags(namespace, tag, record_id); - """) + try db.create(table: "memory_tags", ifNotExists: true) { table in + table.column("namespace", .text).notNull() + table.column("record_id", .text).notNull() + table.column("tag", .text).notNull() + table.foreignKey(["namespace", "record_id"], references: "memory_records", columns: ["namespace", "id"], onDelete: .cascade) + } + try db.create(index: "memory_tags_lookup", on: "memory_tags", columns: ["namespace", "tag", "record_id"]) - try db.execute(sql: """ - CREATE TABLE IF NOT EXISTS memory_related_ids ( - namespace TEXT NOT NULL, - record_id TEXT NOT NULL, - related_id TEXT NOT NULL, - FOREIGN KEY(namespace, record_id) - REFERENCES memory_records(namespace, id) - ON DELETE CASCADE - ); - """) - try db.execute(sql: """ - CREATE INDEX IF NOT EXISTS memory_related_lookup - ON memory_related_ids(namespace, related_id, record_id); - """) + try db.create(table: "memory_related_ids", ifNotExists: true) { table in + table.column("namespace", .text).notNull() + table.column("record_id", .text).notNull() + table.column("related_id", .text).notNull() + table.foreignKey(["namespace", "record_id"], references: "memory_records", columns: ["namespace", "id"], onDelete: .cascade) + } + try db.create(index: "memory_related_lookup", on: "memory_related_ids", columns: ["namespace", "related_id", "record_id"]) + // FTS virtual-table creation still stays raw because it relies on + // SQLite's module-specific DDL syntax rather than ordinary table creation. try db.execute(sql: """ CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5(namespace UNINDEXED, record_id UNINDEXED, content); diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift index 7caceb6..94731ae 100644 --- a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Persistence.swift @@ -33,13 +33,10 @@ extension GRDBRuntimeStateStore { return } + let persistence = self.persistence try await dbQueue.write { db in try attachmentStore.reset() - try Self.replaceDatabaseContents( - with: state, - in: db, - attachmentStore: attachmentStore - ) + try persistence.replaceDatabaseContents(with: state, in: db) } } @@ -48,24 +45,27 @@ extension GRDBRuntimeStateStore { try RuntimeUserVersionQuery().execute(in: db) } } +} + +struct GRDBRuntimeStorePersistence: Sendable { + let attachmentStore: RuntimeAttachmentStore - static func replaceDatabaseContents( + func replaceDatabaseContents( with normalized: StoredRuntimeState, - in db: Database, - attachmentStore: RuntimeAttachmentStore + in db: Database ) throws { - let threadRows = try normalized.threads.map(Self.makeThreadRow) + let threadRows = try normalized.threads.map(makeThreadRow) let summaryRows = try normalized.threads.compactMap { thread -> RuntimeSummaryRow? in guard let summary = normalized.summariesByThread[thread.id] else { return nil } - return try Self.makeSummaryRow(from: summary) + return try makeSummaryRow(from: summary) } let historyRows = try normalized.historyByThread.values .flatMap { $0 } - .map { try Self.makeHistoryRow(from: $0, attachmentStore: attachmentStore) } - let structuredOutputRows = try Self.structuredOutputRows(from: normalized.historyByThread) - let contextRows = try normalized.contextStateByThread.values.map(Self.makeContextStateRow) + .map(makeHistoryRow) + let structuredOutputRows = try self.structuredOutputRows(from: normalized.historyByThread) + let contextRows = try normalized.contextStateByThread.values.map(makeContextStateRow) try RuntimeContextStateRow.deleteAll(db) try RuntimeStructuredOutputRow.deleteAll(db) @@ -73,27 +73,16 @@ extension GRDBRuntimeStateStore { try RuntimeSummaryRow.deleteAll(db) try RuntimeThreadRow.deleteAll(db) - for row in threadRows { - try row.insert(db) - } - for row in summaryRows { - try row.insert(db) - } - for row in historyRows { - try row.insert(db) - } - for row in structuredOutputRows { - try row.insert(db) - } - for row in contextRows { - try row.insert(db) - } + for row in threadRows { try row.insert(db) } + for row in summaryRows { try row.insert(db) } + for row in historyRows { try row.insert(db) } + for row in structuredOutputRows { try row.insert(db) } + for row in contextRows { try row.insert(db) } } - static func loadPartialState( + func loadPartialState( for threadIDs: Set, - from db: Database, - attachmentStore: RuntimeAttachmentStore + from db: Database ) throws -> StoredRuntimeState { guard !threadIDs.isEmpty else { return .empty @@ -106,12 +95,10 @@ extension GRDBRuntimeStateStore { let summaryRows = try RuntimeSummaryRow .filter(ids.contains(Column("threadID"))) .fetchAll(db) - // History loading keeps raw SQL here so we can preserve a deterministic - // thread + sequence ordering across multiple thread IDs in one fetch. let historyRows = try RuntimeHistoryRowsRequest( sql: """ SELECT * FROM \(RuntimeHistoryRow.databaseTableName) - WHERE threadID IN \(Self.sqlPlaceholders(count: ids.count)) + WHERE threadID IN \(sqlPlaceholders(count: ids.count)) ORDER BY threadID ASC, sequenceNumber ASC """, arguments: StatementArguments(ids) @@ -120,16 +107,14 @@ extension GRDBRuntimeStateStore { .filter(ids.contains(Column("threadID"))) .fetchAll(db) - let threads = try threadRows.map { try Self.decodeThread(from: $0) } + let threads = try threadRows.map(decodeThread) let summaries = try Dictionary( - uniqueKeysWithValues: summaryRows.map { ($0.threadID, try Self.decodeSummary(from: $0)) } + uniqueKeysWithValues: summaryRows.map { ($0.threadID, try decodeSummary(from: $0)) } ) - let decodedHistoryRows = try historyRows.map { - try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) - } + let decodedHistoryRows = try historyRows.map(decodeHistoryRecord) let history = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) let contextState = try Dictionary( - uniqueKeysWithValues: contextRows.map { ($0.threadID, try Self.decodeContextState(from: $0)) } + uniqueKeysWithValues: contextRows.map { ($0.threadID, try decodeContextState(from: $0)) } ) let nextSequence = history.mapValues { ($0.last?.sequenceNumber ?? 0) + 1 } @@ -142,11 +127,10 @@ extension GRDBRuntimeStateStore { ) } - static func persistThreads( + func persistThreads( ids threadIDs: Set, from state: StoredRuntimeState, - in db: Database, - attachmentStore: RuntimeAttachmentStore + in db: Database ) throws { let normalized = state.normalized() let threads = normalized.threads.filter { threadIDs.contains($0.id) } @@ -155,33 +139,33 @@ extension GRDBRuntimeStateStore { } for thread in threads { - try Self.makeThreadRow(from: thread).insert(db) + try makeThreadRow(from: thread).insert(db) if let summary = normalized.summariesByThread[thread.id] { - try Self.makeSummaryRow(from: summary).insert(db) + try makeSummaryRow(from: summary).insert(db) } if let contextState = normalized.contextStateByThread[thread.id] { - try Self.makeContextStateRow(from: contextState).insert(db) + try makeContextStateRow(from: contextState).insert(db) } for record in normalized.historyByThread[thread.id] ?? [] { - try Self.makeHistoryRow(from: record, attachmentStore: attachmentStore).insert(db) + try makeHistoryRow(from: record).insert(db) } } - for row in try Self.structuredOutputRows( + for row in try structuredOutputRows( from: normalized.historyByThread.filter { threadIDs.contains($0.key) } ) { try row.insert(db) } } - static func deletePersistedThread( + func deletePersistedThread( _ threadID: String, in db: Database ) throws { _ = try RuntimeThreadRow.deleteOne(db, key: threadID) } - static func makeThreadRow(from thread: AgentThread) throws -> RuntimeThreadRow { + func makeThreadRow(from thread: AgentThread) throws -> RuntimeThreadRow { RuntimeThreadRow( threadID: thread.id, createdAt: thread.createdAt.timeIntervalSince1970, @@ -191,7 +175,7 @@ extension GRDBRuntimeStateStore { ) } - static func makeSummaryRow(from summary: AgentThreadSummary) throws -> RuntimeSummaryRow { + func makeSummaryRow(from summary: AgentThreadSummary) throws -> RuntimeSummaryRow { RuntimeSummaryRow( threadID: summary.threadID, createdAt: summary.createdAt.timeIntervalSince1970, @@ -204,10 +188,7 @@ extension GRDBRuntimeStateStore { ) } - static func makeHistoryRow( - from record: AgentHistoryRecord, - attachmentStore: RuntimeAttachmentStore - ) throws -> RuntimeHistoryRow { + func makeHistoryRow(from record: AgentHistoryRecord) throws -> RuntimeHistoryRow { let persisted = try PersistedAgentHistoryRecord( record: record, attachmentStore: attachmentStore @@ -226,7 +207,7 @@ extension GRDBRuntimeStateStore { ) } - static func makeContextStateRow(from state: AgentThreadContextState) throws -> RuntimeContextStateRow { + func makeContextStateRow(from state: AgentThreadContextState) throws -> RuntimeContextStateRow { RuntimeContextStateRow( threadID: state.threadID, generation: state.generation, @@ -234,7 +215,7 @@ extension GRDBRuntimeStateStore { ) } - static func structuredOutputRows( + func structuredOutputRows( from historyByThread: [String: [AgentHistoryRecord]] ) throws -> [RuntimeStructuredOutputRow] { try historyByThread.values @@ -242,16 +223,15 @@ extension GRDBRuntimeStateStore { .compactMap { record -> RuntimeStructuredOutputRow? in switch record.item { case let .structuredOutput(output): - return try Self.makeStructuredOutputRow( + return try makeStructuredOutputRow( id: "structured:\(record.id)", record: output ) - case let .message(message): guard let metadata = message.structuredOutput else { return nil } - return try Self.makeStructuredOutputRow( + return try makeStructuredOutputRow( id: "message:\(message.id)", record: AgentStructuredOutputRecord( threadID: message.threadID, @@ -261,14 +241,13 @@ extension GRDBRuntimeStateStore { committedAt: message.createdAt ) ) - default: return nil } } } - static func makeStructuredOutputRow( + func makeStructuredOutputRow( id: String, record: AgentStructuredOutputRecord ) throws -> RuntimeStructuredOutputRow { @@ -281,22 +260,19 @@ extension GRDBRuntimeStateStore { ) } - static func decodeThread(from row: RuntimeThreadRow) throws -> AgentThread { + func decodeThread(from row: RuntimeThreadRow) throws -> AgentThread { try JSONDecoder().decode(AgentThread.self, from: row.encodedThread) } - static func decodeSummary(from row: RuntimeSummaryRow) throws -> AgentThreadSummary { + func decodeSummary(from row: RuntimeSummaryRow) throws -> AgentThreadSummary { try JSONDecoder().decode(AgentThreadSummary.self, from: row.encodedSummary) } - static func decodeContextState(from row: RuntimeContextStateRow) throws -> AgentThreadContextState { + func decodeContextState(from row: RuntimeContextStateRow) throws -> AgentThreadContextState { try JSONDecoder().decode(AgentThreadContextState.self, from: row.encodedState) } - static func decodeHistoryRecord( - from row: RuntimeHistoryRow, - attachmentStore: RuntimeAttachmentStore - ) throws -> AgentHistoryRecord { + func decodeHistoryRecord(from row: RuntimeHistoryRow) throws -> AgentHistoryRecord { let decoder = JSONDecoder() if let persisted = try? decoder.decode(PersistedAgentHistoryRecord.self, from: row.encodedRecord) { return try persisted.decode(using: attachmentStore) @@ -304,11 +280,19 @@ extension GRDBRuntimeStateStore { return try decoder.decode(AgentHistoryRecord.self, from: row.encodedRecord) } - static func decodeStructuredOutputRecord(from row: RuntimeStructuredOutputRow) throws -> AgentStructuredOutputRecord { + func decodeStructuredOutputRecord(from row: RuntimeStructuredOutputRow) throws -> AgentStructuredOutputRecord { try JSONDecoder().decode(AgentStructuredOutputRecord.self, from: row.encodedRecord) } - static func makeMigrator() -> DatabaseMigrator { + private func sqlPlaceholders(count: Int) -> String { + "(" + Array(repeating: "?", count: count).joined(separator: ", ") + ")" + } +} + +struct GRDBRuntimeStoreSchema: Sendable { + let currentStoreSchemaVersion: Int + + func makeMigrator() -> DatabaseMigrator { var migrator = DatabaseMigrator() migrator.registerMigration("runtime_store_v1") { db in diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift index a47511f..b3ebbb5 100644 --- a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore+Queries.swift @@ -3,7 +3,9 @@ import GRDB extension GRDBRuntimeStateStore { func executeHistoryQuery(_ query: HistoryItemsQuery) async throws -> AgentHistoryQueryResult { - try await dbQueue.read { db in + let persistence = self.persistence + let queries = self.queries + return try await dbQueue.read { db in guard let threadRow = try RuntimeThreadRow.fetchOne(db, key: query.threadID) else { return AgentHistoryQueryResult( threadID: query.threadID, @@ -15,16 +17,15 @@ extension GRDBRuntimeStateStore { ) } - let thread = try Self.decodeThread(from: threadRow) - let history = try Self.fetchHistoryRows( + let thread = try persistence.decodeThread(from: threadRow) + let history = try queries.fetchHistoryRows( threadID: query.threadID, kinds: query.kinds, createdAtRange: query.createdAtRange, turnID: query.turnID, includeRedacted: query.includeRedacted, includeCompactionEvents: query.includeCompactionEvents, - in: db, - attachmentStore: attachmentStore + in: db ) let state = StoredRuntimeState( @@ -36,7 +37,8 @@ extension GRDBRuntimeStateStore { } func executeThreadQuery(_ query: ThreadMetadataQuery) async throws -> [AgentThread] { - try await dbQueue.read { db in + let persistence = self.persistence + return try await dbQueue.read { db in var request = RuntimeThreadRow.all() if let threadIDs = query.threadIDs, !threadIDs.isEmpty { request = request.filter(threadIDs.contains(Column("threadID"))) @@ -64,13 +66,13 @@ extension GRDBRuntimeStateStore { request = request.limit(max(0, limit)) } - let rows = try request.fetchAll(db) - return try rows.map { try Self.decodeThread(from: $0) } + return try request.fetchAll(db).map { try persistence.decodeThread(from: $0) } } } func executeThreadContextStateQuery(_ query: ThreadContextStateQuery) async throws -> [AgentThreadContextState] { - try await dbQueue.read { db in + let persistence = self.persistence + return try await dbQueue.read { db in var request = RuntimeContextStateRow.all() if let threadIDs = query.threadIDs, !threadIDs.isEmpty { request = request.filter(threadIDs.contains(Column("threadID"))) @@ -80,14 +82,14 @@ extension GRDBRuntimeStateStore { request = request.limit(max(0, limit)) } - return try request.fetchAll(db).map { try Self.decodeContextState(from: $0) } + return try request.fetchAll(db).map { try persistence.decodeContextState(from: $0) } } } func executePendingStateQuery(_ query: PendingStateQuery) async throws -> [AgentPendingStateRecord] { - try await dbQueue.read { db in - var request = RuntimeSummaryRow - .filter(Column("pendingStateKind") != nil) + let persistence = self.persistence + return try await dbQueue.read { db in + var request = RuntimeSummaryRow.filter(Column("pendingStateKind") != nil) if let threadIDs = query.threadIDs, !threadIDs.isEmpty { request = request.filter(threadIDs.contains(Column("threadID"))) @@ -109,7 +111,7 @@ extension GRDBRuntimeStateStore { let summaries = try request.fetchAll(db) return try summaries.compactMap { row -> AgentPendingStateRecord? in - let summary = try Self.decodeSummary(from: row) + let summary = try persistence.decodeSummary(from: row) guard let pendingState = summary.pendingState else { return nil } @@ -123,7 +125,8 @@ extension GRDBRuntimeStateStore { } func executeStructuredOutputQuery(_ query: StructuredOutputQuery) async throws -> [AgentStructuredOutputRecord] { - try await dbQueue.read { db in + let persistence = self.persistence + return try await dbQueue.read { db in var request = RuntimeStructuredOutputRow.all() if let threadIDs = query.threadIDs, !threadIDs.isEmpty { request = request.filter(threadIDs.contains(Column("threadID"))) @@ -144,7 +147,7 @@ extension GRDBRuntimeStateStore { } var records = try request.fetchAll(db) - .map { try Self.decodeStructuredOutputRecord(from: $0) } + .map { try persistence.decodeStructuredOutputRecord(from: $0) } if query.latestOnly { var seen = Set() @@ -159,7 +162,8 @@ extension GRDBRuntimeStateStore { } func executeThreadSnapshotQuery(_ query: ThreadSnapshotQuery) async throws -> [AgentThreadSnapshot] { - try await dbQueue.read { db in + let persistence = self.persistence + return try await dbQueue.read { db in var request = RuntimeSummaryRow.all() if let threadIDs = query.threadIDs, !threadIDs.isEmpty { request = request.filter(threadIDs.contains(Column("threadID"))) @@ -181,20 +185,23 @@ extension GRDBRuntimeStateStore { } return try request.fetchAll(db) - .map { try Self.decodeSummary(from: $0) } + .map { try persistence.decodeSummary(from: $0) } .map(\.snapshot) } } +} + +struct GRDBRuntimeStoreQueries: Sendable { + let attachmentStore: RuntimeAttachmentStore - static func fetchHistoryRows( + func fetchHistoryRows( threadID: String, kinds: Set?, createdAtRange: ClosedRange?, turnID: String?, includeRedacted: Bool, includeCompactionEvents: Bool, - in db: Database, - attachmentStore: RuntimeAttachmentStore + in db: Database ) throws -> [AgentHistoryRecord] { var clauses = ["threadID = ?"] var arguments: [any DatabaseValueConvertible] = [threadID] @@ -220,8 +227,6 @@ extension GRDBRuntimeStateStore { clauses.append("isCompactionMarker = 0") } - // This stays in SQL because the history query shape is highly dynamic and - // we always want sequence-ordered reads for restore/query replay semantics. let sql = """ SELECT * FROM \(RuntimeHistoryRow.databaseTableName) WHERE \(clauses.joined(separator: " AND ")) @@ -230,14 +235,13 @@ extension GRDBRuntimeStateStore { return try RuntimeHistoryRowsRequest( sql: sql, arguments: StatementArguments(arguments) - ).execute(in: db).map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } + ).execute(in: db).map(decodeHistoryRecord) } - static func fetchHistoryPage( + func fetchHistoryPage( threadID: String, query: AgentHistoryQuery, - in db: Database, - attachmentStore: RuntimeAttachmentStore + in db: Database ) throws -> AgentThreadHistoryPage { let limit = max(1, query.limit) let kinds = historyKinds(from: query.filter) @@ -274,9 +278,7 @@ extension GRDBRuntimeStateStore { ).execute(in: db) let hasMoreBefore = fetched.count > limit let pageRowsDescending = Array(fetched.prefix(limit)) - let pageRecords = try pageRowsDescending - .map { try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) } - .reversed() + let pageRecords = try pageRowsDescending.map(decodeHistoryRecord).reversed() let hasMoreAfter: Bool if let anchor { @@ -330,9 +332,7 @@ extension GRDBRuntimeStateStore { ).execute(in: db) let hasMoreAfter = fetched.count > limit let pageRows = Array(fetched.prefix(limit)) - let pageRecords = try pageRows.map { - try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) - } + let pageRecords = try pageRows.map(decodeHistoryRecord) let hasMoreBefore: Bool if let anchor { @@ -359,7 +359,7 @@ extension GRDBRuntimeStateStore { } } - static func historyRecordExists( + private func historyRecordExists( threadID: String, kinds: Set?, includeCompactionEvents: Bool, @@ -377,8 +377,6 @@ extension GRDBRuntimeStateStore { clauses.append("isCompactionMarker = 0") } - // EXISTS is one of the few cases where the raw SQL is both shorter and more obvious - // than the equivalent GRDB request composition for cursor-bound history checks. let sql = """ SELECT EXISTS( SELECT 1 FROM \(RuntimeHistoryRow.databaseTableName) @@ -391,7 +389,7 @@ extension GRDBRuntimeStateStore { ).execute(in: db) } - static func historyKinds(from filter: AgentHistoryFilter?) -> Set? { + private func historyKinds(from filter: AgentHistoryFilter?) -> Set? { guard let filter else { return nil } @@ -406,7 +404,7 @@ extension GRDBRuntimeStateStore { return kinds } - static func makeCursor(threadID: String, sequenceNumber: Int?) -> AgentHistoryCursor? { + private func makeCursor(threadID: String, sequenceNumber: Int?) -> AgentHistoryCursor? { guard let sequenceNumber else { return nil } @@ -424,7 +422,7 @@ extension GRDBRuntimeStateStore { return AgentHistoryCursor(rawValue: base64) } - static func decodeCursorSequence( + private func decodeCursorSequence( _ cursor: AgentHistoryCursor?, expectedThreadID: String ) throws -> Int? { @@ -449,11 +447,18 @@ extension GRDBRuntimeStateStore { return payload.sequenceNumber } - static func defaultLegacyImportURL(for url: URL) -> URL { - url.deletingPathExtension().appendingPathExtension("json") + private func decodeHistoryRecord(from row: RuntimeHistoryRow) throws -> AgentHistoryRecord { + let persistence = GRDBRuntimeStorePersistence(attachmentStore: attachmentStore) + return try persistence.decodeHistoryRecord(from: row) } - static func sqlPlaceholders(count: Int) -> String { + private func sqlPlaceholders(count: Int) -> String { "(" + Array(repeating: "?", count: count).joined(separator: ", ") + ")" } } + +extension GRDBRuntimeStateStore { + static func defaultLegacyImportURL(for url: URL) -> URL { + url.deletingPathExtension().appendingPathExtension("json") + } +} diff --git a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift index 08941cf..cf6748c 100644 --- a/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift +++ b/Sources/CodexKit/Runtime/GRDBRuntimeStateStore.swift @@ -12,6 +12,14 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, let migrator: DatabaseMigrator var isPrepared = false + var persistence: GRDBRuntimeStorePersistence { + GRDBRuntimeStorePersistence(attachmentStore: attachmentStore) + } + + var queries: GRDBRuntimeStoreQueries { + GRDBRuntimeStoreQueries(attachmentStore: attachmentStore) + } + public init( url: URL, importingLegacyStateFrom legacyStateURL: URL? = nil @@ -39,7 +47,8 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, configuration.foreignKeysEnabled = true configuration.label = "CodexKit.GRDBRuntimeStateStore" dbQueue = try DatabaseQueue(path: url.path, configuration: configuration) - migrator = Self.makeMigrator() + migrator = GRDBRuntimeStoreSchema(currentStoreSchemaVersion: Self.currentStoreSchemaVersion) + .makeMigrator() } public func prepare() async throws -> AgentStoreMetadata { @@ -67,6 +76,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, public func loadState() async throws -> StoredRuntimeState { try await ensurePrepared() + let persistence = self.persistence return try await dbQueue.read { db in let threadRows = try RuntimeThreadRow.fetchAll(db) @@ -74,19 +84,19 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, let historyRows = try RuntimeHistoryRow.fetchAll(db) let contextRows = try RuntimeContextStateRow.fetchAll(db) - let threads = try threadRows.map { try Self.decodeThread(from: $0) } + let threads = try threadRows.map { try persistence.decodeThread(from: $0) } let summariesByThread = try Dictionary( uniqueKeysWithValues: summaryRows.map { row in - (row.threadID, try Self.decodeSummary(from: row)) + (row.threadID, try persistence.decodeSummary(from: row)) } ) let decodedHistoryRows = try historyRows.map { - try Self.decodeHistoryRecord(from: $0, attachmentStore: attachmentStore) + try persistence.decodeHistoryRecord(from: $0) } let historyByThread = Dictionary(grouping: decodedHistoryRows, by: { $0.item.threadID }) let contextStateByThread = try Dictionary( uniqueKeysWithValues: contextRows.map { row in - (row.threadID, try Self.decodeContextState(from: row)) + (row.threadID, try persistence.decodeContextState(from: row)) } ) @@ -103,12 +113,12 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, try await ensurePrepared() let normalized = state.normalized() + let persistence = self.persistence try attachmentStore.reset() try await dbQueue.write { db in - try Self.replaceDatabaseContents( + try persistence.replaceDatabaseContents( with: normalized, - in: db, - attachmentStore: attachmentStore + in: db ) } } @@ -124,39 +134,39 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, return } + let persistence = self.persistence try await dbQueue.write { db in - var partialState = try Self.loadPartialState( + var partialState = try persistence.loadPartialState( for: affectedThreadIDs, - from: db, - attachmentStore: attachmentStore + from: db ) partialState = try partialState.applying(operations) for threadID in affectedThreadIDs { - try Self.deletePersistedThread(threadID, in: db) + try persistence.deletePersistedThread(threadID, in: db) try attachmentStore.removeThread(threadID) } - try Self.persistThreads( + try persistence.persistThreads( ids: affectedThreadIDs, from: partialState, - in: db, - attachmentStore: attachmentStore + in: db ) } } public func fetchThreadSummary(id: String) async throws -> AgentThreadSummary { try await ensurePrepared() + let persistence = self.persistence return try await dbQueue.read { db in guard let threadRow = try RuntimeThreadRow.fetchOne(db, key: id) else { throw AgentRuntimeError.threadNotFound(id) } if let summaryRow = try RuntimeSummaryRow.fetchOne(db, key: id) { - return try Self.decodeSummary(from: summaryRow) + return try persistence.decodeSummary(from: summaryRow) } - let thread = try Self.decodeThread(from: threadRow) + let thread = try persistence.decodeThread(from: threadRow) return StoredRuntimeState(threads: [thread]).threadSummaryFallback(for: thread) } } @@ -166,17 +176,17 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, query: AgentHistoryQuery ) async throws -> AgentThreadHistoryPage { try await ensurePrepared() + let queries = self.queries return try await dbQueue.read { db in guard try RuntimeThreadRow.fetchOne(db, key: id) != nil else { throw AgentRuntimeError.threadNotFound(id) } - return try Self.fetchHistoryPage( + return try queries.fetchHistoryPage( threadID: id, query: query, - in: db, - attachmentStore: attachmentStore + in: db ) } } @@ -188,6 +198,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, public func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? { try await ensurePrepared() + let persistence = self.persistence return try await dbQueue.read { db in guard try RuntimeThreadRow.fetchOne(db, key: id) != nil else { throw AgentRuntimeError.threadNotFound(id) @@ -195,7 +206,7 @@ public actor GRDBRuntimeStateStore: RuntimeStateStoring, RuntimeStateInspecting, guard let row = try RuntimeContextStateRow.fetchOne(db, key: id) else { return nil } - return try Self.decodeContextState(from: row) + return try persistence.decodeContextState(from: row) } } From a62b12133c59b83f926e086e8930ade3cde48263 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 18:18:49 +1100 Subject: [PATCH 17/19] Refactor responses turn runner and transport parsing --- .../CodexResponsesBackend+Compaction.swift | 6 +- .../CodexResponsesBackend+Streaming.swift | 368 --------------- .../Runtime/CodexResponsesBackend.swift | 281 +----------- .../Runtime/CodexResponsesTransport.swift | 379 +++++++++++++++ .../Runtime/CodexResponsesTurnRunner.swift | 434 ++++++++++++++++++ .../Runtime/StoredRuntimeState+Queries.swift | 22 +- 6 files changed, 846 insertions(+), 644 deletions(-) create mode 100644 Sources/CodexKit/Runtime/CodexResponsesTransport.swift create mode 100644 Sources/CodexKit/Runtime/CodexResponsesTurnRunner.swift diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift index 9be1f36..e6510db 100644 --- a/Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend+Compaction.swift @@ -8,13 +8,17 @@ extension CodexResponsesBackend: AgentBackendContextCompacting { tools: [ToolDefinition], session: ChatGPTSession ) async throws -> AgentCompactionResult { + let requestFactory = CodexResponsesRequestFactory( + configuration: configuration, + encoder: encoder + ) let requestBody = ResponsesCompactRequestBody( model: configuration.model, reasoning: .init(effort: configuration.reasoningEffort), instructions: instructions, text: .init(format: .init(responseFormat: nil)), input: effectiveHistory.map { WorkingHistoryItem.visibleMessage($0).jsonValue }, - tools: CodexResponsesTurnSession.responsesTools( + tools: requestFactory.responsesTools( from: tools, enableWebSearch: configuration.enableWebSearch ), diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift index 331df3d..2fb05bd 100644 --- a/Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift @@ -177,371 +177,3 @@ struct StreamErrorPayload: Decodable { struct StreamIncompleteDetails: Decodable { let reason: String? } - -extension CodexResponsesTurnSession { - static func buildURLRequest( - configuration: CodexResponsesBackendConfiguration, - instructions: String, - responseFormat: AgentStructuredOutputFormat?, - streamedStructuredOutput: AgentStreamedStructuredOutputRequest?, - threadID: String, - items: [WorkingHistoryItem], - tools: [ToolDefinition], - session: ChatGPTSession, - encoder: JSONEncoder - ) throws -> URLRequest { - let resolvedInstructions = if let streamedStructuredOutput { - instructions + "\n\n" + streamedStructuredOutputInstructions(for: streamedStructuredOutput) - } else { - instructions - } - let requestBody = ResponsesRequestBody( - model: configuration.model, - reasoning: .init(effort: configuration.reasoningEffort), - instructions: resolvedInstructions, - text: .init( - format: .init( - responseFormat: streamedStructuredOutput == nil ? responseFormat : nil - ) - ), - input: items.map(\.jsonValue), - tools: responsesTools( - from: tools, - enableWebSearch: configuration.enableWebSearch - ), - toolChoice: "auto", - parallelToolCalls: false, - store: false, - stream: true, - include: [], - promptCacheKey: threadID - ) - - var request = URLRequest(url: configuration.baseURL.appendingPathComponent("responses")) - request.httpMethod = "POST" - request.httpBody = try encoder.encode(requestBody) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - request.setValue("text/event-stream", forHTTPHeaderField: "Accept") - request.setValue("Bearer \(session.accessToken)", forHTTPHeaderField: "Authorization") - request.setValue(session.account.id, forHTTPHeaderField: "ChatGPT-Account-ID") - request.setValue(threadID, forHTTPHeaderField: "session_id") - request.setValue(threadID, forHTTPHeaderField: "x-client-request-id") - request.setValue(configuration.originator, forHTTPHeaderField: "originator") - - for (header, value) in configuration.extraHeaders { - request.setValue(value, forHTTPHeaderField: header) - } - - return request - } - - static func responsesTools( - from tools: [ToolDefinition], - enableWebSearch: Bool - ) -> [JSONValue] { - var responsesTools = tools.map(\.responsesJSONValue) - - if enableWebSearch { - responsesTools.append(.object(["type": .string("web_search")])) - } - - return responsesTools - } - - static func streamEvents( - request: URLRequest, - urlSession: URLSession, - decoder: JSONDecoder - ) async throws -> AsyncThrowingStream { - let (bytes, response) = try await urlSession.bytes(for: request) - guard let httpResponse = response as? HTTPURLResponse else { - throw AgentRuntimeError( - code: "responses_invalid_response", - message: "The ChatGPT responses endpoint returned an invalid response." - ) - } - - if !(200 ..< 300).contains(httpResponse.statusCode) { - let bodyData = try await readAll(bytes) - let body = String(data: bodyData, encoding: .utf8) ?? "Unknown error" - if httpResponse.statusCode == 401 || httpResponse.statusCode == 403 { - throw AgentRuntimeError.unauthorized(body) - } - throw AgentRuntimeError( - code: "responses_http_status_\(httpResponse.statusCode)", - message: "The ChatGPT responses request failed with status \(httpResponse.statusCode): \(body)" - ) - } - - return AsyncThrowingStream { continuation in - Task { - var parser = SSEEventParser() - - do { - var lineBuffer = Data() - - for try await byte in bytes { - if byte == UInt8(ascii: "\n") { - var line = String(decoding: lineBuffer, as: UTF8.self) - if line.hasSuffix("\r") { - line.removeLast() - } - lineBuffer.removeAll(keepingCapacity: true) - - if let payload = parser.consume(line: line) { - if let event = try parseStreamEvent( - from: payload, - decoder: decoder - ) { - continuation.yield(event) - } - } - continue - } - - lineBuffer.append(byte) - } - - if !lineBuffer.isEmpty { - var line = String(decoding: lineBuffer, as: UTF8.self) - if line.hasSuffix("\r") { - line.removeLast() - } - if let payload = parser.consume(line: line) { - if let event = try parseStreamEvent( - from: payload, - decoder: decoder - ) { - continuation.yield(event) - } - } - } - - if let payload = parser.finish(), - let event = try parseStreamEvent(from: payload, decoder: decoder) { - continuation.yield(event) - } - - continuation.finish() - } catch { - continuation.finish(throwing: error) - } - } - } - } - - static func shouldRetry( - _ error: Error, - policy: RequestRetryPolicy - ) -> Bool { - if let runtimeError = error as? AgentRuntimeError { - if runtimeError.code == AgentRuntimeError.unauthorized().code { - return false - } - if let statusCode = httpStatusCode(from: runtimeError.code) { - return policy.retryableHTTPStatusCodes.contains(statusCode) - } - return false - } - - if let urlError = error as? URLError { - return policy.retryableURLErrorCodes.contains(urlError.errorCode) - } - - let nsError = error as NSError - if nsError.domain == NSURLErrorDomain { - return policy.retryableURLErrorCodes.contains(nsError.code) - } - - return false - } - - static func httpStatusCode(from errorCode: String) -> Int? { - let prefix = "responses_http_status_" - guard errorCode.hasPrefix(prefix) else { - return nil - } - return Int(errorCode.dropFirst(prefix.count)) - } - - static func parseStreamEvent( - from payload: SSEEventPayload, - decoder: JSONDecoder - ) throws -> CodexResponsesStreamEvent? { - guard !payload.data.isEmpty else { - return nil - } - - let envelope = try decoder.decode( - StreamEnvelope.self, - from: Data(payload.data.utf8) - ) - - switch envelope.type { - case "response.output_text.delta": - return envelope.delta.map(CodexResponsesStreamEvent.assistantTextDelta) - - case "response.output_item.done": - guard let item = envelope.item else { - return nil - } - - switch item { - case let .message(message): - let text = message.content - .compactMap(\.displayText) - .joined() - .trimmingCharacters(in: .whitespacesAndNewlines) - let images = message.content.compactMap(\.imageAttachment) - guard !text.isEmpty || !images.isEmpty else { - return nil - } - return .assistantMessage( - AgentMessage( - threadID: "", - role: .assistant, - text: text, - images: images - ) - ) - - case let .functionCall(functionCall): - return .functionCall( - FunctionCallRecord( - name: functionCall.name, - callID: functionCall.callID, - argumentsRaw: functionCall.arguments - ) - ) - - case .other: - return nil - } - - case "response.completed": - let usage = envelope.response?.usage?.assistantUsage ?? AgentUsage() - return .completed(usage) - - case "response.failed": - let message = envelope.response?.error?.message ?? "The ChatGPT responses stream failed." - throw AgentRuntimeError(code: "responses_stream_failed", message: message) - - case "response.incomplete": - let reason = envelope.response?.incompleteDetails?.reason ?? "unknown" - throw AgentRuntimeError( - code: "responses_stream_incomplete", - message: "The ChatGPT responses stream completed early: \(reason)." - ) - - default: - return nil - } - } - - static func readAll(_ bytes: URLSession.AsyncBytes) async throws -> Data { - var data = Data() - for try await byte in bytes { - data.append(byte) - } - return data - } - - static func toolOutputText(from result: ToolResultEnvelope) -> String { - var segments: [String] = [] - - if let primaryText = result.primaryText, !primaryText.isEmpty { - segments.append(primaryText) - } - - let imageURLs = result.content.compactMap { content -> URL? in - guard case let .image(url) = content else { - return nil - } - return url - } - if !imageURLs.isEmpty { - segments.append( - "Image URLs:\n" + imageURLs.map(\.absoluteString).joined(separator: "\n") - ) - } - - if !segments.isEmpty { - return segments.joined(separator: "\n\n") - } - if let errorMessage = result.errorMessage, !errorMessage.isEmpty { - return errorMessage - } - return result.success ? "Tool execution completed." : "Tool execution failed." - } - - static func toolOutputImages( - from result: ToolResultEnvelope, - urlSession: URLSession - ) async -> [AgentImageAttachment] { - var attachments: [AgentImageAttachment] = [] - for content in result.content { - guard case let .image(url) = content else { - continue - } - if let attachment = await imageAttachment(from: url, urlSession: urlSession) { - attachments.append(attachment) - } - } - return attachments.uniqued() - } - - static func imageAttachment( - from url: URL, - urlSession: URLSession - ) async -> AgentImageAttachment? { - if url.scheme?.lowercased() == "data" { - let decoded = url.absoluteString.removingPercentEncoding ?? url.absoluteString - return AgentImageAttachment(dataURLString: decoded) - } - - if url.isFileURL { - guard let mimeType = inferredImageMimeType(from: url.pathExtension), - let data = try? Data(contentsOf: url), - !data.isEmpty - else { - return nil - } - return AgentImageAttachment(mimeType: mimeType, data: data) - } - - do { - let (data, response) = try await urlSession.data(from: url) - guard !data.isEmpty else { - return nil - } - - let mimeType = response.mimeType?.lowercased() ?? inferredImageMimeType(from: url.pathExtension) - let normalized = mimeType ?? "image/png" - guard normalized.hasPrefix("image/") else { - return nil - } - return AgentImageAttachment(mimeType: normalized, data: data) - } catch { - return nil - } - } - - static func inferredImageMimeType(from pathExtension: String) -> String? { - switch pathExtension.lowercased() { - case "png": - "image/png" - case "jpg", "jpeg": - "image/jpeg" - case "gif": - "image/gif" - case "webp": - "image/webp" - case "heic": - "image/heic" - case "heif": - "image/heif" - default: - nil - } - } -} diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift index 587bb12..c8dd175 100644 --- a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift @@ -88,7 +88,7 @@ public actor CodexResponsesBackend: AgentBackend { } } -private extension CodexResponsesBackend { +extension CodexResponsesBackend { static func structuredMetadata( from text: String, responseFormat: AgentStructuredOutputFormat? @@ -135,25 +135,27 @@ final class CodexResponsesTurnSession: AgentTurnStreaming, @unchecked Sendable { events = AsyncThrowingStream { continuation in continuation.yield(.turnStarted(turn)) + let runner = CodexResponsesTurnRunner( + configuration: configuration, + instructions: instructions, + responseFormat: responseFormat, + streamedStructuredOutput: streamedStructuredOutput, + urlSession: urlSession, + encoder: encoder, + decoder: decoder, + threadID: thread.id, + turnID: turn.id, + tools: tools, + session: session, + pendingToolResults: pendingToolResults, + continuation: continuation + ) Task { do { - let usage = try await Self.runTurnLoop( - configuration: configuration, - instructions: instructions, - responseFormat: responseFormat, - streamedStructuredOutput: streamedStructuredOutput, - urlSession: urlSession, - encoder: encoder, - decoder: decoder, - threadID: thread.id, - turnID: turn.id, + let usage = try await runner.run( history: history, - newMessage: message, - tools: tools, - session: session, - pendingToolResults: pendingToolResults, - continuation: continuation + newMessage: message ) continuation.yield( @@ -179,251 +181,4 @@ final class CodexResponsesTurnSession: AgentTurnStreaming, @unchecked Sendable { ) async throws { await pendingToolResults.resolve(result, for: invocationID) } - - private static func runTurnLoop( - configuration: CodexResponsesBackendConfiguration, - instructions: String, - responseFormat: AgentStructuredOutputFormat?, - streamedStructuredOutput: AgentStreamedStructuredOutputRequest?, - urlSession: URLSession, - encoder: JSONEncoder, - decoder: JSONDecoder, - threadID: String, - turnID: String, - history: [AgentMessage], - newMessage: UserMessageRequest, - tools: [ToolDefinition], - session: ChatGPTSession, - pendingToolResults: PendingToolResults, - continuation: AsyncThrowingStream.Continuation - ) async throws -> AgentUsage { - var workingHistory = history.map(WorkingHistoryItem.visibleMessage) - workingHistory.append( - .userMessage( - AgentMessage( - threadID: threadID, - role: .user, - text: newMessage.text, - images: newMessage.images - ) - ) - ) - - var aggregateUsage = AgentUsage() - var shouldContinue = true - var pendingToolImages: [AgentImageAttachment] = [] - var pendingToolFallbackTexts: [String] = [] - var structuredParser = CodexResponsesStructuredStreamParser() - var pendingStructuredOutputMetadata: AgentStructuredOutputMetadata? - - while shouldContinue { - shouldContinue = false - var sawToolCall = false - let retryPolicy = configuration.requestRetryPolicy - var attempt = 1 - - retryLoop: while true { - var emittedRetryUnsafeOutput = false - - do { - let request = try buildURLRequest( - configuration: configuration, - instructions: instructions, - responseFormat: responseFormat, - streamedStructuredOutput: streamedStructuredOutput, - threadID: threadID, - items: workingHistory, - tools: tools, - session: session, - encoder: encoder - ) - - let stream = try await streamEvents( - request: request, - urlSession: urlSession, - decoder: decoder - ) - - for try await event in stream { - switch event { - case let .assistantTextDelta(delta): - emittedRetryUnsafeOutput = true - if streamedStructuredOutput != nil { - for parsedEvent in structuredParser.consume(delta: delta) { - switch parsedEvent { - case let .visibleText(visibleDelta): - guard !visibleDelta.isEmpty else { - continue - } - continuation.yield( - .assistantMessageDelta( - threadID: threadID, - turnID: turnID, - delta: visibleDelta - ) - ) - case let .structuredOutputPartial(value): - continuation.yield(.structuredOutputPartial(value)) - case let .structuredOutputValidationFailed(validationFailure): - continuation.yield(.structuredOutputValidationFailed(validationFailure)) - } - } - } else { - continuation.yield( - .assistantMessageDelta( - threadID: threadID, - turnID: turnID, - delta: delta - ) - ) - } - - case let .assistantMessage(messageTemplate): - emittedRetryUnsafeOutput = true - - let normalizedMessage: AgentMessage - if let streamedStructuredOutput { - let extraction = structuredParser.finalize(rawMessage: messageTemplate.text) - - switch extraction.finalResult { - case .none: - break - case let .committed(value): - pendingStructuredOutputMetadata = AgentStructuredOutputMetadata( - formatName: streamedStructuredOutput.responseFormat.name, - payload: value - ) - continuation.yield(.structuredOutputCommitted(value)) - case let .invalid(validationFailure): - continuation.yield(.structuredOutputValidationFailed(validationFailure)) - throw AgentRuntimeError.structuredOutputInvalid( - stage: validationFailure.stage, - underlyingMessage: validationFailure.message - ) - } - - normalizedMessage = AgentMessage( - threadID: threadID, - role: .assistant, - text: extraction.visibleText, - images: messageTemplate.images - ) - } else { - normalizedMessage = AgentMessage( - threadID: threadID, - role: .assistant, - text: messageTemplate.text, - images: messageTemplate.images - ) - } - - let assistantText: String - if normalizedMessage.text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty, - !pendingToolFallbackTexts.isEmpty { - assistantText = pendingToolFallbackTexts.joined(separator: "\n\n") - } else { - assistantText = normalizedMessage.text - } - - let mergedImages = (normalizedMessage.images + pendingToolImages).uniqued() - let message = AgentMessage( - threadID: threadID, - role: .assistant, - text: assistantText, - images: mergedImages, - structuredOutput: pendingStructuredOutputMetadata - ?? CodexResponsesBackend.structuredMetadata( - from: assistantText, - responseFormat: responseFormat - ) - ) - workingHistory.append(.assistantMessage(message)) - continuation.yield(.assistantMessageCompleted(message)) - pendingToolImages.removeAll(keepingCapacity: true) - pendingToolFallbackTexts.removeAll(keepingCapacity: true) - pendingStructuredOutputMetadata = nil - - case let .structuredOutputPartial(value): - continuation.yield(.structuredOutputPartial(value)) - - case let .structuredOutputCommitted(value): - continuation.yield(.structuredOutputCommitted(value)) - - case let .structuredOutputValidationFailed(validationFailure): - continuation.yield(.structuredOutputValidationFailed(validationFailure)) - - case let .functionCall(functionCall): - emittedRetryUnsafeOutput = true - sawToolCall = true - workingHistory.append(.functionCall(functionCall)) - - let invocation = ToolInvocation( - id: functionCall.callID, - threadID: threadID, - turnID: turnID, - toolName: functionCall.name, - arguments: functionCall.arguments - ) - - continuation.yield(.toolCallRequested(invocation)) - let toolResult = try await pendingToolResults.wait(for: invocation.id) - let toolImages = await toolOutputImages(from: toolResult, urlSession: urlSession) - pendingToolImages.append(contentsOf: toolImages) - pendingToolImages = pendingToolImages.uniqued() - - if let primaryText = toolResult.primaryText? - .trimmingCharacters(in: .whitespacesAndNewlines), - !primaryText.isEmpty { - pendingToolFallbackTexts.append(primaryText) - } - - workingHistory.append( - .functionCallOutput( - callID: invocation.id, - output: toolOutputText(from: toolResult) - ) - ) - - case let .completed(usage): - aggregateUsage.inputTokens += usage.inputTokens - aggregateUsage.cachedInputTokens += usage.cachedInputTokens - aggregateUsage.outputTokens += usage.outputTokens - } - } - - break retryLoop - } catch { - guard !emittedRetryUnsafeOutput, - attempt < retryPolicy.maxAttempts, - shouldRetry(error, policy: retryPolicy) - else { - throw error - } - - let delay = retryPolicy.delayBeforeRetry(attempt: attempt) - if delay > 0 { - let nanoseconds = UInt64((delay * 1_000_000_000).rounded()) - try await Task.sleep(nanoseconds: nanoseconds) - } - attempt += 1 - } - } - - shouldContinue = sawToolCall - if !shouldContinue, (!pendingToolImages.isEmpty || !pendingToolFallbackTexts.isEmpty) { - let message = AgentMessage( - threadID: threadID, - role: .assistant, - text: pendingToolFallbackTexts.joined(separator: "\n\n"), - images: pendingToolImages - ) - workingHistory.append(.assistantMessage(message)) - continuation.yield(.assistantMessageCompleted(message)) - pendingToolImages.removeAll(keepingCapacity: true) - pendingToolFallbackTexts.removeAll(keepingCapacity: true) - } - } - - return aggregateUsage - } } diff --git a/Sources/CodexKit/Runtime/CodexResponsesTransport.swift b/Sources/CodexKit/Runtime/CodexResponsesTransport.swift new file mode 100644 index 0000000..d0c5d6f --- /dev/null +++ b/Sources/CodexKit/Runtime/CodexResponsesTransport.swift @@ -0,0 +1,379 @@ +import Foundation + +struct CodexResponsesRequestFactory: Sendable { + let configuration: CodexResponsesBackendConfiguration + let encoder: JSONEncoder + + func buildURLRequest( + instructions: String, + responseFormat: AgentStructuredOutputFormat?, + streamedStructuredOutput: AgentStreamedStructuredOutputRequest?, + threadID: String, + items: [WorkingHistoryItem], + tools: [ToolDefinition], + session: ChatGPTSession + ) throws -> URLRequest { + let resolvedInstructions = if let streamedStructuredOutput { + instructions + "\n\n" + CodexResponsesTurnSession.streamedStructuredOutputInstructions(for: streamedStructuredOutput) + } else { + instructions + } + let requestBody = ResponsesRequestBody( + model: configuration.model, + reasoning: .init(effort: configuration.reasoningEffort), + instructions: resolvedInstructions, + text: .init( + format: .init( + responseFormat: streamedStructuredOutput == nil ? responseFormat : nil + ) + ), + input: items.map(\.jsonValue), + tools: responsesTools( + from: tools, + enableWebSearch: configuration.enableWebSearch + ), + toolChoice: "auto", + parallelToolCalls: false, + store: false, + stream: true, + include: [], + promptCacheKey: threadID + ) + + var request = URLRequest(url: configuration.baseURL.appendingPathComponent("responses")) + request.httpMethod = "POST" + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + request.setValue("Bearer \(session.accessToken)", forHTTPHeaderField: "Authorization") + request.setValue(session.account.id, forHTTPHeaderField: "ChatGPT-Account-ID") + request.setValue(threadID, forHTTPHeaderField: "session_id") + request.setValue(threadID, forHTTPHeaderField: "x-client-request-id") + request.setValue(configuration.originator, forHTTPHeaderField: "originator") + + for (header, value) in configuration.extraHeaders { + request.setValue(value, forHTTPHeaderField: header) + } + + return request + } + + func responsesTools( + from tools: [ToolDefinition], + enableWebSearch: Bool + ) -> [JSONValue] { + var responsesTools = tools.map(\.responsesJSONValue) + if enableWebSearch { + responsesTools.append(.object(["type": .string("web_search")])) + } + return responsesTools + } +} + +struct CodexResponsesEventStreamClient: Sendable { + let urlSession: URLSession + let decoder: JSONDecoder + + func streamEvents( + request: URLRequest + ) async throws -> AsyncThrowingStream { + let (bytes, response) = try await urlSession.bytes(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw AgentRuntimeError( + code: "responses_invalid_response", + message: "The ChatGPT responses endpoint returned an invalid response." + ) + } + + if !(200 ..< 300).contains(httpResponse.statusCode) { + let bodyData = try await readAll(bytes) + let body = String(data: bodyData, encoding: .utf8) ?? "Unknown error" + if httpResponse.statusCode == 401 || httpResponse.statusCode == 403 { + throw AgentRuntimeError.unauthorized(body) + } + throw AgentRuntimeError( + code: "responses_http_status_\(httpResponse.statusCode)", + message: "The ChatGPT responses request failed with status \(httpResponse.statusCode): \(body)" + ) + } + + return AsyncThrowingStream { continuation in + Task { + var parser = SSEEventParser() + + do { + var lineBuffer = Data() + + for try await byte in bytes { + if byte == UInt8(ascii: "\n") { + var line = String(decoding: lineBuffer, as: UTF8.self) + if line.hasSuffix("\r") { + line.removeLast() + } + lineBuffer.removeAll(keepingCapacity: true) + + if let payload = parser.consume(line: line), + let event = try parseStreamEvent(from: payload) { + continuation.yield(event) + } + continue + } + + lineBuffer.append(byte) + } + + if !lineBuffer.isEmpty { + var line = String(decoding: lineBuffer, as: UTF8.self) + if line.hasSuffix("\r") { + line.removeLast() + } + if let payload = parser.consume(line: line), + let event = try parseStreamEvent(from: payload) { + continuation.yield(event) + } + } + + if let payload = parser.finish(), + let event = try parseStreamEvent(from: payload) { + continuation.yield(event) + } + + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + } + } + + func shouldRetry( + _ error: Error, + policy: RequestRetryPolicy + ) -> Bool { + if let runtimeError = error as? AgentRuntimeError { + if runtimeError.code == AgentRuntimeError.unauthorized().code { + return false + } + if let statusCode = httpStatusCode(from: runtimeError.code) { + return policy.retryableHTTPStatusCodes.contains(statusCode) + } + return false + } + + if let urlError = error as? URLError { + return policy.retryableURLErrorCodes.contains(urlError.errorCode) + } + + let nsError = error as NSError + if nsError.domain == NSURLErrorDomain { + return policy.retryableURLErrorCodes.contains(nsError.code) + } + + return false + } + + private func httpStatusCode(from errorCode: String) -> Int? { + let prefix = "responses_http_status_" + guard errorCode.hasPrefix(prefix) else { + return nil + } + return Int(errorCode.dropFirst(prefix.count)) + } + + private func parseStreamEvent( + from payload: SSEEventPayload + ) throws -> CodexResponsesStreamEvent? { + guard !payload.data.isEmpty else { + return nil + } + + let envelope = try decoder.decode( + StreamEnvelope.self, + from: Data(payload.data.utf8) + ) + + switch envelope.type { + case "response.output_text.delta": + return envelope.delta.map(CodexResponsesStreamEvent.assistantTextDelta) + case "response.output_item.done": + guard let item = envelope.item else { + return nil + } + + switch item { + case let .message(message): + let text = message.content + .compactMap(\.displayText) + .joined() + .trimmingCharacters(in: .whitespacesAndNewlines) + let images = message.content.compactMap(\.imageAttachment) + guard !text.isEmpty || !images.isEmpty else { + return nil + } + return .assistantMessage( + AgentMessage( + threadID: "", + role: .assistant, + text: text, + images: images + ) + ) + case let .functionCall(functionCall): + return .functionCall( + FunctionCallRecord( + name: functionCall.name, + callID: functionCall.callID, + argumentsRaw: functionCall.arguments + ) + ) + case .other: + return nil + } + case "response.completed": + let usage = envelope.response?.usage?.assistantUsage ?? AgentUsage() + return .completed(usage) + case "response.failed": + let message = envelope.response?.error?.message ?? "The ChatGPT responses stream failed." + throw AgentRuntimeError(code: "responses_stream_failed", message: message) + case "response.incomplete": + let reason = envelope.response?.incompleteDetails?.reason ?? "unknown" + throw AgentRuntimeError( + code: "responses_stream_incomplete", + message: "The ChatGPT responses stream completed early: \(reason)." + ) + default: + return nil + } + } + + private func readAll(_ bytes: URLSession.AsyncBytes) async throws -> Data { + var data = Data() + for try await byte in bytes { + data.append(byte) + } + return data + } +} + +struct CodexResponsesToolOutputAdapter: Sendable { + let urlSession: URLSession + + func text(from result: ToolResultEnvelope) -> String { + var segments: [String] = [] + + if let primaryText = result.primaryText, !primaryText.isEmpty { + segments.append(primaryText) + } + + let imageURLs = result.content.compactMap { content -> URL? in + guard case let .image(url) = content else { + return nil + } + return url + } + if !imageURLs.isEmpty { + segments.append("Image URLs:\n" + imageURLs.map(\.absoluteString).joined(separator: "\n")) + } + + if !segments.isEmpty { + return segments.joined(separator: "\n\n") + } + if let errorMessage = result.errorMessage, !errorMessage.isEmpty { + return errorMessage + } + return result.success ? "Tool execution completed." : "Tool execution failed." + } + + func images(from result: ToolResultEnvelope) async -> [AgentImageAttachment] { + var attachments: [AgentImageAttachment] = [] + for content in result.content { + guard case let .image(url) = content else { + continue + } + if let attachment = await imageAttachment(from: url) { + attachments.append(attachment) + } + } + return attachments.uniqued() + } + + private func imageAttachment(from url: URL) async -> AgentImageAttachment? { + if url.scheme?.lowercased() == "data" { + let decoded = url.absoluteString.removingPercentEncoding ?? url.absoluteString + return AgentImageAttachment(dataURLString: decoded) + } + + if url.isFileURL { + guard let mimeType = RuntimeImageMimeType(pathExtension: url.pathExtension), + let data = try? Data(contentsOf: url), + !data.isEmpty else { + return nil + } + return AgentImageAttachment(mimeType: mimeType.rawValue, data: data) + } + + do { + let (data, response) = try await urlSession.data(from: url) + guard !data.isEmpty else { + return nil + } + + let mimeType = RuntimeImageMimeType( + responseMimeType: response.mimeType, + pathExtension: url.pathExtension + ) ?? .png + guard mimeType.isImage else { + return nil + } + return AgentImageAttachment(mimeType: mimeType.rawValue, data: data) + } catch { + return nil + } + } +} + +private enum RuntimeImageMimeType: String { + case png = "image/png" + case jpeg = "image/jpeg" + case gif = "image/gif" + case webp = "image/webp" + case heic = "image/heic" + case heif = "image/heif" + + init?(pathExtension: String) { + switch pathExtension.lowercased() { + case "png": + self = .png + case "jpg", "jpeg": + self = .jpeg + case "gif": + self = .gif + case "webp": + self = .webp + case "heic": + self = .heic + case "heif": + self = .heif + default: + return nil + } + } + + init?(responseMimeType: String?, pathExtension: String) { + if let responseMimeType, + let normalized = Self(rawValue: responseMimeType.lowercased()) { + self = normalized + return + } + + guard let inferred = Self(pathExtension: pathExtension) else { + return nil + } + + self = inferred + } + + var isImage: Bool { + rawValue.hasPrefix("image/") + } +} diff --git a/Sources/CodexKit/Runtime/CodexResponsesTurnRunner.swift b/Sources/CodexKit/Runtime/CodexResponsesTurnRunner.swift new file mode 100644 index 0000000..f4ce554 --- /dev/null +++ b/Sources/CodexKit/Runtime/CodexResponsesTurnRunner.swift @@ -0,0 +1,434 @@ +import Foundation + +struct CodexResponsesTurnRunner { + let configuration: CodexResponsesBackendConfiguration + let instructions: String + let responseFormat: AgentStructuredOutputFormat? + let streamedStructuredOutput: AgentStreamedStructuredOutputRequest? + let requestFactory: CodexResponsesRequestFactory + let streamClient: CodexResponsesEventStreamClient + let toolOutputAdapter: CodexResponsesToolOutputAdapter + let threadID: String + let turnID: String + let tools: [ToolDefinition] + let session: ChatGPTSession + let pendingToolResults: PendingToolResults + let continuation: AsyncThrowingStream.Continuation + + init( + configuration: CodexResponsesBackendConfiguration, + instructions: String, + responseFormat: AgentStructuredOutputFormat?, + streamedStructuredOutput: AgentStreamedStructuredOutputRequest?, + urlSession: URLSession, + encoder: JSONEncoder, + decoder: JSONDecoder, + threadID: String, + turnID: String, + tools: [ToolDefinition], + session: ChatGPTSession, + pendingToolResults: PendingToolResults, + continuation: AsyncThrowingStream.Continuation + ) { + self.configuration = configuration + self.instructions = instructions + self.responseFormat = responseFormat + self.streamedStructuredOutput = streamedStructuredOutput + self.requestFactory = CodexResponsesRequestFactory(configuration: configuration, encoder: encoder) + self.streamClient = CodexResponsesEventStreamClient(urlSession: urlSession, decoder: decoder) + self.toolOutputAdapter = CodexResponsesToolOutputAdapter(urlSession: urlSession) + self.threadID = threadID + self.turnID = turnID + self.tools = tools + self.session = session + self.pendingToolResults = pendingToolResults + self.continuation = continuation + } + + func run( + history: [AgentMessage], + newMessage: UserMessageRequest + ) async throws -> AgentUsage { + var state = TurnRunState( + workingHistory: initialWorkingHistory(history: history, newMessage: newMessage) + ) + + try await runTurnPasses(state: &state) + emitPendingAssistantFallbackIfNeeded(state: &state) + return state.aggregateUsage + } + + private func initialWorkingHistory( + history: [AgentMessage], + newMessage: UserMessageRequest + ) -> [WorkingHistoryItem] { + var workingHistory = history.map(WorkingHistoryItem.visibleMessage) + workingHistory.append( + .userMessage( + AgentMessage( + threadID: threadID, + role: .user, + text: newMessage.text, + images: newMessage.images + ) + ) + ) + return workingHistory + } + + private func runTurnPasses( + state: inout TurnRunState + ) async throws { + var nextPass: TurnPassDisposition = .needsAnotherPass + + while case .needsAnotherPass = nextPass { + nextPass = try await runTurnPassWithRetry(state: &state) + } + } + + private func runTurnPassWithRetry( + state: inout TurnRunState + ) async throws -> TurnPassDisposition { + let retryPolicy = configuration.requestRetryPolicy + // Build one request per pass. Retries replay the same request, while a new pass + // is only started after tool output mutates the working history. + let request = try makeRequest(for: state) + + for attempt in 1...retryPolicy.maxAttempts { + var retryState = RetryAttemptState() + do { + return try await consumeEventStream( + request: request, + state: &state, + retryState: &retryState + ) + } catch { + guard shouldRetry( + error, + attempt: attempt, + policy: retryPolicy, + retryState: retryState + ) else { + throw error + } + try await sleepBeforeRetry(attempt: attempt, policy: retryPolicy) + } + } + + return .completed + } + + private func makeRequest( + for state: TurnRunState + ) throws -> URLRequest { + try requestFactory.buildURLRequest( + instructions: instructions, + responseFormat: responseFormat, + streamedStructuredOutput: streamedStructuredOutput, + threadID: threadID, + items: state.workingHistory, + tools: tools, + session: session + ) + } + + private func consumeEventStream( + request: URLRequest, + state: inout TurnRunState, + retryState: inout RetryAttemptState + ) async throws -> TurnPassDisposition { + let stream = try await streamClient.streamEvents(request: request) + var passDisposition: TurnPassDisposition = .completed + + for try await event in stream { + let eventResult = try await handleStreamEvent(event, state: &state) + passDisposition = passDisposition.merging(with: eventResult.passDisposition) + retryState.record(eventResult) + } + + return passDisposition + } + + private func handleStreamEvent( + _ event: CodexResponsesStreamEvent, + state: inout TurnRunState + ) async throws -> StreamEventResult { + switch event { + case let .assistantTextDelta(delta): + try handleAssistantTextDelta(delta, state: &state) + return .visibleOutput + + case let .assistantMessage(messageTemplate): + try handleAssistantMessage(messageTemplate, state: &state) + return .visibleOutput + + case let .structuredOutputPartial(value): + continuation.yield(.structuredOutputPartial(value)) + return .none + + case let .structuredOutputCommitted(value): + continuation.yield(.structuredOutputCommitted(value)) + return .none + + case let .structuredOutputValidationFailed(validationFailure): + continuation.yield(.structuredOutputValidationFailed(validationFailure)) + return .none + + case let .functionCall(functionCall): + try await handleFunctionCall(functionCall, state: &state) + return .toolCall + + case let .completed(usage): + state.aggregateUsage.inputTokens += usage.inputTokens + state.aggregateUsage.cachedInputTokens += usage.cachedInputTokens + state.aggregateUsage.outputTokens += usage.outputTokens + return .none + } + } + + private func handleAssistantTextDelta( + _ delta: String, + state: inout TurnRunState + ) throws { + guard streamedStructuredOutput != nil else { + continuation.yield( + .assistantMessageDelta( + threadID: threadID, + turnID: turnID, + delta: delta + ) + ) + return + } + + for parsedEvent in state.structuredParser.consume(delta: delta) { + switch parsedEvent { + case let .visibleText(visibleDelta): + guard !visibleDelta.isEmpty else { + continue + } + continuation.yield( + .assistantMessageDelta( + threadID: threadID, + turnID: turnID, + delta: visibleDelta + ) + ) + case let .structuredOutputPartial(value): + continuation.yield(.structuredOutputPartial(value)) + case let .structuredOutputValidationFailed(validationFailure): + continuation.yield(.structuredOutputValidationFailed(validationFailure)) + } + } + } + + private func handleAssistantMessage( + _ messageTemplate: AgentMessage, + state: inout TurnRunState + ) throws { + let normalizedMessage = try normalizedAssistantMessage( + from: messageTemplate, + state: &state + ) + let assistantText = resolvedAssistantText( + for: normalizedMessage, + fallbackTexts: state.pendingToolFallbackTexts + ) + let mergedImages = (normalizedMessage.images + state.pendingToolImages).uniqued() + let message = AgentMessage( + threadID: threadID, + role: .assistant, + text: assistantText, + images: mergedImages, + structuredOutput: state.pendingStructuredOutputMetadata + ?? CodexResponsesBackend.structuredMetadata( + from: assistantText, + responseFormat: responseFormat + ) + ) + + state.workingHistory.append(.assistantMessage(message)) + continuation.yield(.assistantMessageCompleted(message)) + state.pendingToolImages.removeAll(keepingCapacity: true) + state.pendingToolFallbackTexts.removeAll(keepingCapacity: true) + state.pendingStructuredOutputMetadata = nil + } + + private func normalizedAssistantMessage( + from messageTemplate: AgentMessage, + state: inout TurnRunState + ) throws -> AgentMessage { + guard let streamedStructuredOutput else { + return AgentMessage( + threadID: threadID, + role: .assistant, + text: messageTemplate.text, + images: messageTemplate.images + ) + } + + let extraction = state.structuredParser.finalize(rawMessage: messageTemplate.text) + + switch extraction.finalResult { + case .none: + break + case let .committed(value): + state.pendingStructuredOutputMetadata = AgentStructuredOutputMetadata( + formatName: streamedStructuredOutput.responseFormat.name, + payload: value + ) + continuation.yield(.structuredOutputCommitted(value)) + case let .invalid(validationFailure): + continuation.yield(.structuredOutputValidationFailed(validationFailure)) + throw AgentRuntimeError.structuredOutputInvalid( + stage: validationFailure.stage, + underlyingMessage: validationFailure.message + ) + } + + return AgentMessage( + threadID: threadID, + role: .assistant, + text: extraction.visibleText, + images: messageTemplate.images + ) + } + + private func resolvedAssistantText( + for message: AgentMessage, + fallbackTexts: [String] + ) -> String { + let trimmed = message.text.trimmingCharacters(in: .whitespacesAndNewlines) + guard trimmed.isEmpty, !fallbackTexts.isEmpty else { + return message.text + } + return fallbackTexts.joined(separator: "\n\n") + } + + private func handleFunctionCall( + _ functionCall: FunctionCallRecord, + state: inout TurnRunState + ) async throws { + state.workingHistory.append(.functionCall(functionCall)) + + let invocation = ToolInvocation( + id: functionCall.callID, + threadID: threadID, + turnID: turnID, + toolName: functionCall.name, + arguments: functionCall.arguments + ) + + continuation.yield(.toolCallRequested(invocation)) + let toolResult = try await pendingToolResults.wait(for: invocation.id) + let toolImages = await toolOutputAdapter.images(from: toolResult) + state.pendingToolImages.append(contentsOf: toolImages) + state.pendingToolImages = state.pendingToolImages.uniqued() + + if let primaryText = toolResult.primaryText? + .trimmingCharacters(in: .whitespacesAndNewlines), + !primaryText.isEmpty { + state.pendingToolFallbackTexts.append(primaryText) + } + + state.workingHistory.append( + .functionCallOutput( + callID: invocation.id, + output: toolOutputAdapter.text(from: toolResult) + ) + ) + } + + private func emitPendingAssistantFallbackIfNeeded( + state: inout TurnRunState + ) { + guard !state.pendingToolImages.isEmpty || !state.pendingToolFallbackTexts.isEmpty else { + return + } + + let message = AgentMessage( + threadID: threadID, + role: .assistant, + text: state.pendingToolFallbackTexts.joined(separator: "\n\n"), + images: state.pendingToolImages + ) + state.workingHistory.append(.assistantMessage(message)) + continuation.yield(.assistantMessageCompleted(message)) + state.pendingToolImages.removeAll(keepingCapacity: true) + state.pendingToolFallbackTexts.removeAll(keepingCapacity: true) + } + + private func shouldRetry( + _ error: Error, + attempt: Int, + policy: RequestRetryPolicy, + retryState: RetryAttemptState + ) -> Bool { + !retryState.hasVisibleOutput + && attempt < policy.maxAttempts + && streamClient.shouldRetry(error, policy: policy) + } + + private func sleepBeforeRetry( + attempt: Int, + policy: RequestRetryPolicy + ) async throws { + let delay = policy.delayBeforeRetry(attempt: attempt) + guard delay > 0 else { + return + } + let nanoseconds = UInt64((delay * 1_000_000_000).rounded()) + try await Task.sleep(nanoseconds: nanoseconds) + } +} + +private enum TurnPassDisposition { + case needsAnotherPass + case completed + + func merging(with other: TurnPassDisposition) -> TurnPassDisposition { + switch (self, other) { + case (.needsAnotherPass, _), (_, .needsAnotherPass): + return .needsAnotherPass + case (.completed, .completed): + return .completed + } + } +} + +private struct TurnRunState { + var workingHistory: [WorkingHistoryItem] + var aggregateUsage = AgentUsage() + var pendingToolImages: [AgentImageAttachment] = [] + var pendingToolFallbackTexts: [String] = [] + var structuredParser = CodexResponsesStructuredStreamParser() + var pendingStructuredOutputMetadata: AgentStructuredOutputMetadata? +} + +private struct RetryAttemptState { + var hasVisibleOutput = false + + mutating func record(_ eventResult: StreamEventResult) { + hasVisibleOutput = hasVisibleOutput || eventResult.emittedVisibleOutput + } +} + +private struct StreamEventResult { + let emittedVisibleOutput: Bool + let passDisposition: TurnPassDisposition + + static let none = StreamEventResult( + emittedVisibleOutput: false, + passDisposition: .completed + ) + + static let visibleOutput = StreamEventResult( + emittedVisibleOutput: true, + passDisposition: .completed + ) + + static let toolCall = StreamEventResult( + emittedVisibleOutput: true, + passDisposition: .needsAnotherPass + ) +} diff --git a/Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift b/Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift index f6302f9..77bdc43 100644 --- a/Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift +++ b/Sources/CodexKit/Runtime/StoredRuntimeState+Queries.swift @@ -2,6 +2,7 @@ import Foundation extension StoredRuntimeState { func normalized() -> StoredRuntimeState { + let projections = StoredRuntimeStateProjectionBuilder() let sortedThreads = threads.sorted { lhs, rhs in if lhs.updatedAt == rhs.updatedAt { return lhs.id < rhs.id @@ -20,7 +21,7 @@ extension StoredRuntimeState { } for (threadID, messages) in messagesByThread where normalizedHistory[threadID]?.isEmpty != false { - normalizedHistory[threadID] = Self.syntheticHistory(from: messages) + normalizedHistory[threadID] = projections.syntheticHistory(from: messages) } let normalizedMessages: [String: [AgentMessage]] = normalizedHistory.mapValues { records in @@ -43,7 +44,7 @@ extension StoredRuntimeState { var normalizedContextState = contextStateByThread for thread in sortedThreads { let history = normalizedHistory[thread.id] ?? [] - normalizedSummaries[thread.id] = Self.rebuildSummary( + normalizedSummaries[thread.id] = projections.rebuildSummary( for: thread, history: history, existing: summariesByThread[thread.id] @@ -80,7 +81,7 @@ extension StoredRuntimeState { } func threadSummaryFallback(for thread: AgentThread) -> AgentThreadSummary { - Self.rebuildSummary( + StoredRuntimeStateProjectionBuilder().rebuildSummary( for: thread, history: historyByThread[thread.id] ?? [], existing: summariesByThread[thread.id] @@ -265,8 +266,10 @@ extension StoredRuntimeState { return updated.normalized() } +} - private static func syntheticHistory(from messages: [AgentMessage]) -> [AgentHistoryRecord] { +struct StoredRuntimeStateProjectionBuilder: Sendable { + func syntheticHistory(from messages: [AgentMessage]) -> [AgentHistoryRecord] { let orderedMessages = messages.enumerated().sorted { lhs, rhs in let left = lhs.element let right = rhs.element @@ -285,7 +288,7 @@ extension StoredRuntimeState { } } - private static func rebuildSummary( + func rebuildSummary( for thread: AgentThread, history: [AgentHistoryRecord], existing: AgentThreadSummary? @@ -306,7 +309,6 @@ extension StoredRuntimeState { latestStructuredOutputMetadata = structuredOutput } } - case let .toolCall(toolCall): latestToolState = AgentLatestToolState( invocationID: toolCall.invocation.id, @@ -315,16 +317,12 @@ extension StoredRuntimeState { status: .waiting, updatedAt: toolCall.requestedAt ) - case let .toolResult(toolResult): - latestToolState = Self.latestToolState(from: toolResult) - + latestToolState = self.latestToolState(from: toolResult) case let .structuredOutput(structuredOutput): latestStructuredOutputMetadata = structuredOutput.metadata - case .approval: break - case let .systemEvent(systemEvent): switch systemEvent.type { case .turnStarted: @@ -354,7 +352,7 @@ extension StoredRuntimeState { ) } - private static func latestToolState(from toolResult: AgentToolResultRecord) -> AgentLatestToolState { + func latestToolState(from toolResult: AgentToolResultRecord) -> AgentLatestToolState { let preview = toolResult.result.primaryText let session = toolResult.result.session let status: AgentToolSessionStatus From fc377be836654e8873ef4e9e7ac6eabbb55783a6 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 19:37:07 +1100 Subject: [PATCH 18/19] Add runtime thread context usage projection --- .../AgentDemoViewModel+Observation.swift | 23 ++++++ .../AgentDemoViewModel+ThreadState.swift | 7 +- .../Shared/AgentDemoViewModel.swift | 2 + .../Shared/ThreadDetailView.swift | 78 +++++++++---------- DemoApp/README.md | 9 ++- README.md | 14 +++- Sources/CodexKit/Runtime/AgentHistory.swift | 1 + .../AgentRuntime+ContextCompaction.swift | 4 +- .../Runtime/AgentRuntime+ContextUsage.swift | 41 ++++++++++ .../Runtime/AgentRuntime+History.swift | 4 + Sources/CodexKit/Runtime/AgentRuntime.swift | 6 ++ .../Runtime/AgentRuntimeObservation.swift | 27 +++++++ .../Runtime/AgentThreadContextUsage.swift | 44 +++++++++++ .../Runtime/CodexResponsesBackend.swift | 27 +++++++ .../AgentRuntimeHistoryCompactionTests.swift | 61 ++++++++++++++- .../AgentRuntimeHistoryTestSupport.swift | 4 +- .../AgentRuntimeMessageBehaviorTests.swift | 66 +++++++++++++++- 17 files changed, 364 insertions(+), 54 deletions(-) create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+ContextUsage.swift create mode 100644 Sources/CodexKit/Runtime/AgentThreadContextUsage.swift diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift index cd3c49a..6a2bef9 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+Observation.swift @@ -72,6 +72,17 @@ extension AgentDemoViewModel { self.activeThreadContextState = contextState } .store(in: &activeThreadObservationCancellables) + + runtime.observeThreadContextUsage(id: threadID) + .receive(on: DispatchQueue.main) + .sink { [weak self] contextUsage in + guard let self else { + return + } + self.observedThreadContextUsage = contextUsage + self.activeThreadContextUsage = contextUsage + } + .store(in: &activeThreadObservationCancellables) } func resetObservedThreadState() { @@ -80,5 +91,17 @@ extension AgentDemoViewModel { observedThreadSummary = nil observedThreadContextState = nil activeThreadContextState = nil + observedThreadContextUsage = nil + activeThreadContextUsage = nil + } + + func formattedTokenCount(_ tokens: Int) -> String { + if tokens >= 1_000_000 { + return String(format: "%.1fM", Double(tokens) / 1_000_000) + } + if tokens >= 1_000 { + return String(format: "%.1fk", Double(tokens) / 1_000) + } + return "\(tokens)" } } diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift index fa78e26..7a73040 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel+ThreadState.swift @@ -280,9 +280,13 @@ extension AgentDemoViewModel { do { activeThreadContextState = try await runtime.fetchThreadContextState(id: resolvedThreadID) observedThreadContextState = activeThreadContextState + activeThreadContextUsage = try await runtime.fetchThreadContextUsage(id: resolvedThreadID) + observedThreadContextUsage = activeThreadContextUsage } catch { activeThreadContextState = nil observedThreadContextState = nil + activeThreadContextUsage = nil + observedThreadContextUsage = nil developerErrorLog("Failed to fetch thread context state. threadID=\(resolvedThreadID) error=\(error.localizedDescription)") } } @@ -325,10 +329,11 @@ extension AgentDemoViewModel { do { developerLog("Manual context compaction started. threadID=\(activeThreadID)") activeThreadContextState = try await runtime.compactThreadContext(id: activeThreadID) + activeThreadContextUsage = try await runtime.fetchThreadContextUsage(id: activeThreadID) threads = await runtime.threads() setMessages(await runtime.messages(for: activeThreadID)) developerLog( - "Manual context compaction finished. threadID=\(activeThreadID) generation=\(activeThreadContextState?.generation ?? 0) effectiveMessages=\(activeThreadContextState?.effectiveMessages.count ?? 0)" + "Manual context compaction finished. threadID=\(activeThreadID) generation=\(activeThreadContextState?.generation ?? 0) effectiveTokens=\(activeThreadContextUsage?.effectiveEstimatedTokenCount ?? 0)" ) } catch { reportError(error) diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift index 4678a84..337a650 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/AgentDemoViewModel.swift @@ -239,6 +239,8 @@ final class AgentDemoViewModel: @unchecked Sendable { var observedMessages: [AgentMessage] = [] var observedThreadSummary: AgentThreadSummary? var observedThreadContextState: AgentThreadContextState? + var activeThreadContextUsage: AgentThreadContextUsage? + var observedThreadContextUsage: AgentThreadContextUsage? let approvalInbox: ApprovalInbox let deviceCodePromptCoordinator: DeviceCodePromptCoordinator diff --git a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift index 1f57e8d..2d2b4d9 100644 --- a/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift +++ b/DemoApp/AssistantRuntimeDemoApp/Shared/ThreadDetailView.swift @@ -85,17 +85,13 @@ private extension ThreadDetailView { viewModel.observedThread?.id == threadID ? viewModel.observedThread : activeThread } - var observedContextState: AgentThreadContextState? { - viewModel.observedThreadContextState ?? viewModel.activeThreadContextState - } + var observedContextState: AgentThreadContextState? { viewModel.observedThreadContextState ?? viewModel.activeThreadContextState } + var observedContextUsage: AgentThreadContextUsage? { viewModel.observedThreadContextUsage ?? viewModel.activeThreadContextUsage } var observedSummary: AgentThreadSummary? { viewModel.observedThreadSummary } var turnActivityStatus: AgentThreadStatus? { - guard let status = observedThread?.status else { - return nil - } - + guard let status = observedThread?.status else { return nil } switch status { case .streaming where !isStreamingActive, .waitingForApproval, .waitingForToolResult: return status @@ -120,17 +116,11 @@ private extension ThreadDetailView { } HStack(spacing: 10) { - Label( - viewModel.model, - systemImage: "cpu" - ) + Label(viewModel.model, systemImage: "cpu") .font(.caption) .foregroundStyle(.secondary) - Label( - reasoningEffortTitle, - systemImage: reasoningEffortSymbol - ) + Label(reasoningEffortTitle, systemImage: reasoningEffortSymbol) .font(.caption) .foregroundStyle(.secondary) } @@ -142,7 +132,7 @@ private extension ThreadDetailView { VStack(alignment: .leading, spacing: 8) { Text("Observation Demo") .font(.headline) - Text("This card is driven by `observeThread`, `observeMessages`, `observeThreadSummary`, and `observeThreadContextState`, so title, transcript, summary, and compaction changes update live without a manual refresh.") + Text("This card is driven by `observeThread`, `observeMessages`, `observeThreadSummary`, `observeThreadContextState`, and `observeThreadContextUsage`, so title, transcript, summary, compaction state, and context usage all update live without a manual refresh.") .font(.subheadline) .foregroundStyle(.secondary) } @@ -157,9 +147,7 @@ private extension ThreadDetailView { .textFieldStyle(.roundedBorder) Button { - Task { - await viewModel.updateActiveThreadTitle(threadTitleDraft) - } + Task { await viewModel.updateActiveThreadTitle(threadTitleDraft) } } label: { Label("Save Thread Title", systemImage: "pencil") } @@ -198,7 +186,6 @@ private extension ThreadDetailView { VStack(alignment: .leading, spacing: 8) { Text("Context Compaction") .font(.headline) - Text("Preserves the visible transcript, but rewrites the runtime’s hidden effective prompt context for future turns.") .font(.subheadline) .foregroundStyle(.secondary) @@ -206,17 +193,37 @@ private extension ThreadDetailView { HStack(spacing: 12) { compactionMetric( - title: "Visible Messages", - value: "\(threadMessages.count)" - ) - compactionMetric( - title: "Effective Messages", - value: "\(observedContextState?.effectiveMessages.count ?? threadMessages.count)" + title: "Visible Context", + value: "\(viewModel.formattedTokenCount(observedContextUsage?.visibleEstimatedTokenCount ?? 0)) tokens" ) compactionMetric( - title: "Generation", - value: "\(observedContextState?.generation ?? 0)" + title: "Effective Context", + value: "\(viewModel.formattedTokenCount(observedContextUsage?.effectiveEstimatedTokenCount ?? 0)) tokens" ) + compactionMetric(title: "Generation", value: "\(observedContextState?.generation ?? 0)") + } + + if let contextUsage = observedContextUsage, + let percentFull = contextUsage.percentUsed, + let windowTokens = contextUsage.usableContextWindowTokenCount { + let effectiveTokens = viewModel.formattedTokenCount(contextUsage.effectiveEstimatedTokenCount) + let usableTokens = viewModel.formattedTokenCount(windowTokens) + Text("Estimated context window: \(percentFull)% full (\(effectiveTokens) / \(usableTokens) usable tokens)") + .font(.caption) + .foregroundStyle(.secondary) + } else { + Text("Estimated effective prompt usage: \(viewModel.formattedTokenCount(observedContextUsage?.effectiveEstimatedTokenCount ?? 0)) tokens") + .font(.caption) + .foregroundStyle(.secondary) + } + + if let contextUsage = observedContextUsage, + contextUsage.estimatedTokenSavings > 0 { + let visibleTokens = viewModel.formattedTokenCount(contextUsage.visibleEstimatedTokenCount) + let effectiveTokens = viewModel.formattedTokenCount(contextUsage.effectiveEstimatedTokenCount) + Text("Current compaction savings: ~\(visibleTokens) -> \(effectiveTokens) estimated tokens") + .font(.caption) + .foregroundStyle(.secondary) } if let contextState = observedContextState { @@ -226,13 +233,11 @@ private extension ThreadDetailView { .font(.caption) .foregroundStyle(.secondary) } - if let lastCompactedAt = contextState.lastCompactedAt { Text("Updated \(lastCompactedAt.formatted(date: .abbreviated, time: .shortened))") .font(.caption) .foregroundStyle(.secondary) } - if let summaryMessage = contextState.effectiveMessages.first(where: { $0.role == .system }), !summaryMessage.text.isEmpty { Text(summaryMessage.text) @@ -250,14 +255,9 @@ private extension ThreadDetailView { HStack(spacing: 10) { Button { - Task { - await viewModel.compactActiveThreadContext() - } + Task { await viewModel.compactActiveThreadContext() } } label: { - Label( - viewModel.isCompactingThreadContext ? "Compacting..." : "Compact Context Now", - systemImage: viewModel.isCompactingThreadContext ? "hourglass" : "arrow.triangle.branch" - ) + Label(viewModel.isCompactingThreadContext ? "Compacting..." : "Compact Context Now", systemImage: viewModel.isCompactingThreadContext ? "hourglass" : "arrow.triangle.branch") } .buttonStyle(.borderedProminent) .disabled(viewModel.session == nil || activeThread == nil || viewModel.isCompactingThreadContext) @@ -346,9 +346,7 @@ private extension ThreadDetailView { } var threadMessages: [AgentMessage] { - guard viewModel.activeThreadID == threadID else { - return [] - } + guard viewModel.activeThreadID == threadID else { return [] } return viewModel.observedMessages.isEmpty ? viewModel.messages : viewModel.observedMessages } diff --git a/DemoApp/README.md b/DemoApp/README.md index d36d766..0455eda 100644 --- a/DemoApp/README.md +++ b/DemoApp/README.md @@ -22,7 +22,7 @@ The Xcode project is the source of truth for the demo app. Edit it directly in X - lets you attach a photo from the library and send it with or without text - renders attached user images in the transcript - streams assistant output into the UI -- demonstrates live Combine observation of thread, message, summary, and context-state updates +- demonstrates live Combine observation of thread, message, summary, context-state, and context-usage updates - lets you rename the active thread from the thread detail screen using `setTitle(_:for:)` - includes a thread-level `Context Compaction` card so you can compact effective prompt state without removing visible transcript history - supports approval prompts for host-defined tools that opt into `requiresApproval` @@ -54,8 +54,9 @@ The app links `CodexKit` and `CodexKitUI` from the repo's local `Package.swift`, The checked-in demo enables context compaction in automatic mode. In a thread detail screen, the `Context Compaction` card shows: -- visible transcript message count -- effective prompt message count +- visible transcript token usage +- effective prompt token usage +- estimated context window fullness when available - compaction generation - last compaction reason/time - a `Compact Context Now` action for manual testing @@ -66,12 +67,14 @@ The same thread detail screen also includes an `Observation Demo` card. It subsc - `observeMessages(in:)` - `observeThreadSummary(id:)` - `observeThreadContextState(id:)` +- `observeThreadContextUsage(id:)` Use that card to verify that: - local title changes propagate immediately through `setTitle(_:for:)` - new messages appear from the observation stream without a manual refresh - context compaction updates the observed context state live +- effective prompt usage updates live in estimated tokens ## Files diff --git a/README.md b/README.md index 981ffde..aeb5e1e 100644 --- a/README.md +++ b/README.md @@ -302,6 +302,13 @@ runtime.observeThreadContextState(id: thread.id) } .store(in: &cancellables) +runtime.observeThreadContextUsage(id: thread.id) + .receive(on: DispatchQueue.main) + .sink { usage in + print("Estimated effective tokens:", usage?.effectiveEstimatedTokenCount ?? 0) + } + .store(in: &cancellables) + try await runtime.setTitle("Shipping Triage", for: thread.id) ``` @@ -312,6 +319,7 @@ Available built-in publishers: - `observeMessages(in:)` - `observeThreadSummary(id:)` - `observeThreadContextState(id:)` +- `observeThreadContextUsage(id:)` The checked-in demo app includes a thread detail `Observation Demo` card that exercises these publishers live, along with a rename control that calls `setTitle(_:for:)`. @@ -339,12 +347,16 @@ let runtime = try AgentRuntime(configuration: .init( let contextState = try await runtime.compactThreadContext(id: thread.id) print(contextState.generation) + +let usage = try await runtime.fetchThreadContextUsage(id: thread.id) +print(usage?.effectiveEstimatedTokenCount ?? 0) ``` -For debug tooling or host inspection, you can also read the compacted effective context directly: +For debug tooling or host inspection, you can also read the compacted effective context and the current estimated context-window usage directly: ```swift let contextState = try await runtime.fetchThreadContextState(id: thread.id) +let usage = try await runtime.fetchThreadContextUsage(id: thread.id) let contexts = try await runtime.execute( ThreadContextStateQuery(threadIDs: [thread.id]) ) diff --git a/Sources/CodexKit/Runtime/AgentHistory.swift b/Sources/CodexKit/Runtime/AgentHistory.swift index e27d831..5963be5 100644 --- a/Sources/CodexKit/Runtime/AgentHistory.swift +++ b/Sources/CodexKit/Runtime/AgentHistory.swift @@ -178,6 +178,7 @@ public protocol AgentRuntimeThreadInspecting: Sendable { ) async throws -> AgentThreadHistoryPage func fetchLatestStructuredOutputMetadata(id: String) async throws -> AgentStructuredOutputMetadata? func fetchThreadContextState(id: String) async throws -> AgentThreadContextState? + func fetchThreadContextUsage(id: String) async throws -> AgentThreadContextUsage? } public extension AgentThreadSummary { diff --git a/Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift b/Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift index aff780c..13b0243 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+ContextCompaction.swift @@ -16,10 +16,12 @@ extension AgentRuntime { return } + let currentEffectiveMessages = state.contextStateByThread[message.threadID]?.effectiveMessages + ?? Array((state.messagesByThread[message.threadID] ?? []).dropLast()) let current = state.contextStateByThread[message.threadID] ?? AgentThreadContextState( threadID: message.threadID, - effectiveMessages: state.messagesByThread[message.threadID] ?? [] + effectiveMessages: currentEffectiveMessages ) let updated = AgentThreadContextState( threadID: current.threadID, diff --git a/Sources/CodexKit/Runtime/AgentRuntime+ContextUsage.swift b/Sources/CodexKit/Runtime/AgentRuntime+ContextUsage.swift new file mode 100644 index 0000000..c1cf670 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+ContextUsage.swift @@ -0,0 +1,41 @@ +import Foundation + +extension AgentRuntime { + func threadContextUsage(for threadID: String) -> AgentThreadContextUsage? { + guard state.threads.contains(where: { $0.id == threadID }) else { + return nil + } + + let visibleMessages = state.messagesByThread[threadID] ?? [] + let effectiveMessages = effectiveHistory(for: threadID) + + return AgentThreadContextUsage( + threadID: threadID, + visibleEstimatedTokenCount: approximateTokenCount(for: visibleMessages), + effectiveEstimatedTokenCount: approximateTokenCount(for: effectiveMessages), + modelContextWindowTokenCount: modelContextWindowTokenCount(), + usableContextWindowTokenCount: usableContextWindowTokenCount() + ) + } + + func approximateTokenCount(for messages: [AgentMessage]) -> Int { + guard !messages.isEmpty else { + return 0 + } + + return max( + 1, + messages.reduce(into: 0) { total, message in + total += message.text.count + (message.images.count * 512) + } / 4 + ) + } + + private func modelContextWindowTokenCount() -> Int? { + (backend as? any AgentBackendContextWindowProviding)?.modelContextWindowTokenCount + } + + private func usableContextWindowTokenCount() -> Int? { + (backend as? any AgentBackendContextWindowProviding)?.usableContextWindowTokenCount + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+History.swift b/Sources/CodexKit/Runtime/AgentRuntime+History.swift index 4eb444c..24152db 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime+History.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime+History.swift @@ -84,6 +84,10 @@ extension AgentRuntime: AgentRuntimeQueryable, AgentRuntimeThreadInspecting { ).first } + public func fetchThreadContextUsage(id: String) async throws -> AgentThreadContextUsage? { + threadContextUsage(for: id) + } + public func fetchLatestStructuredOutput( id: String, as outputType: Output.Type, diff --git a/Sources/CodexKit/Runtime/AgentRuntime.swift b/Sources/CodexKit/Runtime/AgentRuntime.swift index aeda489..61a564f 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime.swift @@ -297,6 +297,12 @@ public actor AgentRuntime { state: state.contextStateByThread[threadID] ) ) + observationCenter.send( + .threadContextUsageChanged( + threadID: threadID, + usage: threadContextUsage(for: threadID) + ) + ) } func resolveInstructions( diff --git a/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift b/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift index 51312ce..a2693f5 100644 --- a/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift +++ b/Sources/CodexKit/Runtime/AgentRuntimeObservation.swift @@ -7,6 +7,7 @@ public enum AgentRuntimeObservation: Sendable { case messagesChanged(threadID: String, messages: [AgentMessage]) case threadSummaryChanged(AgentThreadSummary) case threadContextStateChanged(threadID: String, state: AgentThreadContextState?) + case threadContextUsageChanged(threadID: String, usage: AgentThreadContextUsage?) case threadDeleted(threadID: String) } @@ -18,6 +19,7 @@ public final class AgentRuntimeObservationCenter: @unchecked Sendable { private var messageSubjects: [String: CurrentValueSubject<[AgentMessage], Never>] = [:] private var summarySubjects: [String: CurrentValueSubject] = [:] private var contextStateSubjects: [String: CurrentValueSubject] = [:] + private var contextUsageSubjects: [String: CurrentValueSubject] = [:] public init() {} @@ -53,6 +55,12 @@ public final class AgentRuntimeObservationCenter: @unchecked Sendable { } } + public func threadContextUsagePublisher(for threadID: String) -> AnyPublisher { + withLock { + contextUsageSubject(for: threadID).eraseToAnyPublisher() + } + } + func send(_ observation: AgentRuntimeObservation) { var updates: [() -> Void] = [] @@ -77,15 +85,21 @@ public final class AgentRuntimeObservationCenter: @unchecked Sendable { let subject = contextStateSubject(for: threadID) updates.append { subject.send(state) } + case let .threadContextUsageChanged(threadID, usage): + let subject = contextUsageSubject(for: threadID) + updates.append { subject.send(usage) } + case let .threadDeleted(threadID): let threadSubject = threadSubject(for: threadID) let messageSubject = messageSubject(for: threadID) let summarySubject = summarySubject(for: threadID) let contextStateSubject = contextStateSubject(for: threadID) + let contextUsageSubject = contextUsageSubject(for: threadID) updates.append { threadSubject.send(nil) } updates.append { messageSubject.send([]) } updates.append { summarySubject.send(nil) } updates.append { contextStateSubject.send(nil) } + updates.append { contextUsageSubject.send(nil) } } } @@ -129,6 +143,15 @@ public final class AgentRuntimeObservationCenter: @unchecked Sendable { return subject } + private func contextUsageSubject(for threadID: String) -> CurrentValueSubject { + if let subject = contextUsageSubjects[threadID] { + return subject + } + let subject = CurrentValueSubject(nil) + contextUsageSubjects[threadID] = subject + return subject + } + private func withLock(_ body: () -> T) -> T { lock.lock() defer { lock.unlock() } @@ -156,4 +179,8 @@ extension AgentRuntime { public nonisolated func observeThreadContextState(id threadID: String) -> AnyPublisher { observationCenter.threadContextStatePublisher(for: threadID) } + + public nonisolated func observeThreadContextUsage(id threadID: String) -> AnyPublisher { + observationCenter.threadContextUsagePublisher(for: threadID) + } } diff --git a/Sources/CodexKit/Runtime/AgentThreadContextUsage.swift b/Sources/CodexKit/Runtime/AgentThreadContextUsage.swift new file mode 100644 index 0000000..eaffdc9 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentThreadContextUsage.swift @@ -0,0 +1,44 @@ +import Foundation + +public struct AgentThreadContextUsage: Codable, Hashable, Sendable { + public let threadID: String + public let visibleEstimatedTokenCount: Int + public let effectiveEstimatedTokenCount: Int + public let modelContextWindowTokenCount: Int? + public let usableContextWindowTokenCount: Int? + + public init( + threadID: String, + visibleEstimatedTokenCount: Int, + effectiveEstimatedTokenCount: Int, + modelContextWindowTokenCount: Int? = nil, + usableContextWindowTokenCount: Int? = nil + ) { + self.threadID = threadID + self.visibleEstimatedTokenCount = visibleEstimatedTokenCount + self.effectiveEstimatedTokenCount = effectiveEstimatedTokenCount + self.modelContextWindowTokenCount = modelContextWindowTokenCount + self.usableContextWindowTokenCount = usableContextWindowTokenCount + } +} + +public extension AgentThreadContextUsage { + var estimatedTokenSavings: Int { + max(0, visibleEstimatedTokenCount - effectiveEstimatedTokenCount) + } + + var percentUsed: Int? { + guard let usableContextWindowTokenCount, + usableContextWindowTokenCount > 0 else { + return nil + } + + let percent = Double(effectiveEstimatedTokenCount) / Double(usableContextWindowTokenCount) * 100 + return min(100, Int(percent.rounded())) + } +} + +public protocol AgentBackendContextWindowProviding: Sendable { + var modelContextWindowTokenCount: Int? { get } + var usableContextWindowTokenCount: Int? { get } +} diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift index c8dd175..ade2d4a 100644 --- a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift @@ -36,6 +36,23 @@ public struct CodexResponsesBackendConfiguration: Sendable { } } +extension CodexResponsesBackendConfiguration { + var modelContextWindowTokenCount: Int? { + let normalizedModel = model.lowercased() + if normalizedModel.hasPrefix("gpt-5") { + return 272_000 + } + return nil + } + + var usableContextWindowTokenCount: Int? { + guard let modelContextWindowTokenCount else { + return nil + } + return (modelContextWindowTokenCount * 95) / 100 + } +} + public actor CodexResponsesBackend: AgentBackend { public nonisolated let baseInstructions: String? @@ -88,6 +105,16 @@ public actor CodexResponsesBackend: AgentBackend { } } +extension CodexResponsesBackend: AgentBackendContextWindowProviding { + public nonisolated var modelContextWindowTokenCount: Int? { + configuration.modelContextWindowTokenCount + } + + public nonisolated var usableContextWindowTokenCount: Int? { + configuration.usableContextWindowTokenCount + } +} + extension CodexResponsesBackend { static func structuredMetadata( from text: String, diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift b/Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift index bffb4c2..88cb0b5 100644 --- a/Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeHistoryCompactionTests.swift @@ -10,9 +10,14 @@ extension AgentRuntimeTests { _ = try await runtime.signIn() let thread = try await runtime.createThread(title: "Compaction") - _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + let longMessages = [ + String(repeating: "first message context ", count: 30), + String(repeating: "second message context ", count: 30), + String(repeating: "third message context ", count: 30), + ] + for message in longMessages { + _ = try await runtime.sendMessage(UserMessageRequest(text: message), in: thread.id) + } let visibleBefore = await runtime.messages(for: thread.id) XCTAssertEqual(visibleBefore.count, 6) @@ -81,6 +86,56 @@ extension AgentRuntimeTests { XCTAssertFalse(restoredContext?.effectiveMessages.isEmpty ?? true) } + func testFetchThreadContextUsageReportsVisibleAndEffectiveTokenCounts() async throws { + let backend = CompactingTestBackend() + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + contextCompaction: .init( + isEnabled: true, + mode: .manual, + strategy: .preferRemoteThenLocal + ) + ) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Usage") + let longMessages = [ + String(repeating: "first message context ", count: 30), + String(repeating: "second message context ", count: 30), + String(repeating: "third message context ", count: 30), + ] + for message in longMessages { + _ = try await runtime.sendMessage(UserMessageRequest(text: message), in: thread.id) + } + + let usageBefore = try await runtime.fetchThreadContextUsage(id: thread.id) + let unwrappedUsageBefore = try XCTUnwrap(usageBefore) + XCTAssertEqual( + unwrappedUsageBefore.visibleEstimatedTokenCount, + unwrappedUsageBefore.effectiveEstimatedTokenCount + ) + XCTAssertEqual(unwrappedUsageBefore.modelContextWindowTokenCount, 272_000) + XCTAssertEqual(unwrappedUsageBefore.usableContextWindowTokenCount, 258_400) + XCTAssertNotNil(unwrappedUsageBefore.percentUsed) + + _ = try await runtime.compactThreadContext(id: thread.id) + + let usageAfter = try await runtime.fetchThreadContextUsage(id: thread.id) + let unwrappedUsageAfter = try XCTUnwrap(usageAfter) + XCTAssertEqual( + unwrappedUsageAfter.visibleEstimatedTokenCount, + unwrappedUsageBefore.visibleEstimatedTokenCount + ) + XCTAssertLessThan( + unwrappedUsageAfter.effectiveEstimatedTokenCount, + unwrappedUsageBefore.effectiveEstimatedTokenCount + ) + XCTAssertGreaterThan(unwrappedUsageAfter.estimatedTokenSavings, 0) + } + func testContextCompactionConfigurationDefaultsAndCodableShape() throws { let configuration = AgentContextCompactionConfiguration() XCTAssertFalse(configuration.isEnabled) diff --git a/Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift b/Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift index 045a4e8..5d612d2 100644 --- a/Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift +++ b/Tests/CodexKitTests/AgentRuntimeHistoryTestSupport.swift @@ -163,8 +163,10 @@ private actor PartialEmissionGate { } } -actor CompactingTestBackend: AgentBackend, AgentBackendContextCompacting { +actor CompactingTestBackend: AgentBackend, AgentBackendContextCompacting, AgentBackendContextWindowProviding { nonisolated let baseInstructions: String? = nil + nonisolated let modelContextWindowTokenCount: Int? = 272_000 + nonisolated let usableContextWindowTokenCount: Int? = 258_400 private let failOnHistoryCountAbove: Int? private var threads: [String: AgentThread] = [:] diff --git a/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift index ac93ad3..d50a51e 100644 --- a/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeMessageBehaviorTests.swift @@ -303,16 +303,21 @@ extension AgentRuntimeTests { contextCompaction: .init( isEnabled: true, mode: .manual, - strategy: .localOnly + strategy: .preferRemoteThenLocal ) ) _ = try await runtime.restore() _ = try await runtime.signIn() let thread = try await runtime.createThread(title: "Observe Context") - _ = try await runtime.sendMessage(UserMessageRequest(text: "one"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "two"), in: thread.id) - _ = try await runtime.sendMessage(UserMessageRequest(text: "three"), in: thread.id) + let longMessages = [ + String(repeating: "first message context ", count: 30), + String(repeating: "second message context ", count: 30), + String(repeating: "third message context ", count: 30), + ] + for message in longMessages { + _ = try await runtime.sendMessage(UserMessageRequest(text: message), in: thread.id) + } let observedInitialState = expectation(description: "Observed the initial context state") let observedCompactedState = expectation(description: "Observed the compacted context state") @@ -345,6 +350,59 @@ extension AgentRuntimeTests { XCTAssertTrue(contexts.contains(where: { $0?.generation == 1 })) } + func testObserveThreadContextUsagePublishesInitialAndCompactedUsage() async throws { + let backend = CompactingTestBackend() + let runtime = try makeHistoryRuntime( + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + contextCompaction: .init( + isEnabled: true, + mode: .manual, + strategy: .preferRemoteThenLocal + ) + ) + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Observe Usage") + let longMessages = [ + String(repeating: "first message context ", count: 30), + String(repeating: "second message context ", count: 30), + String(repeating: "third message context ", count: 30), + ] + for message in longMessages { + _ = try await runtime.sendMessage(UserMessageRequest(text: message), in: thread.id) + } + + let observedInitialUsage = expectation(description: "Observed the initial context usage") + let observedCompactedUsage = expectation(description: "Observed compacted context usage") + observedCompactedUsage.assertForOverFulfill = false + var usages: [AgentThreadContextUsage?] = [] + var cancellables = Set() + + runtime.observeThreadContextUsage(id: thread.id) + .sink { usage in + usages.append(usage) + if let usage, + usage.visibleEstimatedTokenCount == usage.effectiveEstimatedTokenCount, + usage.visibleEstimatedTokenCount > 0 { + observedInitialUsage.fulfill() + } + if let usage, usage.estimatedTokenSavings > 0 { + observedCompactedUsage.fulfill() + } + } + .store(in: &cancellables) + + await fulfillment(of: [observedInitialUsage], timeout: 0.5) + + _ = try await runtime.compactThreadContext(id: thread.id) + + await fulfillment(of: [observedCompactedUsage], timeout: 0.5) + XCTAssertTrue(usages.contains(where: { ($0?.estimatedTokenSavings ?? 0) > 0 })) + } + func testSetTitlePublishesObservedThreadUpdateAndPersists() async throws { let stateStore = InMemoryRuntimeStateStore() let runtime = try AgentRuntime(configuration: .init( From c778a3b8c355dd23da07ce0a5370d8f17ec3fb4a Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Tue, 24 Mar 2026 19:44:18 +1100 Subject: [PATCH 19/19] Refresh docs for token-based context usage --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index aeb5e1e..af6c733 100644 --- a/README.md +++ b/README.md @@ -796,7 +796,7 @@ The demo app exercises: - Responses web search in checked-in configuration - thread-pinned personas and one-turn overrides - a one-tap skill policy probe that compares tool behavior in normal vs skill-constrained threads -- a thread-level `Context Compaction` card that shows visible-vs-effective message counts and lets you trigger manual compaction +- a thread-level `Context Compaction` card that shows visible-vs-effective token usage and lets you trigger manual compaction - a Health Coach tab with HealthKit steps, AI-generated coaching, local reminders, and tone switching - GRDB-backed runtime persistence with automatic import from older `runtime-state.json` state on first launch