From 2076a027e6a30cc05311f676cb12ddaa58355887 Mon Sep 17 00:00:00 2001 From: Timothy Zelinsky Date: Sun, 22 Mar 2026 19:22:24 +1100 Subject: [PATCH] Split runtime and backend files by concern --- .../Runtime/AgentRuntime+Memory.swift | 406 +++++ .../Runtime/AgentRuntime+Messaging.swift | 391 +++++ .../Runtime/AgentRuntime+Skills.swift | 205 +++ .../Runtime/AgentRuntime+Threads.swift | 163 ++ Sources/CodexKit/Runtime/AgentRuntime.swift | 1164 +------------ .../CodexResponsesBackend+Models.swift | 239 +++ .../CodexResponsesBackend+Streaming.swift | 537 ++++++ .../Runtime/CodexResponsesBackend.swift | 772 --------- .../AgentRuntimeMemoryTests.swift | 503 ++++++ .../AgentRuntimeMessageTests.swift | 334 ++++ .../AgentRuntimePersonaSkillTests.swift | 613 +++++++ Tests/CodexKitTests/AgentRuntimeTests.swift | 1540 +---------------- 12 files changed, 3434 insertions(+), 3433 deletions(-) create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+Memory.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+Skills.swift create mode 100644 Sources/CodexKit/Runtime/AgentRuntime+Threads.swift create mode 100644 Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift create mode 100644 Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift create mode 100644 Tests/CodexKitTests/AgentRuntimeMemoryTests.swift create mode 100644 Tests/CodexKitTests/AgentRuntimeMessageTests.swift create mode 100644 Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Memory.swift b/Sources/CodexKit/Runtime/AgentRuntime+Memory.swift new file mode 100644 index 0000000..484dc86 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+Memory.swift @@ -0,0 +1,406 @@ +import Foundation + +extension AgentRuntime { + // MARK: - Memory Previews + + public func memoryQueryPreview( + for threadID: String, + request: UserMessageRequest + ) async throws -> MemoryQueryResult? { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + return await resolvedMemoryQuery( + thread: thread, + message: request + ) + } + + // MARK: - Automatic Capture + + func automaticallyCaptureMemoriesIfConfigured( + for threadID: String, + userMessage: AgentMessage, + assistantMessages: [AgentMessage] + ) async { + guard let memoryConfiguration, + let policy = memoryConfiguration.automaticCapturePolicy + else { + return + } + + guard let thread = thread(for: threadID) else { + return + } + + if policy.requiresThreadMemoryContext, thread.memoryContext == nil { + return + } + + let source: MemoryCaptureSource + let sourceDescription: String + switch policy.source { + case .lastTurn: + let turnMessages = [userMessage] + assistantMessages.filter { $0.threadID == threadID } + guard turnMessages.contains(where: { $0.role == .assistant }) else { + return + } + source = .messages(turnMessages) + sourceDescription = "last_turn" + + case let .threadHistory(maxMessages): + source = .threadHistory(maxMessages: maxMessages) + sourceDescription = "thread_history_\(max(1, maxMessages))" + } + + if let observer = memoryConfiguration.observer { + await observer.handle( + event: .captureStarted( + threadID: threadID, + sourceDescription: sourceDescription + ) + ) + } + + do { + let result = try await captureMemories( + from: source, + for: threadID, + options: policy.options + ) + if let observer = memoryConfiguration.observer { + await observer.handle(event: .captureSucceeded(threadID: threadID, result: result)) + } + } catch { + if let observer = memoryConfiguration.observer { + await observer.handle( + event: .captureFailed( + threadID: threadID, + message: error.localizedDescription + ) + ) + } + } + } + + // MARK: - Memory Context + + public func memoryContext(for threadID: String) throws -> AgentMemoryContext? { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + return thread.memoryContext + } + + // MARK: - Memory Writing + + public func memoryWriter( + defaults: MemoryWriterDefaults = .init() + ) throws -> MemoryWriter { + guard let memoryConfiguration else { + throw AgentRuntimeError.memoryNotConfigured() + } + + return MemoryWriter( + store: memoryConfiguration.store, + defaults: defaults + ) + } + + public func memoryWriter( + for threadID: String, + defaults: MemoryWriterDefaults = .init() + ) throws -> MemoryWriter { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + let inheritedDefaults: MemoryWriterDefaults + if let memoryContext = thread.memoryContext { + inheritedDefaults = MemoryWriterDefaults( + namespace: memoryContext.namespace, + scope: memoryContext.scopes.count == 1 ? memoryContext.scopes[0] : nil, + kind: memoryContext.kinds.count == 1 ? memoryContext.kinds[0] : nil, + tags: memoryContext.tags, + relatedIDs: memoryContext.relatedIDs + ) + } else { + inheritedDefaults = .init() + } + + return try memoryWriter( + defaults: defaults.fillingMissingValues(from: inheritedDefaults) + ) + } + + // MARK: - Memory Capture + + public func captureMemories( + from source: MemoryCaptureSource = .threadHistory(), + for threadID: String, + options: MemoryCaptureOptions = .init(), + decoder: JSONDecoder = JSONDecoder() + ) async throws -> MemoryCaptureResult { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + let sourceText = formattedMemoryCaptureSource( + source, + threadID: threadID + ) + guard !sourceText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { + return MemoryCaptureResult( + sourceText: sourceText, + drafts: [], + records: [] + ) + } + + let writer = try memoryWriter( + for: threadID, + defaults: options.defaults + ) + let request = UserMessageRequest( + text: MemoryExtractionDraftResponse.prompt( + sourceText: sourceText, + maxMemories: max(1, options.maxMemories) + ) + ) + let session = try await sessionManager.requireSession() + let turnStart = try await beginTurnWithUnauthorizedRecovery( + thread: thread, + history: [], + message: request, + instructions: options.instructions ?? MemoryExtractionDraftResponse.instructions, + responseFormat: MemoryExtractionDraftResponse.responseFormat( + maxMemories: max(1, options.maxMemories) + ), + tools: [], + session: session + ) + let assistantMessage = try await collectFinalAssistantMessage( + from: turnStart.turnStream + ) + let payload = Data(assistantMessage.text.trimmingCharacters(in: .whitespacesAndNewlines).utf8) + + let extraction: MemoryExtractionDraftResponse + do { + extraction = try decoder.decode(MemoryExtractionDraftResponse.self, from: payload) + } catch { + throw AgentRuntimeError.structuredOutputDecodingFailed( + typeName: "MemoryExtractionDraftResponse", + underlyingMessage: error.localizedDescription + ) + } + + let drafts = extraction.memories.map(\.memoryDraft) + var records: [MemoryRecord] = [] + records.reserveCapacity(drafts.count) + for draft in drafts { + if draft.dedupeKey != nil { + records.append(try await writer.upsert(draft)) + } else { + records.append(try await writer.put(draft)) + } + } + + return MemoryCaptureResult( + sourceText: sourceText, + drafts: drafts, + records: records + ) + } + + // MARK: - Memory Formatting + + func formattedMemoryCaptureSource( + _ source: MemoryCaptureSource, + threadID: String + ) -> String { + switch source { + case let .threadHistory(maxMessages): + let history = Array((state.messagesByThread[threadID] ?? []).suffix(max(1, maxMessages))) + return formattedMemoryTranscript(from: history) + + case let .messages(messages): + return formattedMemoryTranscript(from: messages) + + case let .text(text): + return text.trimmingCharacters(in: .whitespacesAndNewlines) + } + } + + func formattedMemoryTranscript(from messages: [AgentMessage]) -> String { + messages + .map { message in + let role = message.role.rawValue.capitalized + let text = message.displayText.trimmingCharacters(in: .whitespacesAndNewlines) + + if text.isEmpty, !message.images.isEmpty { + return "\(role): [\(message.images.count) image attachment(s)]" + } + + return "\(role): \(text)" + } + .joined(separator: "\n") + } + + // MARK: - Memory Query Resolution + + func resolvedMemoryQuery( + thread: AgentThread, + message: UserMessageRequest + ) async -> MemoryQueryResult? { + guard let memoryConfiguration else { + return nil + } + + guard let query = resolvedMemoryQuery( + thread: thread, + message: message, + fallbackRanking: memoryConfiguration.defaultRanking, + fallbackBudget: memoryConfiguration.defaultReadBudget + ) else { + return nil + } + + if let observer = memoryConfiguration.observer { + await observer.handle(event: .queryStarted(query)) + } + + do { + let result = try await memoryConfiguration.store.query(query) + if let observer = memoryConfiguration.observer { + await observer.handle(event: .querySucceeded(query: query, result: result)) + } + return result + } catch { + if let observer = memoryConfiguration.observer { + await observer.handle( + event: .queryFailed( + query: query, + message: error.localizedDescription + ) + ) + } + return nil + } + } + + func resolvedMemoryQuery( + thread: AgentThread, + message: UserMessageRequest, + fallbackRanking: MemoryRankingWeights, + fallbackBudget: MemoryReadBudget + ) -> MemoryQuery? { + let selection = message.memorySelection + if selection?.mode == .disable { + return nil + } + + let threadContext = thread.memoryContext + let namespace = selection?.namespace ?? + threadContext?.namespace + + guard let namespace else { + return nil + } + + let scopes: [MemoryScope] + switch selection?.mode ?? .inherit { + case .append: + scopes = uniqueScopes((threadContext?.scopes ?? []) + (selection?.scopes ?? [])) + case .replace: + scopes = selection?.scopes ?? [] + case .disable: + return nil + case .inherit: + if let selection, + !selection.scopes.isEmpty { + scopes = selection.scopes + } else { + scopes = threadContext?.scopes ?? [] + } + } + + let kinds = resolvedValues( + mode: selection?.mode ?? .inherit, + threadValues: threadContext?.kinds ?? [], + selectionValues: selection?.kinds ?? [] + ) + let tags = resolvedValues( + mode: selection?.mode ?? .inherit, + threadValues: threadContext?.tags ?? [], + selectionValues: selection?.tags ?? [] + ) + let relatedIDs = resolvedValues( + mode: selection?.mode ?? .inherit, + threadValues: threadContext?.relatedIDs ?? [], + selectionValues: selection?.relatedIDs ?? [] + ) + + let recencyWindow = selection?.recencyWindow + ?? threadContext?.recencyWindow + let minImportance = selection?.minImportance + ?? threadContext?.minImportance + let ranking = selection?.ranking + ?? threadContext?.ranking + ?? fallbackRanking + let budget = resolvedMemoryBudget( + thread: thread, + message: message, + fallback: fallbackBudget + ) + let text = selection?.text ?? message.text + + return MemoryQuery( + namespace: namespace, + scopes: scopes, + text: text, + kinds: kinds, + tags: tags, + relatedIDs: relatedIDs, + recencyWindow: recencyWindow, + minImportance: minImportance, + ranking: ranking, + limit: budget.maxItems, + maxCharacters: budget.maxCharacters, + includeArchived: false + ) + } + + func resolvedMemoryBudget( + thread: AgentThread, + message: UserMessageRequest, + fallback: MemoryReadBudget + ) -> MemoryReadBudget { + message.memorySelection?.readBudget + ?? thread.memoryContext?.readBudget + ?? fallback + } + + func uniqueScopes(_ scopes: [MemoryScope]) -> [MemoryScope] { + var seen: Set = [] + return scopes.filter { seen.insert($0).inserted } + } + + func resolvedValues( + mode: MemorySelectionMode, + threadValues: [String], + selectionValues: [String] + ) -> [String] { + switch mode { + case .append: + return Array(Set(threadValues + selectionValues)).sorted() + case .replace: + return selectionValues + case .disable: + return [] + case .inherit: + return selectionValues.isEmpty ? threadValues : selectionValues + } + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift b/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift new file mode 100644 index 0000000..a90ae86 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+Messaging.swift @@ -0,0 +1,391 @@ +import Foundation + +extension AgentRuntime { + // MARK: - Messaging + + public func streamMessage( + _ request: UserMessageRequest, + in threadID: String + ) async throws -> AsyncThrowingStream { + try await streamMessage( + request, + in: threadID, + responseFormat: nil + ) + } + + public func sendMessage( + _ request: UserMessageRequest, + in threadID: String + ) async throws -> String { + let stream = try await streamMessage( + request, + in: threadID, + responseFormat: nil + ) + let message = try await collectFinalAssistantMessage(from: stream) + return message.displayText + } + + public func sendMessage( + _ request: UserMessageRequest, + in threadID: String, + expecting outputType: Output.Type = Output.self, + decoder: JSONDecoder = JSONDecoder() + ) async throws -> Output { + try await sendMessage( + request, + in: threadID, + expecting: outputType, + responseFormat: outputType.responseFormat, + decoder: decoder + ) + } + + public func sendMessage( + _ request: UserMessageRequest, + in threadID: String, + expecting outputType: Output.Type, + responseFormat: AgentStructuredOutputFormat, + decoder: JSONDecoder = JSONDecoder() + ) async throws -> Output { + let stream = try await streamMessage( + request, + in: threadID, + responseFormat: responseFormat + ) + let message = try await collectFinalAssistantMessage(from: stream) + let payload = Data(message.text.trimmingCharacters(in: .whitespacesAndNewlines).utf8) + + do { + return try decoder.decode(Output.self, from: payload) + } catch { + throw AgentRuntimeError.structuredOutputDecodingFailed( + typeName: String(describing: outputType), + underlyingMessage: error.localizedDescription + ) + } + } + + func streamMessage( + _ request: UserMessageRequest, + in threadID: String, + responseFormat: AgentStructuredOutputFormat? + ) async throws -> AsyncThrowingStream { + guard request.hasContent else { + throw AgentRuntimeError.invalidMessageContent() + } + + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + let session = try await sessionManager.requireSession() + let userMessage = AgentMessage( + threadID: threadID, + role: .user, + text: request.text, + images: request.images + ) + let priorMessages = state.messagesByThread[threadID] ?? [] + let resolvedTurnSkills = try resolveTurnSkills( + thread: thread, + message: request + ) + let resolvedInstructions = await resolveInstructions( + thread: thread, + message: request, + resolvedTurnSkills: resolvedTurnSkills + ) + + try await appendMessage(userMessage) + try await setThreadStatus(.streaming, for: threadID) + + let tools = await toolRegistry.allDefinitions() + let turnStart = try await beginTurnWithUnauthorizedRecovery( + thread: thread, + history: priorMessages, + message: request, + instructions: resolvedInstructions, + responseFormat: responseFormat, + tools: tools, + session: session + ) + let turnStream = turnStart.turnStream + let turnSession = turnStart.session + + return AsyncThrowingStream { continuation in + continuation.yield(.messageCommitted(userMessage)) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) + + Task { + await self.consumeTurnStream( + turnStream, + for: threadID, + userMessage: userMessage, + session: turnSession, + resolvedTurnSkills: resolvedTurnSkills, + continuation: continuation + ) + } + } + } + + func beginTurnWithUnauthorizedRecovery( + thread: AgentThread, + history: [AgentMessage], + message: UserMessageRequest, + instructions: String, + responseFormat: AgentStructuredOutputFormat?, + tools: [ToolDefinition], + session: ChatGPTSession + ) async throws -> ( + turnStream: any AgentTurnStreaming, + session: ChatGPTSession + ) { + let beginTurn = try await withUnauthorizedRecovery( + initialSession: session + ) { session in + try await backend.beginTurn( + thread: thread, + history: history, + message: message, + instructions: instructions, + responseFormat: responseFormat, + tools: tools, + session: session + ) + } + return (beginTurn.result, beginTurn.session) + } + + // MARK: - Previews + + public func resolvedInstructionsPreview( + for threadID: String, + request: UserMessageRequest + ) async throws -> String { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + let resolvedTurnSkills = try resolveTurnSkills( + thread: thread, + message: request + ) + + return await resolveInstructions( + thread: thread, + message: request, + resolvedTurnSkills: resolvedTurnSkills + ) + } + + // MARK: - Turn Consumption + + private func consumeTurnStream( + _ turnStream: any AgentTurnStreaming, + for threadID: String, + userMessage: AgentMessage, + session: ChatGPTSession, + resolvedTurnSkills: ResolvedTurnSkills, + continuation: AsyncThrowingStream.Continuation + ) async { + let policyTracker: TurnSkillPolicyTracker? = if resolvedTurnSkills.compiledToolPolicy.hasConstraints { + TurnSkillPolicyTracker(policy: resolvedTurnSkills.compiledToolPolicy) + } else { + nil + } + var assistantMessages: [AgentMessage] = [] + + do { + for try await backendEvent in turnStream.events { + switch backendEvent { + case let .turnStarted(turn): + continuation.yield(.turnStarted(turn)) + + case let .assistantMessageDelta(threadID, turnID, delta): + continuation.yield( + .assistantMessageDelta( + threadID: threadID, + turnID: turnID, + delta: delta + ) + ) + + case let .assistantMessageCompleted(message): + try await appendMessage(message) + if message.role == .assistant { + assistantMessages.append(message) + } + continuation.yield(.messageCommitted(message)) + + case let .toolCallRequested(invocation): + continuation.yield(.toolCallStarted(invocation)) + + let result: ToolResultEnvelope + if let policyTracker, + let validationError = policyTracker.validate(toolName: invocation.toolName) { + result = .failure( + invocation: invocation, + message: validationError.message + ) + } else { + let resolvedResult = try await resolveToolInvocation( + invocation, + session: session, + continuation: continuation + ) + result = resolvedResult + policyTracker?.recordAccepted(toolName: invocation.toolName) + } + + try await turnStream.submitToolResult(result, for: invocation.id) + continuation.yield(.toolCallFinished(result)) + try await setThreadStatus(.streaming, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) + + case let .turnCompleted(summary): + if let completionError = policyTracker?.completionError() { + try await setThreadStatus(.failed, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.turnFailed(completionError)) + continuation.finish(throwing: completionError) + return + } + + try await setThreadStatus(.idle, for: threadID) + await automaticallyCaptureMemoriesIfConfigured( + for: threadID, + userMessage: userMessage, + assistantMessages: assistantMessages + ) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .idle)) + continuation.yield(.turnCompleted(summary)) + } + } + + continuation.finish() + } catch { + let runtimeError = (error as? AgentRuntimeError) + ?? AgentRuntimeError( + code: "turn_failed", + message: error.localizedDescription + ) + try? await setThreadStatus(.failed, for: threadID) + continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) + continuation.yield(.turnFailed(runtimeError)) + continuation.finish(throwing: error) + } + } + + private func resolveToolInvocation( + _ invocation: ToolInvocation, + session: ChatGPTSession, + continuation: AsyncThrowingStream.Continuation + ) async throws -> ToolResultEnvelope { + if let definition = await toolRegistry.definition(named: invocation.toolName), + definition.approvalPolicy == .requiresApproval { + let approval = ApprovalRequest( + threadID: invocation.threadID, + turnID: invocation.turnID, + toolInvocation: invocation, + title: "Approve \(invocation.toolName)?", + message: definition.approvalMessage + ?? "This tool requires explicit approval before it can run." + ) + + try await setThreadStatus(.waitingForApproval, for: invocation.threadID) + continuation.yield( + .threadStatusChanged( + threadID: invocation.threadID, + status: .waitingForApproval + ) + ) + continuation.yield(.approvalRequested(approval)) + + let decision = try await approvalCoordinator.requestApproval(approval) + continuation.yield( + .approvalResolved( + ApprovalResolution( + requestID: approval.id, + threadID: approval.threadID, + turnID: approval.turnID, + decision: decision + ) + ) + ) + + guard decision == .approved else { + return .denied(invocation: invocation) + } + } + + try await setThreadStatus(.waitingForToolResult, for: invocation.threadID) + continuation.yield( + .threadStatusChanged( + threadID: invocation.threadID, + status: .waitingForToolResult + ) + ) + + return await toolRegistry.execute(invocation, session: session) + } + + // MARK: - Message Collection + + func collectFinalAssistantMessage( + from stream: AsyncThrowingStream + ) async throws -> AgentMessage { + var latestAssistantMessage: AgentMessage? + + for try await event in stream { + guard case let .messageCommitted(message) = event, + message.role == .assistant + else { + continue + } + + latestAssistantMessage = message + } + + guard let latestAssistantMessage else { + throw AgentRuntimeError.assistantResponseMissing() + } + + return latestAssistantMessage + } + + func collectFinalAssistantMessage( + from turnStream: any AgentTurnStreaming + ) async throws -> AgentMessage { + var latestAssistantMessage: AgentMessage? + + for try await event in turnStream.events { + switch event { + case let .assistantMessageCompleted(message): + if message.role == .assistant { + latestAssistantMessage = message + } + + case let .toolCallRequested(invocation): + try await turnStream.submitToolResult( + .failure( + invocation: invocation, + message: "Automatic memory capture does not allow tool calls." + ), + for: invocation.id + ) + + default: + break + } + } + + guard let latestAssistantMessage else { + throw AgentRuntimeError.assistantResponseMissing() + } + + return latestAssistantMessage + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Skills.swift b/Sources/CodexKit/Runtime/AgentRuntime+Skills.swift new file mode 100644 index 0000000..9192b21 --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+Skills.swift @@ -0,0 +1,205 @@ +import Foundation + +extension AgentRuntime { + // MARK: - Skills + + public func skills() -> [AgentSkill] { + skillsByID.values.sorted { $0.id < $1.id } + } + + public func skill(for skillID: String) -> AgentSkill? { + skillsByID[skillID] + } + + public func registerSkill(_ skill: AgentSkill) throws { + guard AgentSkill.isValidID(skill.id) else { + throw AgentRuntimeError.invalidSkillID(skill.id) + } + try Self.validateSkillExecutionPolicy(skill) + guard skillsByID[skill.id] == nil else { + throw AgentRuntimeError.duplicateSkill(skill.id) + } + + skillsByID[skill.id] = skill + } + + public func replaceSkill(_ skill: AgentSkill) throws { + guard AgentSkill.isValidID(skill.id) else { + throw AgentRuntimeError.invalidSkillID(skill.id) + } + try Self.validateSkillExecutionPolicy(skill) + + skillsByID[skill.id] = skill + } + + @discardableResult + public func registerSkill( + from source: AgentDefinitionSource, + id: String? = nil, + name: String? = nil + ) async throws -> AgentSkill { + let skill = try await definitionSourceLoader.loadSkill( + from: source, + id: id, + name: name + ) + try registerSkill(skill) + return skill + } + + @discardableResult + public func replaceSkill( + from source: AgentDefinitionSource, + id: String? = nil, + name: String? = nil + ) async throws -> AgentSkill { + let skill = try await definitionSourceLoader.loadSkill( + from: source, + id: id, + name: name + ) + try replaceSkill(skill) + return skill + } + + public func skillIDs(for threadID: String) throws -> [String] { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + return thread.skillIDs + } + + public func setSkillIDs( + _ skillIDs: [String], + for threadID: String + ) async throws { + guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + try assertSkillsExist(skillIDs) + + state.threads[index].skillIDs = skillIDs + state.threads[index].updatedAt = Date() + try await persistState() + } + + // MARK: - Skill Policy + + func resolveTurnSkills( + thread: AgentThread, + message: UserMessageRequest + ) throws -> ResolvedTurnSkills { + if let skillOverrideIDs = message.skillOverrideIDs { + try assertSkillsExist(skillOverrideIDs) + } + + let threadSkills = resolveSkills(for: thread.skillIDs) + let turnSkills = resolveSkills(for: message.skillOverrideIDs ?? []) + let allSkills = threadSkills + turnSkills + + return ResolvedTurnSkills( + threadSkills: threadSkills, + turnSkills: turnSkills, + compiledToolPolicy: compileToolPolicy(from: allSkills) + ) + } + + private func compileToolPolicy(from skills: [AgentSkill]) -> CompiledSkillToolPolicy { + var allowedToolNames: Set? + var requiredToolNames: Set = [] + var toolSequence: [String]? + var maxToolCalls: Int? + + for skill in skills { + guard let executionPolicy = skill.executionPolicy else { + continue + } + + if let allowed = executionPolicy.allowedToolNames, + !allowed.isEmpty { + let allowedSet = Set(allowed) + if let existingAllowed = allowedToolNames { + allowedToolNames = existingAllowed.intersection(allowedSet) + } else { + allowedToolNames = allowedSet + } + } + + if !executionPolicy.requiredToolNames.isEmpty { + requiredToolNames.formUnion(executionPolicy.requiredToolNames) + } + + if let sequence = executionPolicy.toolSequence, + !sequence.isEmpty { + toolSequence = sequence + } + + if let maxCalls = executionPolicy.maxToolCalls { + if let existingMaxCalls = maxToolCalls { + maxToolCalls = min(existingMaxCalls, maxCalls) + } else { + maxToolCalls = maxCalls + } + } + } + + return CompiledSkillToolPolicy( + allowedToolNames: allowedToolNames, + requiredToolNames: requiredToolNames, + toolSequence: toolSequence, + maxToolCalls: maxToolCalls + ) + } + + private func resolveSkills(for skillIDs: [String]) -> [AgentSkill] { + skillIDs.compactMap { skillsByID[$0] } + } + + func assertSkillsExist(_ skillIDs: [String]) throws { + let missing = Array(Set(skillIDs.filter { skillsByID[$0] == nil })).sorted() + guard missing.isEmpty else { + throw AgentRuntimeError.skillsNotFound(missing) + } + } + + static func validatedSkills(from skills: [AgentSkill]) throws -> [String: AgentSkill] { + var dictionary: [String: AgentSkill] = [:] + for skill in skills { + guard AgentSkill.isValidID(skill.id) else { + throw AgentRuntimeError.invalidSkillID(skill.id) + } + try validateSkillExecutionPolicy(skill) + guard dictionary[skill.id] == nil else { + throw AgentRuntimeError.duplicateSkill(skill.id) + } + dictionary[skill.id] = skill + } + return dictionary + } + + static func validateSkillExecutionPolicy(_ skill: AgentSkill) throws { + guard let executionPolicy = skill.executionPolicy else { + return + } + + if let maxToolCalls = executionPolicy.maxToolCalls, + maxToolCalls < 0 { + throw AgentRuntimeError.invalidSkillMaxToolCalls(skillID: skill.id) + } + + let policyToolNames: [String] = + (executionPolicy.allowedToolNames ?? []) + + executionPolicy.requiredToolNames + + (executionPolicy.toolSequence ?? []) + + for toolName in policyToolNames { + guard ToolDefinition.isValidName(toolName) else { + throw AgentRuntimeError.invalidSkillToolName( + skillID: skill.id, + toolName: toolName + ) + } + } + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift new file mode 100644 index 0000000..f32669e --- /dev/null +++ b/Sources/CodexKit/Runtime/AgentRuntime+Threads.swift @@ -0,0 +1,163 @@ +import Foundation + +extension AgentRuntime { + // MARK: - Threads + + @discardableResult + public func createThread( + title: String? = nil, + personaStack: AgentPersonaStack? = nil, + personaSource: AgentDefinitionSource? = nil, + skillIDs: [String] = [], + memoryContext: AgentMemoryContext? = nil + ) async throws -> AgentThread { + try assertSkillsExist(skillIDs) + let resolvedPersonaStack: AgentPersonaStack? + if let personaStack { + resolvedPersonaStack = personaStack + } else if let personaSource { + resolvedPersonaStack = try await definitionSourceLoader.loadPersonaStack(from: personaSource) + } else { + resolvedPersonaStack = nil + } + + let session = try await sessionManager.requireSession() + let creation = try await withUnauthorizedRecovery( + initialSession: session + ) { session in + try await backend.createThread(session: session) + } + var thread = creation.result + if let title { + thread.title = title + } + thread.personaStack = resolvedPersonaStack + thread.skillIDs = skillIDs + thread.memoryContext = memoryContext + try await upsertThread(thread) + return thread + } + + @discardableResult + public func resumeThread(id: String) async throws -> AgentThread { + let session = try await sessionManager.requireSession() + let resume = try await withUnauthorizedRecovery( + initialSession: session + ) { session in + try await backend.resumeThread(id: id, session: session) + } + let thread = resume.result + try await upsertThread(thread) + return thread + } + + // MARK: - Thread Configuration + + func thread(for threadID: String) -> AgentThread? { + state.threads.first { $0.id == threadID } + } + + public func personaStack(for threadID: String) throws -> AgentPersonaStack? { + guard let thread = thread(for: threadID) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + return thread.personaStack + } + + public func setPersonaStack( + _ personaStack: AgentPersonaStack?, + for threadID: String + ) async throws { + guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + state.threads[index].personaStack = personaStack + state.threads[index].updatedAt = Date() + try await persistState() + } + + @discardableResult + public func setPersonaStack( + from source: AgentDefinitionSource, + for threadID: String, + defaultLayerName: String = "dynamic_persona" + ) async throws -> AgentPersonaStack { + let personaStack = try await definitionSourceLoader.loadPersonaStack( + from: source, + defaultLayerName: defaultLayerName + ) + try await setPersonaStack(personaStack, for: threadID) + return personaStack + } + + public func setMemoryContext( + _ memoryContext: AgentMemoryContext?, + for threadID: String + ) async throws { + guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + state.threads[index].memoryContext = memoryContext + state.threads[index].updatedAt = Date() + try await persistState() + } + + // MARK: - State Mutation + + func upsertThread(_ thread: AgentThread) async throws { + if let index = state.threads.firstIndex(where: { $0.id == thread.id }) { + var mergedThread = thread + if mergedThread.title == nil { + mergedThread.title = state.threads[index].title + } + if mergedThread.personaStack == nil { + mergedThread.personaStack = state.threads[index].personaStack + } + if mergedThread.skillIDs.isEmpty { + mergedThread.skillIDs = state.threads[index].skillIDs + } + if mergedThread.memoryContext == nil { + mergedThread.memoryContext = state.threads[index].memoryContext + } + state.threads[index] = mergedThread + } else { + state.threads.append(thread) + } + try await persistState() + } + + func setThreadStatus( + _ status: AgentThreadStatus, + for threadID: String + ) async throws { + guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { + throw AgentRuntimeError.threadNotFound(threadID) + } + + state.threads[index].status = status + state.threads[index].updatedAt = Date() + try await persistState() + } + + func appendMessage(_ message: AgentMessage) async throws { + state.messagesByThread[message.threadID, default: []].append(message) + + if let index = state.threads.firstIndex(where: { $0.id == message.threadID }) { + state.threads[index].updatedAt = message.createdAt + if state.threads[index].title == nil, message.role == .user { + if !message.text.isEmpty { + state.threads[index].title = String(message.text.prefix(48)) + } else if !message.images.isEmpty { + state.threads[index].title = message.images.count == 1 + ? "Image message" + : "Image message (\(message.images.count))" + } + } + } + + try await persistState() + } +} diff --git a/Sources/CodexKit/Runtime/AgentRuntime.swift b/Sources/CodexKit/Runtime/AgentRuntime.swift index abd5a13..90668d3 100644 --- a/Sources/CodexKit/Runtime/AgentRuntime.swift +++ b/Sources/CodexKit/Runtime/AgentRuntime.swift @@ -1,6 +1,8 @@ import Foundation public actor AgentRuntime { + // MARK: - Configuration + public struct ToolRegistration: Sendable { public let definition: ToolDefinition public let executor: AnyToolExecutor @@ -51,25 +53,25 @@ public actor AgentRuntime { } } - private let backend: any AgentBackend - private let stateStore: any RuntimeStateStoring - private let sessionManager: ChatGPTSessionManager - private let toolRegistry: ToolRegistry - private let approvalCoordinator: ApprovalCoordinator - private let memoryConfiguration: AgentMemoryConfiguration? - private let baseInstructions: String? - private let definitionSourceLoader: AgentDefinitionSourceLoader - private var skillsByID: [String: AgentSkill] + let backend: any AgentBackend + let stateStore: any RuntimeStateStoring + let sessionManager: ChatGPTSessionManager + let toolRegistry: ToolRegistry + let approvalCoordinator: ApprovalCoordinator + let memoryConfiguration: AgentMemoryConfiguration? + let baseInstructions: String? + let definitionSourceLoader: AgentDefinitionSourceLoader + var skillsByID: [String: AgentSkill] - private var state: StoredRuntimeState = .empty + var state: StoredRuntimeState = .empty - private struct ResolvedTurnSkills { + struct ResolvedTurnSkills { let threadSkills: [AgentSkill] let turnSkills: [AgentSkill] let compiledToolPolicy: CompiledSkillToolPolicy } - private struct CompiledSkillToolPolicy { + struct CompiledSkillToolPolicy { let allowedToolNames: Set? let requiredToolNames: Set let toolSequence: [String]? @@ -83,7 +85,7 @@ public actor AgentRuntime { } } - private final class TurnSkillPolicyTracker: @unchecked Sendable { + final class TurnSkillPolicyTracker: @unchecked Sendable { private let policy: CompiledSkillToolPolicy private var toolCallsCount = 0 private var usedToolNames: Set = [] @@ -146,6 +148,8 @@ public actor AgentRuntime { } } + // MARK: - Lifecycle + public init(configuration: Configuration) throws { self.backend = configuration.backend self.stateStore = configuration.stateStore @@ -183,6 +187,8 @@ public actor AgentRuntime { try await sessionManager.signOut() } + // MARK: - Read State + public func threads() -> [AgentThread] { state.threads.sorted { $0.updatedAt > $1.updatedAt } } @@ -191,6 +197,8 @@ public actor AgentRuntime { state.messagesByThread[threadID] ?? [] } + // MARK: - Tools + public func registerTool( _ definition: ToolDefinition, executor: AnyToolExecutor @@ -205,861 +213,13 @@ public actor AgentRuntime { try await toolRegistry.replace(definition, executor: executor) } - public func skills() -> [AgentSkill] { - skillsByID.values.sorted { $0.id < $1.id } - } - - public func skill(for skillID: String) -> AgentSkill? { - skillsByID[skillID] - } - - public func registerSkill(_ skill: AgentSkill) throws { - guard AgentSkill.isValidID(skill.id) else { - throw AgentRuntimeError.invalidSkillID(skill.id) - } - try Self.validateSkillExecutionPolicy(skill) - guard skillsByID[skill.id] == nil else { - throw AgentRuntimeError.duplicateSkill(skill.id) - } - - skillsByID[skill.id] = skill - } - - public func replaceSkill(_ skill: AgentSkill) throws { - guard AgentSkill.isValidID(skill.id) else { - throw AgentRuntimeError.invalidSkillID(skill.id) - } - try Self.validateSkillExecutionPolicy(skill) - - skillsByID[skill.id] = skill - } - - @discardableResult - public func registerSkill( - from source: AgentDefinitionSource, - id: String? = nil, - name: String? = nil - ) async throws -> AgentSkill { - let skill = try await definitionSourceLoader.loadSkill( - from: source, - id: id, - name: name - ) - try registerSkill(skill) - return skill - } - - @discardableResult - public func replaceSkill( - from source: AgentDefinitionSource, - id: String? = nil, - name: String? = nil - ) async throws -> AgentSkill { - let skill = try await definitionSourceLoader.loadSkill( - from: source, - id: id, - name: name - ) - try replaceSkill(skill) - return skill - } - - @discardableResult - public func createThread( - title: String? = nil, - personaStack: AgentPersonaStack? = nil, - personaSource: AgentDefinitionSource? = nil, - skillIDs: [String] = [], - memoryContext: AgentMemoryContext? = nil - ) async throws -> AgentThread { - try assertSkillsExist(skillIDs) - let resolvedPersonaStack: AgentPersonaStack? - if let personaStack { - resolvedPersonaStack = personaStack - } else if let personaSource { - resolvedPersonaStack = try await definitionSourceLoader.loadPersonaStack(from: personaSource) - } else { - resolvedPersonaStack = nil - } - - let session = try await sessionManager.requireSession() - let creation = try await withUnauthorizedRecovery( - initialSession: session - ) { session in - try await backend.createThread(session: session) - } - var thread = creation.result - if let title { - thread.title = title - } - thread.personaStack = resolvedPersonaStack - thread.skillIDs = skillIDs - thread.memoryContext = memoryContext - try await upsertThread(thread) - return thread - } - - @discardableResult - public func resumeThread(id: String) async throws -> AgentThread { - let session = try await sessionManager.requireSession() - let resume = try await withUnauthorizedRecovery( - initialSession: session - ) { session in - try await backend.resumeThread(id: id, session: session) - } - let thread = resume.result - try await upsertThread(thread) - return thread - } - - public func streamMessage( - _ request: UserMessageRequest, - in threadID: String - ) async throws -> AsyncThrowingStream { - try await streamMessage( - request, - in: threadID, - responseFormat: nil - ) - } - - public func sendMessage( - _ request: UserMessageRequest, - in threadID: String - ) async throws -> String { - let stream = try await streamMessage( - request, - in: threadID, - responseFormat: nil - ) - let message = try await collectFinalAssistantMessage(from: stream) - return message.displayText - } - - public func sendMessage( - _ request: UserMessageRequest, - in threadID: String, - expecting outputType: Output.Type = Output.self, - decoder: JSONDecoder = JSONDecoder() - ) async throws -> Output { - try await sendMessage( - request, - in: threadID, - expecting: outputType, - responseFormat: outputType.responseFormat, - decoder: decoder - ) - } - - public func sendMessage( - _ request: UserMessageRequest, - in threadID: String, - expecting outputType: Output.Type, - responseFormat: AgentStructuredOutputFormat, - decoder: JSONDecoder = JSONDecoder() - ) async throws -> Output { - let stream = try await streamMessage( - request, - in: threadID, - responseFormat: responseFormat - ) - let message = try await collectFinalAssistantMessage(from: stream) - let payload = Data(message.text.trimmingCharacters(in: .whitespacesAndNewlines).utf8) - - do { - return try decoder.decode(Output.self, from: payload) - } catch { - throw AgentRuntimeError.structuredOutputDecodingFailed( - typeName: String(describing: outputType), - underlyingMessage: error.localizedDescription - ) - } - } - - private func streamMessage( - _ request: UserMessageRequest, - in threadID: String, - responseFormat: AgentStructuredOutputFormat? - ) async throws -> AsyncThrowingStream { - guard request.hasContent else { - throw AgentRuntimeError.invalidMessageContent() - } - - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - let session = try await sessionManager.requireSession() - let userMessage = AgentMessage( - threadID: threadID, - role: .user, - text: request.text, - images: request.images - ) - let priorMessages = state.messagesByThread[threadID] ?? [] - let resolvedTurnSkills = try resolveTurnSkills( - thread: thread, - message: request - ) - let resolvedInstructions = await resolveInstructions( - thread: thread, - message: request, - resolvedTurnSkills: resolvedTurnSkills - ) - - try await appendMessage(userMessage) - try await setThreadStatus(.streaming, for: threadID) - - let tools = await toolRegistry.allDefinitions() - let turnStart = try await beginTurnWithUnauthorizedRecovery( - thread: thread, - history: priorMessages, - message: request, - instructions: resolvedInstructions, - responseFormat: responseFormat, - tools: tools, - session: session - ) - let turnStream = turnStart.turnStream - let turnSession = turnStart.session - - return AsyncThrowingStream { continuation in - continuation.yield(.messageCommitted(userMessage)) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) - - Task { - await self.consumeTurnStream( - turnStream, - for: threadID, - userMessage: userMessage, - session: turnSession, - resolvedTurnSkills: resolvedTurnSkills, - continuation: continuation - ) - } - } - } - - private func beginTurnWithUnauthorizedRecovery( - thread: AgentThread, - history: [AgentMessage], - message: UserMessageRequest, - instructions: String, - responseFormat: AgentStructuredOutputFormat?, - tools: [ToolDefinition], - session: ChatGPTSession - ) async throws -> ( - turnStream: any AgentTurnStreaming, - session: ChatGPTSession - ) { - let beginTurn = try await withUnauthorizedRecovery( - initialSession: session - ) { session in - try await backend.beginTurn( - thread: thread, - history: history, - message: message, - instructions: instructions, - responseFormat: responseFormat, - tools: tools, - session: session - ) - } - return (beginTurn.result, beginTurn.session) - } - - public func resolvedInstructionsPreview( - for threadID: String, - request: UserMessageRequest - ) async throws -> String { - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - let resolvedTurnSkills = try resolveTurnSkills( - thread: thread, - message: request - ) - - return await resolveInstructions( - thread: thread, - message: request, - resolvedTurnSkills: resolvedTurnSkills - ) - } - - public func memoryQueryPreview( - for threadID: String, - request: UserMessageRequest - ) async throws -> MemoryQueryResult? { - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - return await resolvedMemoryQuery( - thread: thread, - message: request - ) - } - - private func consumeTurnStream( - _ turnStream: any AgentTurnStreaming, - for threadID: String, - userMessage: AgentMessage, - session: ChatGPTSession, - resolvedTurnSkills: ResolvedTurnSkills, - continuation: AsyncThrowingStream.Continuation - ) async { - let policyTracker: TurnSkillPolicyTracker? = if resolvedTurnSkills.compiledToolPolicy.hasConstraints { - TurnSkillPolicyTracker(policy: resolvedTurnSkills.compiledToolPolicy) - } else { - nil - } - var assistantMessages: [AgentMessage] = [] - - do { - for try await backendEvent in turnStream.events { - switch backendEvent { - case let .turnStarted(turn): - continuation.yield(.turnStarted(turn)) - - case let .assistantMessageDelta(threadID, turnID, delta): - continuation.yield( - .assistantMessageDelta( - threadID: threadID, - turnID: turnID, - delta: delta - ) - ) - - case let .assistantMessageCompleted(message): - try await appendMessage(message) - if message.role == .assistant { - assistantMessages.append(message) - } - continuation.yield(.messageCommitted(message)) - - case let .toolCallRequested(invocation): - continuation.yield(.toolCallStarted(invocation)) - - let result: ToolResultEnvelope - if let policyTracker, - let validationError = policyTracker.validate(toolName: invocation.toolName) { - result = .failure( - invocation: invocation, - message: validationError.message - ) - } else { - let resolvedResult = try await resolveToolInvocation( - invocation, - session: session, - continuation: continuation - ) - result = resolvedResult - policyTracker?.recordAccepted(toolName: invocation.toolName) - } - - try await turnStream.submitToolResult(result, for: invocation.id) - continuation.yield(.toolCallFinished(result)) - try await setThreadStatus(.streaming, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .streaming)) - - case let .turnCompleted(summary): - if let completionError = policyTracker?.completionError() { - try await setThreadStatus(.failed, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) - continuation.yield(.turnFailed(completionError)) - continuation.finish(throwing: completionError) - return - } - - try await setThreadStatus(.idle, for: threadID) - await automaticallyCaptureMemoriesIfConfigured( - for: threadID, - userMessage: userMessage, - assistantMessages: assistantMessages - ) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .idle)) - continuation.yield(.turnCompleted(summary)) - } - } - - continuation.finish() - } catch { - let runtimeError = (error as? AgentRuntimeError) - ?? AgentRuntimeError( - code: "turn_failed", - message: error.localizedDescription - ) - try? await setThreadStatus(.failed, for: threadID) - continuation.yield(.threadStatusChanged(threadID: threadID, status: .failed)) - continuation.yield(.turnFailed(runtimeError)) - continuation.finish(throwing: error) - } - } - - private func automaticallyCaptureMemoriesIfConfigured( - for threadID: String, - userMessage: AgentMessage, - assistantMessages: [AgentMessage] - ) async { - guard let memoryConfiguration, - let policy = memoryConfiguration.automaticCapturePolicy - else { - return - } - - guard let thread = thread(for: threadID) else { - return - } - - if policy.requiresThreadMemoryContext, thread.memoryContext == nil { - return - } - - let source: MemoryCaptureSource - let sourceDescription: String - switch policy.source { - case .lastTurn: - let turnMessages = [userMessage] + assistantMessages.filter { $0.threadID == threadID } - guard turnMessages.contains(where: { $0.role == .assistant }) else { - return - } - source = .messages(turnMessages) - sourceDescription = "last_turn" - - case let .threadHistory(maxMessages): - source = .threadHistory(maxMessages: maxMessages) - sourceDescription = "thread_history_\(max(1, maxMessages))" - } - - if let observer = memoryConfiguration.observer { - await observer.handle( - event: .captureStarted( - threadID: threadID, - sourceDescription: sourceDescription - ) - ) - } - - do { - let result = try await captureMemories( - from: source, - for: threadID, - options: policy.options - ) - if let observer = memoryConfiguration.observer { - await observer.handle(event: .captureSucceeded(threadID: threadID, result: result)) - } - } catch { - if let observer = memoryConfiguration.observer { - await observer.handle( - event: .captureFailed( - threadID: threadID, - message: error.localizedDescription - ) - ) - } - } - } - - private func resolveToolInvocation( - _ invocation: ToolInvocation, - session: ChatGPTSession, - continuation: AsyncThrowingStream.Continuation - ) async throws -> ToolResultEnvelope { - if let definition = await toolRegistry.definition(named: invocation.toolName), - definition.approvalPolicy == .requiresApproval { - let approval = ApprovalRequest( - threadID: invocation.threadID, - turnID: invocation.turnID, - toolInvocation: invocation, - title: "Approve \(invocation.toolName)?", - message: definition.approvalMessage - ?? "This tool requires explicit approval before it can run." - ) - - try await setThreadStatus(.waitingForApproval, for: invocation.threadID) - continuation.yield( - .threadStatusChanged( - threadID: invocation.threadID, - status: .waitingForApproval - ) - ) - continuation.yield(.approvalRequested(approval)) - - let decision = try await approvalCoordinator.requestApproval(approval) - continuation.yield( - .approvalResolved( - ApprovalResolution( - requestID: approval.id, - threadID: approval.threadID, - turnID: approval.turnID, - decision: decision - ) - ) - ) - - guard decision == .approved else { - return .denied(invocation: invocation) - } - } - - try await setThreadStatus(.waitingForToolResult, for: invocation.threadID) - continuation.yield( - .threadStatusChanged( - threadID: invocation.threadID, - status: .waitingForToolResult - ) - ) - - return await toolRegistry.execute(invocation, session: session) - } - - private func thread(for threadID: String) -> AgentThread? { - state.threads.first { $0.id == threadID } - } - - public func personaStack(for threadID: String) throws -> AgentPersonaStack? { - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - return thread.personaStack - } - - public func setPersonaStack( - _ personaStack: AgentPersonaStack?, - for threadID: String - ) async throws { - guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - state.threads[index].personaStack = personaStack - state.threads[index].updatedAt = Date() - try await persistState() - } - - @discardableResult - public func setPersonaStack( - from source: AgentDefinitionSource, - for threadID: String, - defaultLayerName: String = "dynamic_persona" - ) async throws -> AgentPersonaStack { - let personaStack = try await definitionSourceLoader.loadPersonaStack( - from: source, - defaultLayerName: defaultLayerName - ) - try await setPersonaStack(personaStack, for: threadID) - return personaStack - } - - public func skillIDs(for threadID: String) throws -> [String] { - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - return thread.skillIDs - } - - public func memoryContext(for threadID: String) throws -> AgentMemoryContext? { - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - return thread.memoryContext - } - - public func memoryWriter( - defaults: MemoryWriterDefaults = .init() - ) throws -> MemoryWriter { - guard let memoryConfiguration else { - throw AgentRuntimeError.memoryNotConfigured() - } - - return MemoryWriter( - store: memoryConfiguration.store, - defaults: defaults - ) - } - - public func memoryWriter( - for threadID: String, - defaults: MemoryWriterDefaults = .init() - ) throws -> MemoryWriter { - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - let inheritedDefaults: MemoryWriterDefaults - if let memoryContext = thread.memoryContext { - inheritedDefaults = MemoryWriterDefaults( - namespace: memoryContext.namespace, - scope: memoryContext.scopes.count == 1 ? memoryContext.scopes[0] : nil, - kind: memoryContext.kinds.count == 1 ? memoryContext.kinds[0] : nil, - tags: memoryContext.tags, - relatedIDs: memoryContext.relatedIDs - ) - } else { - inheritedDefaults = .init() - } - - return try memoryWriter( - defaults: defaults.fillingMissingValues(from: inheritedDefaults) - ) - } - - public func captureMemories( - from source: MemoryCaptureSource = .threadHistory(), - for threadID: String, - options: MemoryCaptureOptions = .init(), - decoder: JSONDecoder = JSONDecoder() - ) async throws -> MemoryCaptureResult { - guard let thread = thread(for: threadID) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - let sourceText = formattedMemoryCaptureSource( - source, - threadID: threadID - ) - guard !sourceText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { - return MemoryCaptureResult( - sourceText: sourceText, - drafts: [], - records: [] - ) - } - - let writer = try memoryWriter( - for: threadID, - defaults: options.defaults - ) - let request = UserMessageRequest( - text: MemoryExtractionDraftResponse.prompt( - sourceText: sourceText, - maxMemories: max(1, options.maxMemories) - ) - ) - let session = try await sessionManager.requireSession() - let turnStart = try await beginTurnWithUnauthorizedRecovery( - thread: thread, - history: [], - message: request, - instructions: options.instructions ?? MemoryExtractionDraftResponse.instructions, - responseFormat: MemoryExtractionDraftResponse.responseFormat( - maxMemories: max(1, options.maxMemories) - ), - tools: [], - session: session - ) - let assistantMessage = try await collectFinalAssistantMessage( - from: turnStart.turnStream - ) - let payload = Data(assistantMessage.text.trimmingCharacters(in: .whitespacesAndNewlines).utf8) - - let extraction: MemoryExtractionDraftResponse - do { - extraction = try decoder.decode(MemoryExtractionDraftResponse.self, from: payload) - } catch { - throw AgentRuntimeError.structuredOutputDecodingFailed( - typeName: "MemoryExtractionDraftResponse", - underlyingMessage: error.localizedDescription - ) - } + // MARK: - Instruction Resolution - let drafts = extraction.memories.map(\.memoryDraft) - var records: [MemoryRecord] = [] - records.reserveCapacity(drafts.count) - for draft in drafts { - if draft.dedupeKey != nil { - records.append(try await writer.upsert(draft)) - } else { - records.append(try await writer.put(draft)) - } - } - - return MemoryCaptureResult( - sourceText: sourceText, - drafts: drafts, - records: records - ) - } - - public func setSkillIDs( - _ skillIDs: [String], - for threadID: String - ) async throws { - guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - try assertSkillsExist(skillIDs) - - state.threads[index].skillIDs = skillIDs - state.threads[index].updatedAt = Date() - try await persistState() - } - - public func setMemoryContext( - _ memoryContext: AgentMemoryContext?, - for threadID: String - ) async throws { - guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - state.threads[index].memoryContext = memoryContext - state.threads[index].updatedAt = Date() - try await persistState() - } - - private func upsertThread(_ thread: AgentThread) async throws { - if let index = state.threads.firstIndex(where: { $0.id == thread.id }) { - var mergedThread = thread - if mergedThread.title == nil { - mergedThread.title = state.threads[index].title - } - if mergedThread.personaStack == nil { - mergedThread.personaStack = state.threads[index].personaStack - } - if mergedThread.skillIDs.isEmpty { - mergedThread.skillIDs = state.threads[index].skillIDs - } - if mergedThread.memoryContext == nil { - mergedThread.memoryContext = state.threads[index].memoryContext - } - state.threads[index] = mergedThread - } else { - state.threads.append(thread) - } - try await persistState() - } - - private func setThreadStatus( - _ status: AgentThreadStatus, - for threadID: String - ) async throws { - guard let index = state.threads.firstIndex(where: { $0.id == threadID }) else { - throw AgentRuntimeError.threadNotFound(threadID) - } - - state.threads[index].status = status - state.threads[index].updatedAt = Date() - try await persistState() - } - - private func appendMessage(_ message: AgentMessage) async throws { - state.messagesByThread[message.threadID, default: []].append(message) - - if let index = state.threads.firstIndex(where: { $0.id == message.threadID }) { - state.threads[index].updatedAt = message.createdAt - if state.threads[index].title == nil, message.role == .user { - if !message.text.isEmpty { - state.threads[index].title = String(message.text.prefix(48)) - } else if !message.images.isEmpty { - state.threads[index].title = message.images.count == 1 - ? "Image message" - : "Image message (\(message.images.count))" - } - } - } - - try await persistState() - } - - private func collectFinalAssistantMessage( - from stream: AsyncThrowingStream - ) async throws -> AgentMessage { - var latestAssistantMessage: AgentMessage? - - for try await event in stream { - guard case let .messageCommitted(message) = event, - message.role == .assistant - else { - continue - } - - latestAssistantMessage = message - } - - guard let latestAssistantMessage else { - throw AgentRuntimeError.assistantResponseMissing() - } - - return latestAssistantMessage - } - - private func collectFinalAssistantMessage( - from turnStream: any AgentTurnStreaming - ) async throws -> AgentMessage { - var latestAssistantMessage: AgentMessage? - - for try await event in turnStream.events { - switch event { - case let .assistantMessageCompleted(message): - if message.role == .assistant { - latestAssistantMessage = message - } - - case let .toolCallRequested(invocation): - try await turnStream.submitToolResult( - .failure( - invocation: invocation, - message: "Automatic memory capture does not allow tool calls." - ), - for: invocation.id - ) - - default: - break - } - } - - guard let latestAssistantMessage else { - throw AgentRuntimeError.assistantResponseMissing() - } - - return latestAssistantMessage - } - - private func formattedMemoryCaptureSource( - _ source: MemoryCaptureSource, - threadID: String - ) -> String { - switch source { - case let .threadHistory(maxMessages): - let history = Array((state.messagesByThread[threadID] ?? []).suffix(max(1, maxMessages))) - return formattedMemoryTranscript(from: history) - - case let .messages(messages): - return formattedMemoryTranscript(from: messages) - - case let .text(text): - return text.trimmingCharacters(in: .whitespacesAndNewlines) - } - } - - private func formattedMemoryTranscript(from messages: [AgentMessage]) -> String { - messages - .map { message in - let role = message.role.rawValue.capitalized - let text = message.displayText.trimmingCharacters(in: .whitespacesAndNewlines) - - if text.isEmpty, !message.images.isEmpty { - return "\(role): [\(message.images.count) image attachment(s)]" - } - - return "\(role): \(text)" - } - .joined(separator: "\n") - } - - private func persistState() async throws { + func persistState() async throws { try await stateStore.saveState(state) } - private func resolveInstructions( + func resolveInstructions( thread: AgentThread, message: UserMessageRequest, resolvedTurnSkills: ResolvedTurnSkills @@ -1105,11 +265,13 @@ public actor AgentRuntime { """ } - private static func isUnauthorizedError(_ error: Error) -> Bool { + // MARK: - Auth Recovery + + static func isUnauthorizedError(_ error: Error) -> Bool { (error as? AgentRuntimeError)?.code == AgentRuntimeError.unauthorized().code } - private func withUnauthorizedRecovery( + func withUnauthorizedRecovery( initialSession: ChatGPTSession, operation: (ChatGPTSession) async throws -> Result ) async throws -> ( @@ -1130,274 +292,4 @@ public actor AgentRuntime { } } - private func resolveTurnSkills( - thread: AgentThread, - message: UserMessageRequest - ) throws -> ResolvedTurnSkills { - if let skillOverrideIDs = message.skillOverrideIDs { - try assertSkillsExist(skillOverrideIDs) - } - - let threadSkills = resolveSkills(for: thread.skillIDs) - let turnSkills = resolveSkills(for: message.skillOverrideIDs ?? []) - let allSkills = threadSkills + turnSkills - - return ResolvedTurnSkills( - threadSkills: threadSkills, - turnSkills: turnSkills, - compiledToolPolicy: compileToolPolicy(from: allSkills) - ) - } - - private func resolvedMemoryQuery( - thread: AgentThread, - message: UserMessageRequest - ) async -> MemoryQueryResult? { - guard let memoryConfiguration else { - return nil - } - - guard let query = resolvedMemoryQuery( - thread: thread, - message: message, - fallbackRanking: memoryConfiguration.defaultRanking, - fallbackBudget: memoryConfiguration.defaultReadBudget - ) else { - return nil - } - - if let observer = memoryConfiguration.observer { - await observer.handle(event: .queryStarted(query)) - } - - do { - let result = try await memoryConfiguration.store.query(query) - if let observer = memoryConfiguration.observer { - await observer.handle(event: .querySucceeded(query: query, result: result)) - } - return result - } catch { - if let observer = memoryConfiguration.observer { - await observer.handle( - event: .queryFailed( - query: query, - message: error.localizedDescription - ) - ) - } - return nil - } - } - - private func resolvedMemoryQuery( - thread: AgentThread, - message: UserMessageRequest, - fallbackRanking: MemoryRankingWeights, - fallbackBudget: MemoryReadBudget - ) -> MemoryQuery? { - let selection = message.memorySelection - if selection?.mode == .disable { - return nil - } - - let threadContext = thread.memoryContext - let namespace = selection?.namespace ?? - threadContext?.namespace - - guard let namespace else { - return nil - } - - let scopes: [MemoryScope] - switch selection?.mode ?? .inherit { - case .append: - scopes = uniqueScopes((threadContext?.scopes ?? []) + (selection?.scopes ?? [])) - case .replace: - scopes = selection?.scopes ?? [] - case .disable: - return nil - case .inherit: - if let selection, - !selection.scopes.isEmpty { - scopes = selection.scopes - } else { - scopes = threadContext?.scopes ?? [] - } - } - - let kinds = resolvedValues( - mode: selection?.mode ?? .inherit, - threadValues: threadContext?.kinds ?? [], - selectionValues: selection?.kinds ?? [] - ) - let tags = resolvedValues( - mode: selection?.mode ?? .inherit, - threadValues: threadContext?.tags ?? [], - selectionValues: selection?.tags ?? [] - ) - let relatedIDs = resolvedValues( - mode: selection?.mode ?? .inherit, - threadValues: threadContext?.relatedIDs ?? [], - selectionValues: selection?.relatedIDs ?? [] - ) - - let recencyWindow = selection?.recencyWindow - ?? threadContext?.recencyWindow - let minImportance = selection?.minImportance - ?? threadContext?.minImportance - let ranking = selection?.ranking - ?? threadContext?.ranking - ?? fallbackRanking - let budget = resolvedMemoryBudget( - thread: thread, - message: message, - fallback: fallbackBudget - ) - let text = selection?.text ?? message.text - - return MemoryQuery( - namespace: namespace, - scopes: scopes, - text: text, - kinds: kinds, - tags: tags, - relatedIDs: relatedIDs, - recencyWindow: recencyWindow, - minImportance: minImportance, - ranking: ranking, - limit: budget.maxItems, - maxCharacters: budget.maxCharacters, - includeArchived: false - ) - } - - private func resolvedMemoryBudget( - thread: AgentThread, - message: UserMessageRequest, - fallback: MemoryReadBudget - ) -> MemoryReadBudget { - message.memorySelection?.readBudget - ?? thread.memoryContext?.readBudget - ?? fallback - } - - private func uniqueScopes(_ scopes: [MemoryScope]) -> [MemoryScope] { - var seen: Set = [] - return scopes.filter { seen.insert($0).inserted } - } - - private func resolvedValues( - mode: MemorySelectionMode, - threadValues: [String], - selectionValues: [String] - ) -> [String] { - switch mode { - case .append: - return Array(Set(threadValues + selectionValues)).sorted() - case .replace: - return selectionValues - case .disable: - return [] - case .inherit: - return selectionValues.isEmpty ? threadValues : selectionValues - } - } - - private func compileToolPolicy(from skills: [AgentSkill]) -> CompiledSkillToolPolicy { - var allowedToolNames: Set? - var requiredToolNames: Set = [] - var toolSequence: [String]? - var maxToolCalls: Int? - - for skill in skills { - guard let executionPolicy = skill.executionPolicy else { - continue - } - - if let allowed = executionPolicy.allowedToolNames, - !allowed.isEmpty { - let allowedSet = Set(allowed) - if let existingAllowed = allowedToolNames { - allowedToolNames = existingAllowed.intersection(allowedSet) - } else { - allowedToolNames = allowedSet - } - } - - if !executionPolicy.requiredToolNames.isEmpty { - requiredToolNames.formUnion(executionPolicy.requiredToolNames) - } - - if let sequence = executionPolicy.toolSequence, - !sequence.isEmpty { - toolSequence = sequence - } - - if let maxCalls = executionPolicy.maxToolCalls { - if let existingMaxCalls = maxToolCalls { - maxToolCalls = min(existingMaxCalls, maxCalls) - } else { - maxToolCalls = maxCalls - } - } - } - - return CompiledSkillToolPolicy( - allowedToolNames: allowedToolNames, - requiredToolNames: requiredToolNames, - toolSequence: toolSequence, - maxToolCalls: maxToolCalls - ) - } - - private func resolveSkills(for skillIDs: [String]) -> [AgentSkill] { - skillIDs.compactMap { skillsByID[$0] } - } - - private func assertSkillsExist(_ skillIDs: [String]) throws { - let missing = Array(Set(skillIDs.filter { skillsByID[$0] == nil })).sorted() - guard missing.isEmpty else { - throw AgentRuntimeError.skillsNotFound(missing) - } - } - - private static func validatedSkills(from skills: [AgentSkill]) throws -> [String: AgentSkill] { - var dictionary: [String: AgentSkill] = [:] - for skill in skills { - guard AgentSkill.isValidID(skill.id) else { - throw AgentRuntimeError.invalidSkillID(skill.id) - } - try validateSkillExecutionPolicy(skill) - guard dictionary[skill.id] == nil else { - throw AgentRuntimeError.duplicateSkill(skill.id) - } - dictionary[skill.id] = skill - } - return dictionary - } - - private static func validateSkillExecutionPolicy(_ skill: AgentSkill) throws { - guard let executionPolicy = skill.executionPolicy else { - return - } - - if let maxToolCalls = executionPolicy.maxToolCalls, - maxToolCalls < 0 { - throw AgentRuntimeError.invalidSkillMaxToolCalls(skillID: skill.id) - } - - let policyToolNames: [String] = - (executionPolicy.allowedToolNames ?? []) + - executionPolicy.requiredToolNames + - (executionPolicy.toolSequence ?? []) - - for toolName in policyToolNames { - guard ToolDefinition.isValidName(toolName) else { - throw AgentRuntimeError.invalidSkillToolName( - skillID: skill.id, - toolName: toolName - ) - } - } - } } diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift new file mode 100644 index 0000000..23dc289 --- /dev/null +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend+Models.swift @@ -0,0 +1,239 @@ +import Foundation + +struct ResponsesRequestBody: Encodable { + let model: String + let reasoning: ResponsesReasoningConfiguration + let instructions: String + let text: ResponsesTextConfiguration + let input: [JSONValue] + let tools: [JSONValue] + let toolChoice: String + let parallelToolCalls: Bool + let store: Bool + let stream: Bool + let include: [String] + let promptCacheKey: String? + + enum CodingKeys: String, CodingKey { + case model + case reasoning + case instructions + case text + case input + case tools + case toolChoice = "tool_choice" + case parallelToolCalls = "parallel_tool_calls" + case store + case stream + case include + case promptCacheKey = "prompt_cache_key" + } +} + +struct ResponsesReasoningConfiguration: Encodable { + let effort: String + + init(effort: ReasoningEffort) { + self.effort = effort.apiValue + } +} + +struct ResponsesTextConfiguration: Encodable { + let format: ResponsesTextFormat +} + +struct ResponsesTextFormat: Encodable { + let type: String + let name: String? + let description: String? + let schema: JSONValue? + let strict: Bool? + + init(responseFormat: AgentStructuredOutputFormat?) { + if let responseFormat { + type = "json_schema" + name = responseFormat.name + description = responseFormat.description + schema = responseFormat.schema.jsonValue + strict = responseFormat.strict + } else { + type = "text" + name = nil + description = nil + schema = nil + strict = nil + } + } +} + +enum WorkingHistoryItem: Sendable { + case visibleMessage(AgentMessage) + case userMessage(AgentMessage) + case assistantMessage(AgentMessage) + case functionCall(FunctionCallRecord) + case functionCallOutput(callID: String, output: String) + + var jsonValue: JSONValue { + switch self { + case let .visibleMessage(message): + Self.messageJSONValue(for: message) + case let .userMessage(message): + Self.messageJSONValue(for: message) + case let .assistantMessage(message): + Self.messageJSONValue(for: message) + case let .functionCall(functionCall): + .object([ + "type": .string("function_call"), + "name": .string(functionCall.name), + "arguments": .string(functionCall.argumentsRaw), + "call_id": .string(functionCall.callID), + ]) + case let .functionCallOutput(callID, output): + .object([ + "type": .string("function_call_output"), + "call_id": .string(callID), + "output": .string(output), + ]) + } + } + + private static func messageJSONValue(for message: AgentMessage) -> JSONValue { + let roleValue: String = switch message.role { + case .assistant: + "assistant" + case .system: + "system" + case .tool: + "assistant" + case .user: + "user" + } + + var content: [JSONValue] = [] + + switch message.role { + case .assistant: + if !message.text.isEmpty { + content.append(.object([ + "type": .string("output_text"), + "text": .string(message.text), + ])) + } + content.append(contentsOf: message.images.map { image in + .object([ + "type": .string("output_image"), + "image_url": .string(image.dataURLString), + ]) + }) + + default: + if !message.text.isEmpty { + content.append(.object([ + "type": .string("input_text"), + "text": .string(message.text), + ])) + } + + if message.role == .user { + content.append(contentsOf: message.images.map { image in + .object([ + "type": .string("input_image"), + "image_url": .string(image.dataURLString), + ]) + }) + } + } + + return .object([ + "type": .string("message"), + "role": .string(roleValue), + "content": .array(content), + ]) + } +} + +struct FunctionCallRecord: Sendable { + let name: String + let callID: String + let argumentsRaw: String + + var arguments: JSONValue { + guard let data = argumentsRaw.data(using: .utf8), + let value = try? JSONDecoder().decode(JSONValue.self, from: data) + else { + return .string(argumentsRaw) + } + return value + } +} + +enum CodexResponsesStreamEvent: Sendable { + case assistantTextDelta(String) + case assistantMessage(AgentMessage) + case functionCall(FunctionCallRecord) + case completed(AgentUsage) +} + +struct PendingToolResults: Sendable { + private actor Storage { + private var waiting: [String: CheckedContinuation] = [:] + private var resolved: [String: ToolResultEnvelope] = [:] + + func wait(for invocationID: String) async throws -> ToolResultEnvelope { + if let resolved = resolved.removeValue(forKey: invocationID) { + return resolved + } + + return try await withCheckedThrowingContinuation { continuation in + waiting[invocationID] = continuation + } + } + + func resolve(_ result: ToolResultEnvelope, for invocationID: String) { + if let continuation = waiting.removeValue(forKey: invocationID) { + continuation.resume(returning: result) + } else { + resolved[invocationID] = result + } + } + } + + private let storage = Storage() + + func wait(for invocationID: String) async throws -> ToolResultEnvelope { + try await storage.wait(for: invocationID) + } + + func resolve(_ result: ToolResultEnvelope, for invocationID: String) async { + await storage.resolve(result, for: invocationID) + } +} + +extension ToolDefinition { + var responsesJSONValue: JSONValue { + .object([ + "type": .string("function"), + "name": .string(name), + "description": .string(description), + "strict": .bool(false), + "parameters": normalizedSchema, + ]) + } + + var normalizedSchema: JSONValue { + guard case var .object(schema) = inputSchema else { + return inputSchema + } + if schema["properties"] == nil { + schema["properties"] = .object([:]) + } + return .object(schema) + } +} + +extension Array where Element: Hashable { + func uniqued() -> [Element] { + var seen = Set() + return filter { seen.insert($0).inserted } + } +} diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift new file mode 100644 index 0000000..97bb364 --- /dev/null +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend+Streaming.swift @@ -0,0 +1,537 @@ +import Foundation + +struct SSEEventPayload { + let event: String? + let data: String +} + +struct SSEEventParser { + private var eventName: String? + private var dataLines: [String] = [] + + mutating func consume(line: String) -> SSEEventPayload? { + if line.isEmpty { + return flush() + } + + if line.hasPrefix("event:") { + eventName = Self.trimmedFieldValue(from: line) + } else if line.hasPrefix("data:") { + dataLines.append(Self.trimmedFieldValue(from: line)) + } + + return nil + } + + mutating func finish() -> SSEEventPayload? { + flush() + } + + private mutating func flush() -> SSEEventPayload? { + guard !dataLines.isEmpty else { + eventName = nil + return nil + } + + let payload = SSEEventPayload( + event: eventName, + data: dataLines.joined(separator: "\n") + ) + eventName = nil + dataLines.removeAll(keepingCapacity: true) + return payload + } + + private static func trimmedFieldValue(from line: String) -> String { + let value = line.drop { $0 != ":" } + return value.dropFirst().trimmingCharacters(in: .whitespaces) + } +} + +struct StreamEnvelope: Decodable { + let type: String + let delta: String? + let item: StreamItem? + let response: StreamResponsePayload? +} + +enum StreamItem: Decodable { + case message(StreamMessageItem) + case functionCall(StreamFunctionCallItem) + case other + + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + let object = try container.decode([String: JSONValue].self) + let type = object["type"]?.stringValue + + switch type { + case "message": + let data = try JSONEncoder().encode(object) + self = .message(try JSONDecoder().decode(StreamMessageItem.self, from: data)) + case "function_call": + let data = try JSONEncoder().encode(object) + self = .functionCall(try JSONDecoder().decode(StreamFunctionCallItem.self, from: data)) + default: + self = .other + } + } +} + +struct StreamMessageItem: Decodable { + let role: String + let content: [StreamMessageContent] +} + +struct StreamMessageContent: Decodable { + let type: String + let displayText: String? + let imageAttachment: AgentImageAttachment? + + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + let object = try container.decode([String: JSONValue].self) + type = object["type"]?.stringValue ?? "" + displayText = object["text"]?.stringValue ?? object["refusal"]?.stringValue + imageAttachment = Self.parseImageAttachment(from: object) + } + + private static func parseImageAttachment(from object: [String: JSONValue]) -> AgentImageAttachment? { + if let dataURL = object["image_url"]?.stringValue, + let attachment = AgentImageAttachment(dataURLString: dataURL) { + return attachment + } + + if let imageObject = object["image"]?.objectValue, + let dataURL = imageObject["image_url"]?.stringValue, + let attachment = AgentImageAttachment(dataURLString: dataURL) { + return attachment + } + + if let b64 = object["b64_json"]?.stringValue { + return AgentImageAttachment(base64String: b64) + } + + return nil + } +} + +struct StreamFunctionCallItem: Decodable { + let name: String + let arguments: String + let callID: String + + enum CodingKeys: String, CodingKey { + case name + case arguments + case callID = "call_id" + } +} + +struct StreamResponsePayload: Decodable { + let id: String? + let usage: StreamUsage? + let error: StreamErrorPayload? + let incompleteDetails: StreamIncompleteDetails? + + enum CodingKeys: String, CodingKey { + case id + case usage + case error + case incompleteDetails = "incomplete_details" + } +} + +struct StreamUsage: Decodable { + let inputTokens: Int + let inputTokensDetails: StreamInputTokenDetails? + let outputTokens: Int + + enum CodingKeys: String, CodingKey { + case inputTokens = "input_tokens" + case inputTokensDetails = "input_tokens_details" + case outputTokens = "output_tokens" + } + + var assistantUsage: AgentUsage { + AgentUsage( + inputTokens: inputTokens, + cachedInputTokens: inputTokensDetails?.cachedTokens ?? 0, + outputTokens: outputTokens + ) + } +} + +struct StreamInputTokenDetails: Decodable { + let cachedTokens: Int + + enum CodingKeys: String, CodingKey { + case cachedTokens = "cached_tokens" + } +} + +struct StreamErrorPayload: Decodable { + let message: String? +} + +struct StreamIncompleteDetails: Decodable { + let reason: String? +} + +extension CodexResponsesTurnSession { + static func buildURLRequest( + configuration: CodexResponsesBackendConfiguration, + instructions: String, + responseFormat: AgentStructuredOutputFormat?, + threadID: String, + items: [WorkingHistoryItem], + tools: [ToolDefinition], + session: ChatGPTSession, + encoder: JSONEncoder + ) throws -> URLRequest { + let requestBody = ResponsesRequestBody( + model: configuration.model, + reasoning: .init(effort: configuration.reasoningEffort), + instructions: instructions, + text: .init(format: .init(responseFormat: responseFormat)), + input: items.map(\.jsonValue), + tools: responsesTools( + from: tools, + enableWebSearch: configuration.enableWebSearch + ), + toolChoice: "auto", + parallelToolCalls: false, + store: false, + stream: true, + include: [], + promptCacheKey: threadID + ) + + var request = URLRequest(url: configuration.baseURL.appendingPathComponent("responses")) + request.httpMethod = "POST" + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + request.setValue("Bearer \(session.accessToken)", forHTTPHeaderField: "Authorization") + request.setValue(session.account.id, forHTTPHeaderField: "ChatGPT-Account-ID") + request.setValue(threadID, forHTTPHeaderField: "session_id") + request.setValue(threadID, forHTTPHeaderField: "x-client-request-id") + request.setValue(configuration.originator, forHTTPHeaderField: "originator") + + for (header, value) in configuration.extraHeaders { + request.setValue(value, forHTTPHeaderField: header) + } + + return request + } + + static func responsesTools( + from tools: [ToolDefinition], + enableWebSearch: Bool + ) -> [JSONValue] { + var responsesTools = tools.map(\.responsesJSONValue) + + if enableWebSearch { + responsesTools.append(.object(["type": .string("web_search")])) + } + + return responsesTools + } + + static func streamEvents( + request: URLRequest, + urlSession: URLSession, + decoder: JSONDecoder + ) async throws -> AsyncThrowingStream { + let (bytes, response) = try await urlSession.bytes(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw AgentRuntimeError( + code: "responses_invalid_response", + message: "The ChatGPT responses endpoint returned an invalid response." + ) + } + + if !(200 ..< 300).contains(httpResponse.statusCode) { + let bodyData = try await readAll(bytes) + let body = String(data: bodyData, encoding: .utf8) ?? "Unknown error" + if httpResponse.statusCode == 401 || httpResponse.statusCode == 403 { + throw AgentRuntimeError.unauthorized(body) + } + throw AgentRuntimeError( + code: "responses_http_status_\(httpResponse.statusCode)", + message: "The ChatGPT responses request failed with status \(httpResponse.statusCode): \(body)" + ) + } + + return AsyncThrowingStream { continuation in + Task { + var parser = SSEEventParser() + + do { + var lineBuffer = Data() + + for try await byte in bytes { + if byte == UInt8(ascii: "\n") { + var line = String(decoding: lineBuffer, as: UTF8.self) + if line.hasSuffix("\r") { + line.removeLast() + } + lineBuffer.removeAll(keepingCapacity: true) + + if let payload = parser.consume(line: line) { + if let event = try parseStreamEvent( + from: payload, + decoder: decoder + ) { + continuation.yield(event) + } + } + continue + } + + lineBuffer.append(byte) + } + + if !lineBuffer.isEmpty { + var line = String(decoding: lineBuffer, as: UTF8.self) + if line.hasSuffix("\r") { + line.removeLast() + } + if let payload = parser.consume(line: line) { + if let event = try parseStreamEvent( + from: payload, + decoder: decoder + ) { + continuation.yield(event) + } + } + } + + if let payload = parser.finish(), + let event = try parseStreamEvent(from: payload, decoder: decoder) { + continuation.yield(event) + } + + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + } + } + + static func shouldRetry( + _ error: Error, + policy: RequestRetryPolicy + ) -> Bool { + if let runtimeError = error as? AgentRuntimeError { + if runtimeError.code == AgentRuntimeError.unauthorized().code { + return false + } + if let statusCode = httpStatusCode(from: runtimeError.code) { + return policy.retryableHTTPStatusCodes.contains(statusCode) + } + return false + } + + if let urlError = error as? URLError { + return policy.retryableURLErrorCodes.contains(urlError.errorCode) + } + + let nsError = error as NSError + if nsError.domain == NSURLErrorDomain { + return policy.retryableURLErrorCodes.contains(nsError.code) + } + + return false + } + + static func httpStatusCode(from errorCode: String) -> Int? { + let prefix = "responses_http_status_" + guard errorCode.hasPrefix(prefix) else { + return nil + } + return Int(errorCode.dropFirst(prefix.count)) + } + + static func parseStreamEvent( + from payload: SSEEventPayload, + decoder: JSONDecoder + ) throws -> CodexResponsesStreamEvent? { + guard !payload.data.isEmpty else { + return nil + } + + let envelope = try decoder.decode( + StreamEnvelope.self, + from: Data(payload.data.utf8) + ) + + switch envelope.type { + case "response.output_text.delta": + return envelope.delta.map(CodexResponsesStreamEvent.assistantTextDelta) + + case "response.output_item.done": + guard let item = envelope.item else { + return nil + } + + switch item { + case let .message(message): + let text = message.content + .compactMap(\.displayText) + .joined() + .trimmingCharacters(in: .whitespacesAndNewlines) + let images = message.content.compactMap(\.imageAttachment) + guard !text.isEmpty || !images.isEmpty else { + return nil + } + return .assistantMessage( + AgentMessage( + threadID: "", + role: .assistant, + text: text, + images: images + ) + ) + + case let .functionCall(functionCall): + return .functionCall( + FunctionCallRecord( + name: functionCall.name, + callID: functionCall.callID, + argumentsRaw: functionCall.arguments + ) + ) + + case .other: + return nil + } + + case "response.completed": + let usage = envelope.response?.usage?.assistantUsage ?? AgentUsage() + return .completed(usage) + + case "response.failed": + let message = envelope.response?.error?.message ?? "The ChatGPT responses stream failed." + throw AgentRuntimeError(code: "responses_stream_failed", message: message) + + case "response.incomplete": + let reason = envelope.response?.incompleteDetails?.reason ?? "unknown" + throw AgentRuntimeError( + code: "responses_stream_incomplete", + message: "The ChatGPT responses stream completed early: \(reason)." + ) + + default: + return nil + } + } + + static func readAll(_ bytes: URLSession.AsyncBytes) async throws -> Data { + var data = Data() + for try await byte in bytes { + data.append(byte) + } + return data + } + + static func toolOutputText(from result: ToolResultEnvelope) -> String { + var segments: [String] = [] + + if let primaryText = result.primaryText, !primaryText.isEmpty { + segments.append(primaryText) + } + + let imageURLs = result.content.compactMap { content -> URL? in + guard case let .image(url) = content else { + return nil + } + return url + } + if !imageURLs.isEmpty { + segments.append( + "Image URLs:\n" + imageURLs.map(\.absoluteString).joined(separator: "\n") + ) + } + + if !segments.isEmpty { + return segments.joined(separator: "\n\n") + } + if let errorMessage = result.errorMessage, !errorMessage.isEmpty { + return errorMessage + } + return result.success ? "Tool execution completed." : "Tool execution failed." + } + + static func toolOutputImages( + from result: ToolResultEnvelope, + urlSession: URLSession + ) async -> [AgentImageAttachment] { + var attachments: [AgentImageAttachment] = [] + for content in result.content { + guard case let .image(url) = content else { + continue + } + if let attachment = await imageAttachment(from: url, urlSession: urlSession) { + attachments.append(attachment) + } + } + return attachments.uniqued() + } + + static func imageAttachment( + from url: URL, + urlSession: URLSession + ) async -> AgentImageAttachment? { + if url.scheme?.lowercased() == "data" { + let decoded = url.absoluteString.removingPercentEncoding ?? url.absoluteString + return AgentImageAttachment(dataURLString: decoded) + } + + if url.isFileURL { + guard let mimeType = inferredImageMimeType(from: url.pathExtension), + let data = try? Data(contentsOf: url), + !data.isEmpty + else { + return nil + } + return AgentImageAttachment(mimeType: mimeType, data: data) + } + + do { + let (data, response) = try await urlSession.data(from: url) + guard !data.isEmpty else { + return nil + } + + let mimeType = response.mimeType?.lowercased() ?? inferredImageMimeType(from: url.pathExtension) + let normalized = mimeType ?? "image/png" + guard normalized.hasPrefix("image/") else { + return nil + } + return AgentImageAttachment(mimeType: normalized, data: data) + } catch { + return nil + } + } + + static func inferredImageMimeType(from pathExtension: String) -> String? { + switch pathExtension.lowercased() { + case "png": + "image/png" + case "jpg", "jpeg": + "image/jpeg" + case "gif": + "image/gif" + case "webp": + "image/webp" + case "heic": + "image/heic" + case "heif": + "image/heif" + default: + nil + } + } +} diff --git a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift index 77581dc..8f045ca 100644 --- a/Sources/CodexKit/Runtime/CodexResponsesBackend.swift +++ b/Sources/CodexKit/Runtime/CodexResponsesBackend.swift @@ -323,776 +323,4 @@ final class CodexResponsesTurnSession: AgentTurnStreaming, @unchecked Sendable { return aggregateUsage } - - private static func buildURLRequest( - configuration: CodexResponsesBackendConfiguration, - instructions: String, - responseFormat: AgentStructuredOutputFormat?, - threadID: String, - items: [WorkingHistoryItem], - tools: [ToolDefinition], - session: ChatGPTSession, - encoder: JSONEncoder - ) throws -> URLRequest { - let requestBody = ResponsesRequestBody( - model: configuration.model, - reasoning: .init(effort: configuration.reasoningEffort), - instructions: instructions, - text: .init(format: .init(responseFormat: responseFormat)), - input: items.map(\.jsonValue), - tools: responsesTools( - from: tools, - enableWebSearch: configuration.enableWebSearch - ), - toolChoice: "auto", - parallelToolCalls: false, - store: false, - stream: true, - include: [], - promptCacheKey: threadID - ) - - var request = URLRequest(url: configuration.baseURL.appendingPathComponent("responses")) - request.httpMethod = "POST" - request.httpBody = try encoder.encode(requestBody) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - request.setValue("text/event-stream", forHTTPHeaderField: "Accept") - request.setValue("Bearer \(session.accessToken)", forHTTPHeaderField: "Authorization") - request.setValue(session.account.id, forHTTPHeaderField: "ChatGPT-Account-ID") - request.setValue(threadID, forHTTPHeaderField: "session_id") - request.setValue(threadID, forHTTPHeaderField: "x-client-request-id") - request.setValue(configuration.originator, forHTTPHeaderField: "originator") - - for (header, value) in configuration.extraHeaders { - request.setValue(value, forHTTPHeaderField: header) - } - - return request - } - - private static func responsesTools( - from tools: [ToolDefinition], - enableWebSearch: Bool - ) -> [JSONValue] { - var responsesTools = tools.map(\.responsesJSONValue) - - if enableWebSearch { - responsesTools.append(.object(["type": .string("web_search")])) - } - - return responsesTools - } - - private static func streamEvents( - request: URLRequest, - urlSession: URLSession, - decoder: JSONDecoder - ) async throws -> AsyncThrowingStream { - let (bytes, response) = try await urlSession.bytes(for: request) - guard let httpResponse = response as? HTTPURLResponse else { - throw AgentRuntimeError( - code: "responses_invalid_response", - message: "The ChatGPT responses endpoint returned an invalid response." - ) - } - - if !(200 ..< 300).contains(httpResponse.statusCode) { - let bodyData = try await readAll(bytes) - let body = String(data: bodyData, encoding: .utf8) ?? "Unknown error" - if httpResponse.statusCode == 401 || httpResponse.statusCode == 403 { - throw AgentRuntimeError.unauthorized(body) - } - throw AgentRuntimeError( - code: "responses_http_status_\(httpResponse.statusCode)", - message: "The ChatGPT responses request failed with status \(httpResponse.statusCode): \(body)" - ) - } - - return AsyncThrowingStream { continuation in - Task { - var parser = SSEEventParser() - - do { - var lineBuffer = Data() - - for try await byte in bytes { - if byte == UInt8(ascii: "\n") { - var line = String(decoding: lineBuffer, as: UTF8.self) - if line.hasSuffix("\r") { - line.removeLast() - } - lineBuffer.removeAll(keepingCapacity: true) - - if let payload = parser.consume(line: line) { - if let event = try parseStreamEvent( - from: payload, - decoder: decoder - ) { - continuation.yield(event) - } - } - continue - } - - lineBuffer.append(byte) - } - - if !lineBuffer.isEmpty { - var line = String(decoding: lineBuffer, as: UTF8.self) - if line.hasSuffix("\r") { - line.removeLast() - } - if let payload = parser.consume(line: line) { - if let event = try parseStreamEvent( - from: payload, - decoder: decoder - ) { - continuation.yield(event) - } - } - } - - if let payload = parser.finish(), - let event = try parseStreamEvent(from: payload, decoder: decoder) { - continuation.yield(event) - } - - continuation.finish() - } catch { - continuation.finish(throwing: error) - } - } - } - } - - private static func shouldRetry( - _ error: Error, - policy: RequestRetryPolicy - ) -> Bool { - if let runtimeError = error as? AgentRuntimeError { - if runtimeError.code == AgentRuntimeError.unauthorized().code { - return false - } - if let statusCode = httpStatusCode(from: runtimeError.code) { - return policy.retryableHTTPStatusCodes.contains(statusCode) - } - return false - } - - if let urlError = error as? URLError { - return policy.retryableURLErrorCodes.contains(urlError.errorCode) - } - - let nsError = error as NSError - if nsError.domain == NSURLErrorDomain { - return policy.retryableURLErrorCodes.contains(nsError.code) - } - - return false - } - - private static func httpStatusCode(from errorCode: String) -> Int? { - let prefix = "responses_http_status_" - guard errorCode.hasPrefix(prefix) else { - return nil - } - return Int(errorCode.dropFirst(prefix.count)) - } - - private static func parseStreamEvent( - from payload: SSEEventPayload, - decoder: JSONDecoder - ) throws -> CodexResponsesStreamEvent? { - guard !payload.data.isEmpty else { - return nil - } - - let envelope = try decoder.decode( - StreamEnvelope.self, - from: Data(payload.data.utf8) - ) - - switch envelope.type { - case "response.output_text.delta": - return envelope.delta.map(CodexResponsesStreamEvent.assistantTextDelta) - - case "response.output_item.done": - guard let item = envelope.item else { - return nil - } - - switch item { - case let .message(message): - let text = message.content - .compactMap(\.displayText) - .joined() - .trimmingCharacters(in: .whitespacesAndNewlines) - let images = message.content.compactMap(\.imageAttachment) - guard !text.isEmpty || !images.isEmpty else { - return nil - } - return .assistantMessage( - AgentMessage( - threadID: "", - role: .assistant, - text: text, - images: images - ) - ) - - case let .functionCall(functionCall): - return .functionCall( - FunctionCallRecord( - name: functionCall.name, - callID: functionCall.callID, - argumentsRaw: functionCall.arguments - ) - ) - - case .other: - return nil - } - - case "response.completed": - let usage = envelope.response?.usage?.assistantUsage ?? AgentUsage() - return .completed(usage) - - case "response.failed": - let message = envelope.response?.error?.message ?? "The ChatGPT responses stream failed." - throw AgentRuntimeError(code: "responses_stream_failed", message: message) - - case "response.incomplete": - let reason = envelope.response?.incompleteDetails?.reason ?? "unknown" - throw AgentRuntimeError( - code: "responses_stream_incomplete", - message: "The ChatGPT responses stream completed early: \(reason)." - ) - - default: - return nil - } - } - - private static func readAll(_ bytes: URLSession.AsyncBytes) async throws -> Data { - var data = Data() - for try await byte in bytes { - data.append(byte) - } - return data - } - - private static func toolOutputText(from result: ToolResultEnvelope) -> String { - var segments: [String] = [] - - if let primaryText = result.primaryText, !primaryText.isEmpty { - segments.append(primaryText) - } - - let imageURLs = result.content.compactMap { content -> URL? in - guard case let .image(url) = content else { - return nil - } - return url - } - if !imageURLs.isEmpty { - segments.append( - "Image URLs:\n" + imageURLs.map(\.absoluteString).joined(separator: "\n") - ) - } - - if !segments.isEmpty { - return segments.joined(separator: "\n\n") - } - if let errorMessage = result.errorMessage, !errorMessage.isEmpty { - return errorMessage - } - return result.success ? "Tool execution completed." : "Tool execution failed." - } - - private static func toolOutputImages( - from result: ToolResultEnvelope, - urlSession: URLSession - ) async -> [AgentImageAttachment] { - var attachments: [AgentImageAttachment] = [] - for content in result.content { - guard case let .image(url) = content else { - continue - } - if let attachment = await imageAttachment(from: url, urlSession: urlSession) { - attachments.append(attachment) - } - } - return attachments.uniqued() - } - - private static func imageAttachment( - from url: URL, - urlSession: URLSession - ) async -> AgentImageAttachment? { - if url.scheme?.lowercased() == "data" { - let decoded = url.absoluteString.removingPercentEncoding ?? url.absoluteString - return AgentImageAttachment(dataURLString: decoded) - } - - if url.isFileURL { - guard let mimeType = inferredImageMimeType(from: url.pathExtension), - let data = try? Data(contentsOf: url), - !data.isEmpty - else { - return nil - } - return AgentImageAttachment(mimeType: mimeType, data: data) - } - - do { - let (data, response) = try await urlSession.data(from: url) - guard !data.isEmpty else { - return nil - } - - let mimeType = response.mimeType?.lowercased() ?? inferredImageMimeType(from: url.pathExtension) - let normalized = mimeType ?? "image/png" - guard normalized.hasPrefix("image/") else { - return nil - } - return AgentImageAttachment(mimeType: normalized, data: data) - } catch { - return nil - } - } - - private static func inferredImageMimeType(from pathExtension: String) -> String? { - switch pathExtension.lowercased() { - case "png": - "image/png" - case "jpg", "jpeg": - "image/jpeg" - case "gif": - "image/gif" - case "webp": - "image/webp" - case "heic": - "image/heic" - case "heif": - "image/heif" - default: - nil - } - } -} - -private struct ResponsesRequestBody: Encodable { - let model: String - let reasoning: ResponsesReasoningConfiguration - let instructions: String - let text: ResponsesTextConfiguration - let input: [JSONValue] - let tools: [JSONValue] - let toolChoice: String - let parallelToolCalls: Bool - let store: Bool - let stream: Bool - let include: [String] - let promptCacheKey: String? - - enum CodingKeys: String, CodingKey { - case model - case reasoning - case instructions - case text - case input - case tools - case toolChoice = "tool_choice" - case parallelToolCalls = "parallel_tool_calls" - case store - case stream - case include - case promptCacheKey = "prompt_cache_key" - } -} - -private struct ResponsesReasoningConfiguration: Encodable { - let effort: String - - init(effort: ReasoningEffort) { - self.effort = effort.apiValue - } -} - -private struct ResponsesTextConfiguration: Encodable { - let format: ResponsesTextFormat -} - -private struct ResponsesTextFormat: Encodable { - let type: String - let name: String? - let description: String? - let schema: JSONValue? - let strict: Bool? - - init(responseFormat: AgentStructuredOutputFormat?) { - if let responseFormat { - type = "json_schema" - name = responseFormat.name - description = responseFormat.description - schema = responseFormat.schema.jsonValue - strict = responseFormat.strict - } else { - type = "text" - name = nil - description = nil - schema = nil - strict = nil - } - } -} - -private enum WorkingHistoryItem: Sendable { - case visibleMessage(AgentMessage) - case userMessage(AgentMessage) - case assistantMessage(AgentMessage) - case functionCall(FunctionCallRecord) - case functionCallOutput(callID: String, output: String) - - var jsonValue: JSONValue { - switch self { - case let .visibleMessage(message): - Self.messageJSONValue(for: message) - case let .userMessage(message): - Self.messageJSONValue(for: message) - case let .assistantMessage(message): - Self.messageJSONValue(for: message) - case let .functionCall(functionCall): - .object([ - "type": .string("function_call"), - "name": .string(functionCall.name), - "arguments": .string(functionCall.argumentsRaw), - "call_id": .string(functionCall.callID), - ]) - case let .functionCallOutput(callID, output): - .object([ - "type": .string("function_call_output"), - "call_id": .string(callID), - "output": .string(output), - ]) - } - } - - private static func messageJSONValue(for message: AgentMessage) -> JSONValue { - let roleValue: String = switch message.role { - case .assistant: - "assistant" - case .system: - "system" - case .tool: - "assistant" - case .user: - "user" - } - - var content: [JSONValue] = [] - - switch message.role { - case .assistant: - if !message.text.isEmpty { - content.append(.object([ - "type": .string("output_text"), - "text": .string(message.text), - ])) - } - content.append(contentsOf: message.images.map { image in - .object([ - "type": .string("output_image"), - "image_url": .string(image.dataURLString), - ]) - }) - - default: - if !message.text.isEmpty { - content.append(.object([ - "type": .string("input_text"), - "text": .string(message.text), - ])) - } - - if message.role == .user { - content.append(contentsOf: message.images.map { image in - .object([ - "type": .string("input_image"), - "image_url": .string(image.dataURLString), - ]) - }) - } - } - - return .object([ - "type": .string("message"), - "role": .string(roleValue), - "content": .array(content), - ]) - } -} - -private struct FunctionCallRecord: Sendable { - let name: String - let callID: String - let argumentsRaw: String - - var arguments: JSONValue { - guard let data = argumentsRaw.data(using: .utf8), - let value = try? JSONDecoder().decode(JSONValue.self, from: data) - else { - return .string(argumentsRaw) - } - return value - } -} - -private enum CodexResponsesStreamEvent: Sendable { - case assistantTextDelta(String) - case assistantMessage(AgentMessage) - case functionCall(FunctionCallRecord) - case completed(AgentUsage) -} - -private struct PendingToolResults: Sendable { - private actor Storage { - private var waiting: [String: CheckedContinuation] = [:] - private var resolved: [String: ToolResultEnvelope] = [:] - - func wait(for invocationID: String) async throws -> ToolResultEnvelope { - if let resolved = resolved.removeValue(forKey: invocationID) { - return resolved - } - - return try await withCheckedThrowingContinuation { continuation in - waiting[invocationID] = continuation - } - } - - func resolve(_ result: ToolResultEnvelope, for invocationID: String) { - if let continuation = waiting.removeValue(forKey: invocationID) { - continuation.resume(returning: result) - } else { - resolved[invocationID] = result - } - } - } - - private let storage = Storage() - - func wait(for invocationID: String) async throws -> ToolResultEnvelope { - try await storage.wait(for: invocationID) - } - - func resolve(_ result: ToolResultEnvelope, for invocationID: String) async { - await storage.resolve(result, for: invocationID) - } -} - -private extension ToolDefinition { - var responsesJSONValue: JSONValue { - .object([ - "type": .string("function"), - "name": .string(name), - "description": .string(description), - "strict": .bool(false), - "parameters": normalizedSchema, - ]) - } - - var normalizedSchema: JSONValue { - guard case var .object(schema) = inputSchema else { - return inputSchema - } - if schema["properties"] == nil { - schema["properties"] = .object([:]) - } - return .object(schema) - } -} - -private struct SSEEventPayload { - let event: String? - let data: String -} - -private struct SSEEventParser { - private var eventName: String? - private var dataLines: [String] = [] - - mutating func consume(line: String) -> SSEEventPayload? { - if line.isEmpty { - return flush() - } - - if line.hasPrefix("event:") { - eventName = Self.trimmedFieldValue(from: line) - } else if line.hasPrefix("data:") { - dataLines.append(Self.trimmedFieldValue(from: line)) - } - - return nil - } - - mutating func finish() -> SSEEventPayload? { - flush() - } - - private mutating func flush() -> SSEEventPayload? { - guard !dataLines.isEmpty else { - eventName = nil - return nil - } - - let payload = SSEEventPayload( - event: eventName, - data: dataLines.joined(separator: "\n") - ) - eventName = nil - dataLines.removeAll(keepingCapacity: true) - return payload - } - - private static func trimmedFieldValue(from line: String) -> String { - let value = line.drop { $0 != ":" } - return value.dropFirst().trimmingCharacters(in: .whitespaces) - } -} - -private struct StreamEnvelope: Decodable { - let type: String - let delta: String? - let item: StreamItem? - let response: StreamResponsePayload? -} - -private enum StreamItem: Decodable { - case message(StreamMessageItem) - case functionCall(StreamFunctionCallItem) - case other - - init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - let object = try container.decode([String: JSONValue].self) - let type = object["type"]?.stringValue - - switch type { - case "message": - let data = try JSONEncoder().encode(object) - self = .message(try JSONDecoder().decode(StreamMessageItem.self, from: data)) - case "function_call": - let data = try JSONEncoder().encode(object) - self = .functionCall(try JSONDecoder().decode(StreamFunctionCallItem.self, from: data)) - default: - self = .other - } - } -} - -private struct StreamMessageItem: Decodable { - let role: String - let content: [StreamMessageContent] -} - -private struct StreamMessageContent: Decodable { - let type: String - let displayText: String? - let imageAttachment: AgentImageAttachment? - - init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - let object = try container.decode([String: JSONValue].self) - type = object["type"]?.stringValue ?? "" - displayText = object["text"]?.stringValue ?? object["refusal"]?.stringValue - imageAttachment = Self.parseImageAttachment(from: object) - } - - private static func parseImageAttachment(from object: [String: JSONValue]) -> AgentImageAttachment? { - if let dataURL = object["image_url"]?.stringValue, - let attachment = AgentImageAttachment(dataURLString: dataURL) { - return attachment - } - - if let imageObject = object["image"]?.objectValue, - let dataURL = imageObject["image_url"]?.stringValue, - let attachment = AgentImageAttachment(dataURLString: dataURL) { - return attachment - } - - if let b64 = object["b64_json"]?.stringValue { - return AgentImageAttachment(base64String: b64) - } - - return nil - } -} - -private struct StreamFunctionCallItem: Decodable { - let name: String - let arguments: String - let callID: String - - enum CodingKeys: String, CodingKey { - case name - case arguments - case callID = "call_id" - } -} - -private struct StreamResponsePayload: Decodable { - let id: String? - let usage: StreamUsage? - let error: StreamErrorPayload? - let incompleteDetails: StreamIncompleteDetails? - - enum CodingKeys: String, CodingKey { - case id - case usage - case error - case incompleteDetails = "incomplete_details" - } -} - -private struct StreamUsage: Decodable { - let inputTokens: Int - let inputTokensDetails: StreamInputTokenDetails? - let outputTokens: Int - - enum CodingKeys: String, CodingKey { - case inputTokens = "input_tokens" - case inputTokensDetails = "input_tokens_details" - case outputTokens = "output_tokens" - } - - var assistantUsage: AgentUsage { - AgentUsage( - inputTokens: inputTokens, - cachedInputTokens: inputTokensDetails?.cachedTokens ?? 0, - outputTokens: outputTokens - ) - } -} - -private struct StreamInputTokenDetails: Decodable { - let cachedTokens: Int - - enum CodingKeys: String, CodingKey { - case cachedTokens = "cached_tokens" - } -} - -private struct StreamErrorPayload: Decodable { - let message: String? -} - -private struct StreamIncompleteDetails: Decodable { - let reason: String? -} - -private extension Array where Element: Hashable { - func uniqued() -> [Element] { - var seen = Set() - return filter { seen.insert($0).inserted } - } } diff --git a/Tests/CodexKitTests/AgentRuntimeMemoryTests.swift b/Tests/CodexKitTests/AgentRuntimeMemoryTests.swift new file mode 100644 index 0000000..1570416 --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimeMemoryTests.swift @@ -0,0 +1,503 @@ +import CodexKit +import XCTest + +// MARK: - Memory + +extension AgentRuntimeTests { + func testRuntimeInjectsRelevantMemoryIntoInstructionsAndPreviewMatches() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let store = InMemoryMemoryStore(initialRecords: [ + MemoryRecord( + namespace: "demo-assistant", + scope: "feature:health-coach", + kind: "preference", + summary: "Health Coach should use direct accountability when the user is behind on steps.", + evidence: ["The user responds better to blunt coaching than soft encouragement."], + importance: 0.9, + tags: ["steps"] + ), + MemoryRecord( + namespace: "demo-assistant", + scope: "feature:travel-planner", + kind: "preference", + summary: "Travel Planner should keep itineraries compact and transit-aware.", + importance: 0.8 + ), + ]) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + memory: .init(store: store) + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Memory", + memoryContext: AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach"] + ) + ) + + let preview = try await runtime.memoryQueryPreview( + for: thread.id, + request: UserMessageRequest(text: "How should the health coach respond when the user is behind on steps?") + ) + XCTAssertEqual(preview?.matches.map(\.record.scope.rawValue), ["feature:health-coach"]) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "How should the health coach respond when the user is behind on steps?"), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + let resolved = try XCTUnwrap(instructions.last) + XCTAssertTrue(resolved.contains("Relevant Memory:")) + XCTAssertTrue(resolved.contains("Health Coach should use direct accountability when the user is behind on steps.")) + XCTAssertFalse(resolved.contains("Travel Planner should keep itineraries compact and transit-aware.")) + } + + func testRuntimeMemorySelectionCanReplaceOrDisableThreadDefaults() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let store = InMemoryMemoryStore(initialRecords: [ + MemoryRecord( + namespace: "demo-assistant", + scope: "feature:health-coach", + kind: "preference", + summary: "Health Coach preference." + ), + MemoryRecord( + namespace: "demo-assistant", + scope: "feature:travel-planner", + kind: "preference", + summary: "Travel Planner preference." + ), + ]) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + memory: .init(store: store) + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Scoped Memory", + memoryContext: AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach"] + ) + ) + + _ = try await runtime.sendMessage( + UserMessageRequest( + text: "Use travel planner memory instead.", + memorySelection: MemorySelection( + mode: .replace, + scopes: ["feature:travel-planner"] + ) + ), + in: thread.id + ) + + _ = try await runtime.sendMessage( + UserMessageRequest( + text: "Now disable memory.", + memorySelection: MemorySelection(mode: .disable) + ), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + XCTAssertEqual(instructions.count, 2) + XCTAssertTrue(instructions[0].contains("Travel Planner preference.")) + XCTAssertFalse(instructions[0].contains("Health Coach preference.")) + XCTAssertFalse(instructions[1].contains("Relevant Memory:")) + } + + func testRuntimeCanAutomaticallyCaptureMemoriesFromTranscript() async throws { + let backend = InMemoryAgentBackend( + structuredResponseText: """ + {"memories":[{"summary":"Health Coach should use direct accountability when step pace is low.","scope":"feature:health-coach","kind":"preference","evidence":["The user asked for blunt reminders when behind on steps."],"importance":0.92,"tags":["steps","tone"],"relatedIDs":["goal-10000"],"dedupeKey":"health-coach-direct-accountability"},{"summary":"Travel Planner should keep itineraries compact and transit-aware.","scope":"feature:travel-planner","kind":"preference","evidence":["The user dislikes sprawling travel plans."],"importance":0.81,"tags":["travel"],"relatedIDs":["travel-style-compact"],"dedupeKey":"travel-planner-compact-itinerary"}]} + """ + ) + let store = InMemoryMemoryStore() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + memory: .init(store: store) + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Auto Memory", + memoryContext: AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach", "feature:travel-planner"] + ) + ) + + let capture = try await runtime.captureMemories( + from: .text(""" + User: Be direct when I am behind on steps. + User: Keep travel itineraries compact and transit-aware. + """), + for: thread.id, + options: .init( + defaults: .init(namespace: "demo-assistant"), + maxMemories: 3 + ) + ) + + XCTAssertEqual(capture.records.count, 2) + XCTAssertEqual(capture.records.map(\.scope.rawValue).sorted(), ["feature:health-coach", "feature:travel-planner"]) + + let stored = try await store.query( + MemoryQuery( + namespace: "demo-assistant", + scopes: ["feature:health-coach", "feature:travel-planner"], + text: "direct steps transit itinerary", + limit: 10, + maxCharacters: 1000 + ) + ) + XCTAssertEqual(stored.matches.count, 2) + + let formats = await backend.receivedResponseFormats() + XCTAssertEqual(formats.last??.name, "memory_extraction_batch") + } + + func testRuntimeCanAutomaticallyCaptureMemoryAfterSuccessfulTurn() async throws { + let backend = InMemoryAgentBackend( + structuredResponseText: """ + {"memories":[{"summary":"Health Coach should use direct accountability when the user falls behind on steps.","scope":"feature:health-coach","kind":"preference","evidence":["The user said blunt reminders work better than soft encouragement."],"importance":0.94,"tags":["steps","tone"],"relatedIDs":["goal-10000"],"dedupeKey":"health-coach-auto-capture"}]} + """ + ) + let store = InMemoryMemoryStore() + let observer = RecordingMemoryObserver() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + memory: .init( + store: store, + observer: observer, + automaticCapturePolicy: .init( + source: .lastTurn, + options: .init( + defaults: .init(namespace: "demo-assistant"), + maxMemories: 2 + ) + ) + ) + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Auto Policy", + memoryContext: AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach"] + ) + ) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "If I am behind on steps, be direct and blunt with me."), + in: thread.id + ) + + let stored = try await store.query( + MemoryQuery( + namespace: "demo-assistant", + scopes: ["feature:health-coach"], + text: "direct blunt steps", + limit: 10, + maxCharacters: 1000 + ) + ) + XCTAssertEqual(stored.matches.count, 1) + XCTAssertEqual(stored.matches[0].record.dedupeKey, "health-coach-auto-capture") + + let formats = await backend.receivedResponseFormats() + XCTAssertEqual(formats.count, 2) + XCTAssertNil(formats.first!) + XCTAssertEqual(formats.last??.name, "memory_extraction_batch") + + let events = await observer.events() + let captureEvents = events.compactMap { event -> (String, String?, Int?)? in + switch event { + case let .captureStarted(threadID, sourceDescription): + return (threadID, sourceDescription, nil) + case let .captureSucceeded(threadID, result): + return (threadID, nil, result.records.count) + default: + return nil + } + } + + XCTAssertEqual(captureEvents.count, 2) + XCTAssertEqual(captureEvents[0].0, thread.id) + XCTAssertEqual(captureEvents[0].1, "last_turn") + XCTAssertEqual(captureEvents[1].0, thread.id) + XCTAssertEqual(captureEvents[1].2, 1) + } + + func testRuntimeGracefullyDegradesWhenMemoryStoreFails() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + memory: .init(store: ThrowingMemoryStore()) + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Graceful", + memoryContext: AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach"] + ) + ) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "This should still work."), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + let resolved = try XCTUnwrap(instructions.last) + XCTAssertFalse(resolved.contains("Relevant Memory:")) + } + + func testRuntimeReportsMemoryObservationEvents() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let store = InMemoryMemoryStore(initialRecords: [ + MemoryRecord( + namespace: "demo-assistant", + scope: "feature:health-coach", + kind: "preference", + summary: "Observed memory." + ), + ]) + let observer = RecordingMemoryObserver() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + memory: .init( + store: store, + observer: observer + ) + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Observed", + memoryContext: AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach"] + ) + ) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Use memory."), + in: thread.id + ) + + let events = await observer.events() + XCTAssertEqual(events.count, 2) + guard case let .queryStarted(startedQuery) = events[0] else { + return XCTFail("Expected queryStarted event.") + } + XCTAssertEqual(startedQuery.namespace, "demo-assistant") + guard case let .querySucceeded(_, result) = events[1] else { + return XCTFail("Expected querySucceeded event.") + } + XCTAssertEqual(result.matches.count, 1) + } + + func testRuntimeProvidesThreadAwareMemoryWriterDefaults() async throws { + let store = InMemoryMemoryStore() + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + memory: .init(store: store) + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Thread Memory Writer", + memoryContext: AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach"], + kinds: ["preference"], + tags: ["steps"], + relatedIDs: ["goal-10000"] + ) + ) + + let writer = try await runtime.memoryWriter(for: thread.id) + let record = try await writer.put( + MemoryDraft( + summary: "The demo user responds better to direct step reminders." + ) + ) + + XCTAssertEqual(record.namespace, "demo-assistant") + XCTAssertEqual(record.scope, "feature:health-coach") + XCTAssertEqual(record.kind, "preference") + XCTAssertEqual(record.tags, ["steps"]) + XCTAssertEqual(record.relatedIDs, ["goal-10000"]) + } + + func testRuntimeMemoryWriterThrowsWhenMemoryIsNotConfigured() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + + await XCTAssertThrowsErrorAsync(try await runtime.memoryWriter()) { error in + XCTAssertEqual(error as? AgentRuntimeError, .memoryNotConfigured()) + } + } + + func testResolvedInstructionsPreviewThrowsForMissingThread() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + + await XCTAssertThrowsErrorAsync( + try await runtime.resolvedInstructionsPreview( + for: "missing-thread", + request: UserMessageRequest(text: "hello") + ) + ) { error in + XCTAssertEqual(error as? AgentRuntimeError, .threadNotFound("missing-thread")) + } + } + + func testThreadMemoryContextPersistsAcrossRestore() async throws { + let runtimeStore = InMemoryRuntimeStateStore() + let memoryContext = AgentMemoryContext( + namespace: "demo-assistant", + scopes: ["feature:health-coach"], + kinds: ["preference"], + tags: ["steps"], + relatedIDs: ["goal-10000"], + readBudget: .init(maxItems: 3, maxCharacters: 500) + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: runtimeStore + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Memory Context", + memoryContext: memoryContext + ) + + let restoredRuntime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: runtimeStore + )) + + let restoredThreads = try await restoredRuntime.restore().threads + let restoredContext = try await restoredRuntime.memoryContext(for: thread.id) + + XCTAssertEqual(restoredContext, memoryContext) + XCTAssertEqual(restoredThreads.first?.memoryContext, memoryContext) + } +} diff --git a/Tests/CodexKitTests/AgentRuntimeMessageTests.swift b/Tests/CodexKitTests/AgentRuntimeMessageTests.swift new file mode 100644 index 0000000..ce2c4a8 --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimeMessageTests.swift @@ -0,0 +1,334 @@ +import CodexKit +import XCTest + +// MARK: - Messaging + +extension AgentRuntimeTests { + func testSendMessageReturnsFinalAssistantMessageText() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Complete") + let reply = try await runtime.sendMessage( + UserMessageRequest(text: "Hello there"), + in: thread.id + ) + + XCTAssertEqual(reply, "Echo: Hello there") + } + + func testSendMessageExpectingStructuredTypeDecodesTypedResponse() async throws { + let backend = InMemoryAgentBackend( + structuredResponseText: #"{"reply":"Your order is already in transit.","priority":"high"}"# + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Structured") + let reply = try await runtime.sendMessage( + UserMessageRequest(text: "Draft a shipping reply."), + in: thread.id, + expecting: ShippingReplyDraft.self + ) + + XCTAssertEqual( + reply, + ShippingReplyDraft( + reply: "Your order is already in transit.", + priority: "high" + ) + ) + + let formats = await backend.receivedResponseFormats() + XCTAssertEqual(formats.last??.name, "shipping_reply_draft") + } + + func testStructuredDecodeFailureThrowsRuntimeError() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend( + structuredResponseText: #"{"unexpected":"value"}"# + ), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Structured Failure") + + await XCTAssertThrowsErrorAsync( + try await runtime.sendMessage( + UserMessageRequest(text: "Draft a shipping reply."), + in: thread.id, + expecting: ShippingReplyDraft.self + ) + ) { error in + let runtimeError = error as? AgentRuntimeError + XCTAssertEqual(runtimeError?.code, "structured_output_decoding_failed") + XCTAssertTrue(runtimeError?.message.contains("ShippingReplyDraft") == true) + } + } + + func testImportedContentInitializerBuildsMessageWithSharedURLs() async throws { + let importedContent = AgentImportedContent( + textSnippets: ["Customer says the package arrived damaged."], + urls: [URL(string: "https://example.com/delivery-update")!] + ) + + let request = UserMessageRequest( + prompt: "Summarize and draft a reply.", + importedContent: importedContent + ) + + XCTAssertTrue(request.text.contains("Summarize and draft a reply.")) + XCTAssertTrue(request.text.contains("https://example.com/delivery-update")) + XCTAssertTrue(request.text.contains("Customer says the package arrived damaged.")) + } + + func testImageOnlyMessageIsAcceptedAndPersisted() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Images") + let image = AgentImageAttachment.png(Data([0x89, 0x50, 0x4E, 0x47])) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "", images: [image]), + in: thread.id + ) + + let messages = await runtime.messages(for: thread.id) + XCTAssertEqual(messages.first?.images.count, 1) + XCTAssertEqual(messages.first?.role, .user) + } + + func testAssistantImagesAreCommittedToThreadHistory() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: ImageReplyAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Assistant Images") + let reply = try await runtime.sendMessage( + UserMessageRequest(text: "show me an image"), + in: thread.id + ) + + XCTAssertEqual(reply, "Attached 1 image") + + let messages = await runtime.messages(for: thread.id) + XCTAssertEqual(messages.last?.role, .assistant) + XCTAssertEqual(messages.last?.images.count, 1) + } + + func testRuntimeStreamsToolApprovalAndCompletion() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + tools: [ + .init( + definition: ToolDefinition( + name: "demo_lookup_profile", + description: "Lookup profile", + inputSchema: .object([:]), + approvalPolicy: .requiresApproval + ), + executor: AnyToolExecutor { invocation, _ in + .success(invocation: invocation, text: "demo-result") + } + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread() + let stream = try await runtime.streamMessage( + UserMessageRequest(text: "please use the tool"), + in: thread.id + ) + + var sawApproval = false + var sawToolResult = false + + for try await event in stream { + switch event { + case .approvalRequested: + sawApproval = true + case let .toolCallFinished(result): + sawToolResult = true + XCTAssertEqual(result.primaryText, "demo-result") + default: + break + } + } + + XCTAssertTrue(sawApproval) + XCTAssertTrue(sawToolResult) + } + + func testSendMessageRetriesUnauthorizedByRefreshingSession() async throws { + let authProvider = RotatingDemoAuthProvider() + let backend = UnauthorizedThenSuccessBackend() + let runtime = try AgentRuntime(configuration: .init( + authProvider: authProvider, + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Recovered Thread") + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Hello after refresh"), + in: thread.id + ) + + let refreshCount = await authProvider.refreshCount() + XCTAssertEqual(refreshCount, 1) + + let attemptedTokens = await backend.attemptedAccessTokens() + XCTAssertEqual(attemptedTokens.count, 2) + XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") + XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") + + let messages = await runtime.messages(for: thread.id) + XCTAssertEqual(messages.filter { $0.role == .assistant }.count, 1) + } + + func testCreateThreadRetriesUnauthorizedByRefreshingSession() async throws { + let authProvider = RotatingDemoAuthProvider() + let backend = UnauthorizedOnCreateThenSuccessBackend() + let runtime = try AgentRuntime(configuration: .init( + authProvider: authProvider, + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Recovered Thread") + XCTAssertEqual(thread.title, "Recovered Thread") + + let refreshCount = await authProvider.refreshCount() + XCTAssertEqual(refreshCount, 1) + + let attemptedTokens = await backend.attemptedAccessTokens() + XCTAssertEqual(attemptedTokens.count, 2) + XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") + XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") + } + + func testConfigurationRegistersInitialTools() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + tools: [ + .init( + definition: ToolDefinition( + name: "demo_lookup_profile", + description: "Lookup profile", + inputSchema: .object([:]), + approvalPolicy: .requiresApproval + ), + executor: AnyToolExecutor { invocation, _ in + .success(invocation: invocation, text: "demo-result") + } + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread() + let stream = try await runtime.streamMessage( + UserMessageRequest(text: "please use the tool"), + in: thread.id + ) + + var sawToolResult = false + + for try await event in stream { + if case let .toolCallFinished(result) = event { + sawToolResult = true + XCTAssertEqual(result.primaryText, "demo-result") + } + } + + XCTAssertTrue(sawToolResult) + } +} diff --git a/Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift b/Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift new file mode 100644 index 0000000..5ecd58a --- /dev/null +++ b/Tests/CodexKitTests/AgentRuntimePersonaSkillTests.swift @@ -0,0 +1,613 @@ +import CodexKit +import XCTest + +// MARK: - Personas And Skills + +extension AgentRuntimeTests { + func testThreadSkillsAreResolvedIntoInstructions() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + skills: [ + .init( + id: "health_coach", + name: "Health Coach", + instructions: "Coach users toward their daily step goals." + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Skills", + skillIDs: ["health_coach"] + ) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Give me a plan."), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + let resolved = try XCTUnwrap(instructions.last) + XCTAssertTrue(resolved.contains("Base host instructions.")) + XCTAssertTrue(resolved.contains("Thread Skills:")) + XCTAssertTrue(resolved.contains("[health_coach: Health Coach]")) + } + + func testTurnSkillOverrideAppliesOnlyToCurrentTurn() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + skills: [ + .init( + id: "health_coach", + name: "Health Coach", + instructions: "Coach users toward their daily step goals." + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Turn Skills") + + _ = try await runtime.sendMessage( + UserMessageRequest( + text: "First turn", + skillOverrideIDs: ["health_coach"] + ), + in: thread.id + ) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Second turn"), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + XCTAssertEqual(instructions.count, 2) + XCTAssertTrue(instructions[0].contains("[health_coach: Health Coach]")) + XCTAssertFalse(instructions[1].contains("[health_coach: Health Coach]")) + } + + func testSetSkillIDsAffectsFutureTurnsOnly() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + skills: [ + .init( + id: "health_coach", + name: "Health Coach", + instructions: "Coach users toward their daily step goals." + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(title: "Skill IDs") + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Before"), + in: thread.id + ) + + try await runtime.setSkillIDs(["health_coach"], for: thread.id) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "After"), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + XCTAssertEqual(instructions.count, 2) + XCTAssertFalse(instructions[0].contains("[health_coach: Health Coach]")) + XCTAssertTrue(instructions[1].contains("[health_coach: Health Coach]")) + } + + func testThreadPersonaUsesBackendBaseInstructionsWhenRuntimeBaseIsUnset() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let supportPersona = AgentPersonaStack(layers: [ + .init(name: "support", instructions: "Act as a support specialist.") + ]) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(personaStack: supportPersona) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Need help"), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + let resolved = try XCTUnwrap(instructions.last) + XCTAssertTrue(resolved.contains("Base host instructions.")) + XCTAssertTrue(resolved.contains("[support]")) + } + + func testTurnPersonaOverrideAppliesOnlyToCurrentTurn() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let supportPersona = AgentPersonaStack(layers: [ + .init(name: "support", instructions: "Act as a support specialist.") + ]) + let reviewerOverride = AgentPersonaStack(layers: [ + .init(name: "reviewer", instructions: "Call out risks first.") + ]) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(personaStack: supportPersona) + + _ = try await runtime.sendMessage( + UserMessageRequest( + text: "First", + personaOverride: reviewerOverride + ), + in: thread.id + ) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Second"), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + XCTAssertEqual(instructions.count, 2) + XCTAssertTrue(instructions[0].contains("[reviewer]")) + XCTAssertTrue(instructions[1].contains("[support]")) + XCTAssertFalse(instructions[1].contains("[reviewer]")) + } + + func testSetPersonaStackAffectsFutureTurnsOnly() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let supportPersona = AgentPersonaStack(layers: [ + .init(name: "support", instructions: "Act as a support specialist.") + ]) + let plannerPersona = AgentPersonaStack(layers: [ + .init(name: "planner", instructions: "Focus on sequencing and tradeoffs.") + ]) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(personaStack: supportPersona) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Before"), + in: thread.id + ) + + try await runtime.setPersonaStack(plannerPersona, for: thread.id) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "After"), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + XCTAssertEqual(instructions.count, 2) + XCTAssertTrue(instructions[0].contains("[support]")) + XCTAssertTrue(instructions[1].contains("[planner]")) + XCTAssertFalse(instructions[1].contains("[support]")) + } + + func testThreadPersonaStackPersistsAcrossRestore() async throws { + let runtimeStore = InMemoryRuntimeStateStore() + let supportPersona = AgentPersonaStack(layers: [ + .init(name: "support", instructions: "Act as a support specialist.") + ]) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: runtimeStore + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread(personaStack: supportPersona) + + let restoredRuntime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: runtimeStore + )) + + let restoredState = try await restoredRuntime.restore() + let restoredThread = try XCTUnwrap(restoredState.threads.first(where: { $0.id == thread.id })) + XCTAssertEqual(restoredThread.personaStack, supportPersona) + } + + func testSetPersonaStackThrowsForMissingThread() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + + await XCTAssertThrowsErrorAsync( + try await runtime.setPersonaStack( + AgentPersonaStack(layers: [.init(name: "missing", instructions: "nope")]), + for: "missing-thread" + ) + ) { error in + XCTAssertEqual(error as? AgentRuntimeError, .threadNotFound("missing-thread")) + } + } + + func testSetSkillIDsThrowsWhenSkillIsNotRegistered() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + let thread = try await runtime.createThread() + + await XCTAssertThrowsErrorAsync( + try await runtime.setSkillIDs(["travel_planner"], for: thread.id) + ) { error in + XCTAssertEqual( + error as? AgentRuntimeError, + .skillsNotFound(["travel_planner"]) + ) + } + } + + func testSkillPolicyBlocksDisallowedToolCalls() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + skills: [ + .init( + id: "strict_support", + name: "Strict Support", + instructions: "Answer directly.", + executionPolicy: .init( + allowedToolNames: ["allowed_tool"] + ) + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + try await runtime.registerTool( + ToolDefinition( + name: "demo_lookup_profile", + description: "Lookup profile", + inputSchema: .object([:]), + approvalPolicy: .automatic + ), + executor: AnyToolExecutor { invocation, _ in + .success(invocation: invocation, text: "profile-ok") + } + ) + + let thread = try await runtime.createThread( + title: "Strict Tool Policy", + skillIDs: ["strict_support"] + ) + + let stream = try await runtime.streamMessage( + UserMessageRequest(text: "please use the tool"), + in: thread.id + ) + for try await _ in stream {} + + let messages = await runtime.messages(for: thread.id) + let assistantText = messages + .filter { $0.role == .assistant } + .map(\.text) + .joined(separator: "\n") + XCTAssertTrue(assistantText.contains("not allowed by the active skill policy")) + } + + func testSkillPolicyFailsTurnWhenRequiredToolIsMissing() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + skills: [ + .init( + id: "requires_tool", + name: "Requires Tool", + instructions: "Use the required tool.", + executionPolicy: .init( + requiredToolNames: ["demo_lookup_profile"] + ) + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Required Tool", + skillIDs: ["requires_tool"] + ) + + let stream = try await runtime.streamMessage( + UserMessageRequest(text: "hello without tool"), + in: thread.id + ) + + var sawTurnFailed = false + var failureError: AgentRuntimeError? + + do { + for try await event in stream { + if case let .turnFailed(error) = event { + sawTurnFailed = true + failureError = error + } + } + XCTFail("Expected turn stream to throw when required tools are missing.") + } catch { + XCTAssertEqual((error as? AgentRuntimeError)?.code, "skill_required_tools_missing") + } + + XCTAssertTrue(sawTurnFailed) + XCTAssertEqual(failureError?.code, "skill_required_tools_missing") + } + + func testRuntimeRejectsSkillWithInvalidPolicyToolName() async throws { + XCTAssertThrowsError( + try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + skills: [ + .init( + id: "invalid_policy", + name: "Invalid Policy", + instructions: "Invalid tool name policy.", + executionPolicy: .init( + requiredToolNames: ["bad tool name"] + ) + ), + ] + )) + ) { error in + XCTAssertEqual((error as? AgentRuntimeError)?.code, "invalid_skill_tool_name") + } + } + + func testResolvedInstructionsPreviewIncludesThreadPersonaAndSkills() async throws { + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: InMemoryAgentBackend(baseInstructions: "Base host instructions."), + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore(), + skills: [ + .init( + id: "health_coach", + name: "Health Coach", + instructions: "Coach users toward their daily step goals." + ), + ] + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let thread = try await runtime.createThread( + title: "Preview", + personaStack: AgentPersonaStack(layers: [ + .init(name: "planner", instructions: "Act as a planning specialist.") + ]), + skillIDs: ["health_coach"] + ) + + let preview = try await runtime.resolvedInstructionsPreview( + for: thread.id, + request: UserMessageRequest(text: "Give me a plan.") + ) + + XCTAssertTrue(preview.contains("Base host instructions.")) + XCTAssertTrue(preview.contains("Thread Persona Layers:")) + XCTAssertTrue(preview.contains("[planner]")) + XCTAssertTrue(preview.contains("Thread Skills:")) + XCTAssertTrue(preview.contains("[health_coach: Health Coach]")) + } + + func testCreateThreadLoadsPersonaFromFileSource() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let personaText = "Act as a migration planning assistant focused on sequencing." + let personaFile = try temporaryFile( + with: personaText, + pathExtension: "txt" + ) + + let thread = try await runtime.createThread( + title: "Dynamic Persona", + personaSource: .file(personaFile) + ) + + let personaStack = try XCTUnwrap(thread.personaStack) + XCTAssertEqual(personaStack.layers.count, 1) + XCTAssertEqual(personaStack.layers[0].instructions, personaText) + } + + func testRegisterSkillFromFileSourceCanBeUsedInThread() async throws { + let backend = InMemoryAgentBackend( + baseInstructions: "Base host instructions." + ) + let runtime = try AgentRuntime(configuration: .init( + authProvider: DemoChatGPTAuthProvider(), + secureStore: KeychainSessionSecureStore( + service: "CodexKitTests.ChatGPTSession", + account: UUID().uuidString + ), + backend: backend, + approvalPresenter: AutoApprovalPresenter(), + stateStore: InMemoryRuntimeStateStore() + )) + + _ = try await runtime.restore() + _ = try await runtime.signIn() + + let skillJSON = """ + { + "id": "travel_planner", + "name": "Travel Planner", + "instructions": "Keep itineraries compact and transit-aware." + } + """ + let skillFile = try temporaryFile( + with: skillJSON, + pathExtension: "json" + ) + + _ = try await runtime.registerSkill( + from: .file(skillFile) + ) + + let thread = try await runtime.createThread( + title: "Travel", + skillIDs: ["travel_planner"] + ) + + _ = try await runtime.sendMessage( + UserMessageRequest(text: "Plan my day"), + in: thread.id + ) + + let instructions = await backend.receivedInstructions() + let resolved = try XCTUnwrap(instructions.last) + XCTAssertTrue(resolved.contains("[travel_planner: Travel Planner]")) + } +} diff --git a/Tests/CodexKitTests/AgentRuntimeTests.swift b/Tests/CodexKitTests/AgentRuntimeTests.swift index 2e81706..8e3b0b7 100644 --- a/Tests/CodexKitTests/AgentRuntimeTests.swift +++ b/Tests/CodexKitTests/AgentRuntimeTests.swift @@ -1,14 +1,16 @@ import CodexKit import XCTest -private struct AutoApprovalPresenter: ApprovalPresenting { +// MARK: - Shared Fixtures + +struct AutoApprovalPresenter: ApprovalPresenting { func requestApproval(_ request: ApprovalRequest) async throws -> ApprovalDecision { XCTAssertEqual(request.toolInvocation.toolName, "demo_lookup_profile") return .approved } } -private struct ShippingReplyDraft: AgentStructuredOutput, Equatable { +struct ShippingReplyDraft: AgentStructuredOutput, Equatable { let reply: String let priority: String @@ -27,1308 +29,7 @@ private struct ShippingReplyDraft: AgentStructuredOutput, Equatable { } final class AgentRuntimeTests: XCTestCase { - func testSendMessageReturnsFinalAssistantMessageText() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Complete") - let reply = try await runtime.sendMessage( - UserMessageRequest(text: "Hello there"), - in: thread.id - ) - - XCTAssertEqual(reply, "Echo: Hello there") - } - - func testSendMessageExpectingStructuredTypeDecodesTypedResponse() async throws { - let backend = InMemoryAgentBackend( - structuredResponseText: #"{"reply":"Your order is already in transit.","priority":"high"}"# - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Structured") - let reply = try await runtime.sendMessage( - UserMessageRequest(text: "Draft a shipping reply."), - in: thread.id, - expecting: ShippingReplyDraft.self - ) - - XCTAssertEqual( - reply, - ShippingReplyDraft( - reply: "Your order is already in transit.", - priority: "high" - ) - ) - - let formats = await backend.receivedResponseFormats() - XCTAssertEqual(formats.last??.name, "shipping_reply_draft") - } - - func testStructuredDecodeFailureThrowsRuntimeError() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend( - structuredResponseText: #"{"unexpected":"value"}"# - ), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Structured Failure") - - await XCTAssertThrowsErrorAsync( - try await runtime.sendMessage( - UserMessageRequest(text: "Draft a shipping reply."), - in: thread.id, - expecting: ShippingReplyDraft.self - ) - ) { error in - let runtimeError = error as? AgentRuntimeError - XCTAssertEqual(runtimeError?.code, "structured_output_decoding_failed") - XCTAssertTrue(runtimeError?.message.contains("ShippingReplyDraft") == true) - } - } - - func testImportedContentInitializerBuildsMessageWithSharedURLs() async throws { - let content = AgentImportedContent( - textSnippets: ["Summarize this article."], - urls: [try XCTUnwrap(URL(string: "https://example.com/story"))], - images: [.png(Data([0x89, 0x50, 0x4E, 0x47]))] - ) - - let request = UserMessageRequest( - prompt: "Give me a concise summary.", - importedContent: content - ) - - XCTAssertTrue(request.hasContent) - XCTAssertEqual(request.images.count, 1) - XCTAssertTrue(request.text.contains("Give me a concise summary.")) - XCTAssertTrue(request.text.contains("Summarize this article.")) - XCTAssertTrue(request.text.contains("Shared URLs:")) - XCTAssertTrue(request.text.contains("https://example.com/story")) - } - - func testThreadSkillsAreResolvedIntoInstructions() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "health_coach", - name: "Health Coach", - instructions: "Coach users toward their daily step goals." - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Health", - skillIDs: ["health_coach"] - ) - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "Give me my plan."), - in: thread.id - ) - for try await _ in stream {} - - let instructions = await backend.receivedInstructions() - let resolvedInstructions = try XCTUnwrap(instructions.last) - XCTAssertTrue(resolvedInstructions.contains("Thread Skills:")) - XCTAssertTrue(resolvedInstructions.contains("[health_coach: Health Coach]")) - } - - func testTurnSkillOverrideAppliesOnlyToCurrentTurn() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "travel_planner", - name: "Travel Planner", - instructions: "Plan practical itineraries." - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread() - - let firstStream = try await runtime.streamMessage( - UserMessageRequest( - text: "Plan my trip.", - skillOverrideIDs: ["travel_planner"] - ), - in: thread.id - ) - for try await _ in firstStream {} - - let secondStream = try await runtime.streamMessage( - UserMessageRequest(text: "Now answer normally."), - in: thread.id - ) - for try await _ in secondStream {} - - let instructions = await backend.receivedInstructions() - XCTAssertEqual(instructions.count, 2) - XCTAssertTrue(instructions[0].contains("Turn Skill Override:")) - XCTAssertTrue(instructions[0].contains("[travel_planner: Travel Planner]")) - XCTAssertFalse(instructions[1].contains("Turn Skill Override:")) - XCTAssertFalse(instructions[1].contains("[travel_planner: Travel Planner]")) - } - - func testSetSkillIDsAffectsFutureTurnsOnly() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "health_coach", - name: "Health Coach", - instructions: "Coach users toward step goals." - ), - .init( - id: "travel_planner", - name: "Travel Planner", - instructions: "Plan practical itineraries." - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Skills", - skillIDs: ["health_coach"] - ) - - let firstStream = try await runtime.streamMessage( - UserMessageRequest(text: "What should I walk today?"), - in: thread.id - ) - for try await _ in firstStream {} - - try await runtime.setSkillIDs(["travel_planner"], for: thread.id) - - let secondStream = try await runtime.streamMessage( - UserMessageRequest(text: "Plan a weekend trip."), - in: thread.id - ) - for try await _ in secondStream {} - - let instructions = await backend.receivedInstructions() - XCTAssertEqual(instructions.count, 2) - XCTAssertTrue(instructions[0].contains("[health_coach: Health Coach]")) - XCTAssertFalse(instructions[0].contains("[travel_planner: Travel Planner]")) - XCTAssertTrue(instructions[1].contains("[travel_planner: Travel Planner]")) - XCTAssertFalse(instructions[1].contains("[health_coach: Health Coach]")) - } - - func testThreadPersonaUsesBackendBaseInstructionsWhenRuntimeBaseIsUnset() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let supportPersona = AgentPersonaStack(layers: [ - .init( - name: "domain", - instructions: "You are an expert customer support agent for a shipping app." - ), - .init( - name: "style", - instructions: "Be concise, calm, and action-oriented." - ), - ]) - - let thread = try await runtime.createThread( - title: "Support Chat", - personaStack: supportPersona - ) - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "How do I track my order?"), - in: thread.id - ) - - for try await _ in stream {} - - let receivedInstructions = await backend.receivedInstructions() - let resolvedInstructions = try XCTUnwrap(receivedInstructions.last) - XCTAssertTrue(resolvedInstructions.contains("Base host instructions.")) - XCTAssertTrue(resolvedInstructions.contains("Thread Persona Layers:")) - XCTAssertTrue(resolvedInstructions.contains("[domain]")) - XCTAssertTrue(resolvedInstructions.contains("[style]")) - XCTAssertLessThan( - try XCTUnwrap(resolvedInstructions.range(of: "Base host instructions.")?.lowerBound), - try XCTUnwrap(resolvedInstructions.range(of: "Thread Persona Layers:")?.lowerBound) - ) - } - - func testTurnPersonaOverrideAppliesOnlyToCurrentTurn() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let supportPersona = AgentPersonaStack(layers: [ - .init( - name: "support", - instructions: "Act as a calm support specialist." - ), - ]) - let reviewerOverride = AgentPersonaStack(layers: [ - .init( - name: "reviewer", - instructions: "For this reply only, act as a strict reviewer and call out risks first." - ), - ]) - - let thread = try await runtime.createThread(personaStack: supportPersona) - - let firstStream = try await runtime.streamMessage( - UserMessageRequest( - text: "Review this architecture.", - personaOverride: reviewerOverride - ), - in: thread.id - ) - for try await _ in firstStream {} - - let secondStream = try await runtime.streamMessage( - UserMessageRequest(text: "Now just answer normally."), - in: thread.id - ) - for try await _ in secondStream {} - - let instructions = await backend.receivedInstructions() - XCTAssertEqual(instructions.count, 2) - XCTAssertTrue(instructions[0].contains("Turn Persona Override:")) - XCTAssertTrue(instructions[0].contains("[reviewer]")) - XCTAssertFalse(instructions[1].contains("Turn Persona Override:")) - XCTAssertFalse(instructions[1].contains("[reviewer]")) - XCTAssertTrue(instructions[1].contains("[support]")) - } - - func testSetPersonaStackAffectsFutureTurnsOnly() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let supportPersona = AgentPersonaStack(layers: [ - .init(name: "support", instructions: "Act as a support agent.") - ]) - let plannerPersona = AgentPersonaStack(layers: [ - .init(name: "planner", instructions: "Act as a careful technical planner.") - ]) - - let thread = try await runtime.createThread(personaStack: supportPersona) - - let firstStream = try await runtime.streamMessage( - UserMessageRequest(text: "Help me with support."), - in: thread.id - ) - for try await _ in firstStream {} - - try await runtime.setPersonaStack(plannerPersona, for: thread.id) - - let secondStream = try await runtime.streamMessage( - UserMessageRequest(text: "Plan the migration."), - in: thread.id - ) - for try await _ in secondStream {} - - let instructions = await backend.receivedInstructions() - XCTAssertEqual(instructions.count, 2) - XCTAssertTrue(instructions[0].contains("[support]")) - XCTAssertFalse(instructions[0].contains("[planner]")) - XCTAssertTrue(instructions[1].contains("[planner]")) - XCTAssertFalse(instructions[1].contains("[support]")) - } - - func testThreadPersonaStackPersistsAcrossRestore() async throws { - let stateStore = InMemoryRuntimeStateStore() - let secureStore = KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ) - let personaStack = AgentPersonaStack(layers: [ - .init(name: "planner", instructions: "Act as a careful technical planner.") - ]) - - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: secureStore, - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: stateStore - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - let thread = try await runtime.createThread( - title: "Planning", - personaStack: personaStack - ) - - let restoredRuntime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: secureStore, - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: stateStore - )) - - _ = try await restoredRuntime.restore() - - let restoredStack = try await restoredRuntime.personaStack(for: thread.id) - let restoredThreads = await restoredRuntime.threads() - XCTAssertEqual(restoredStack, personaStack) - XCTAssertEqual(restoredThreads.first?.personaStack, personaStack) - } - - func testSetPersonaStackThrowsForMissingThread() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - - await XCTAssertThrowsErrorAsync( - try await runtime.setPersonaStack( - AgentPersonaStack(layers: [ - .init(name: "planner", instructions: "Act as a planner.") - ]), - for: "missing-thread" - ) - ) { error in - XCTAssertEqual(error as? AgentRuntimeError, .threadNotFound("missing-thread")) - } - } - - func testSetSkillIDsThrowsWhenSkillIsNotRegistered() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - let thread = try await runtime.createThread() - - await XCTAssertThrowsErrorAsync( - try await runtime.setSkillIDs(["travel_planner"], for: thread.id) - ) { error in - XCTAssertEqual( - error as? AgentRuntimeError, - .skillsNotFound(["travel_planner"]) - ) - } - } - - func testSkillPolicyBlocksDisallowedToolCalls() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "strict_support", - name: "Strict Support", - instructions: "Answer directly.", - executionPolicy: .init( - allowedToolNames: ["allowed_tool"] - ) - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - try await runtime.registerTool( - ToolDefinition( - name: "demo_lookup_profile", - description: "Lookup profile", - inputSchema: .object([:]), - approvalPolicy: .automatic - ), - executor: AnyToolExecutor { invocation, _ in - .success(invocation: invocation, text: "profile-ok") - } - ) - - let thread = try await runtime.createThread( - title: "Strict Tool Policy", - skillIDs: ["strict_support"] - ) - - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "please use the tool"), - in: thread.id - ) - for try await _ in stream {} - - let messages = await runtime.messages(for: thread.id) - let assistantText = messages - .filter { $0.role == .assistant } - .map(\.text) - .joined(separator: "\n") - XCTAssertTrue(assistantText.contains("not allowed by the active skill policy")) - } - - func testSkillPolicyFailsTurnWhenRequiredToolIsMissing() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "requires_tool", - name: "Requires Tool", - instructions: "Use the required tool.", - executionPolicy: .init( - requiredToolNames: ["demo_lookup_profile"] - ) - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Required Tool", - skillIDs: ["requires_tool"] - ) - - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "hello without tool"), - in: thread.id - ) - - var sawTurnFailed = false - var failureError: AgentRuntimeError? - - do { - for try await event in stream { - if case let .turnFailed(error) = event { - sawTurnFailed = true - failureError = error - } - } - XCTFail("Expected turn stream to throw when required tools are missing.") - } catch { - XCTAssertEqual((error as? AgentRuntimeError)?.code, "skill_required_tools_missing") - } - - XCTAssertTrue(sawTurnFailed) - XCTAssertEqual(failureError?.code, "skill_required_tools_missing") - } - - func testRuntimeRejectsSkillWithInvalidPolicyToolName() async throws { - XCTAssertThrowsError( - try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "invalid_policy", - name: "Invalid Policy", - instructions: "Invalid tool name policy.", - executionPolicy: .init( - requiredToolNames: ["bad tool name"] - ) - ), - ] - )) - ) { error in - XCTAssertEqual((error as? AgentRuntimeError)?.code, "invalid_skill_tool_name") - } - } - - func testResolvedInstructionsPreviewIncludesThreadPersonaAndSkills() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(baseInstructions: "Base host instructions."), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - skills: [ - .init( - id: "health_coach", - name: "Health Coach", - instructions: "Coach users toward their daily step goals." - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Preview", - personaStack: AgentPersonaStack(layers: [ - .init(name: "planner", instructions: "Act as a planning specialist.") - ]), - skillIDs: ["health_coach"] - ) - - let preview = try await runtime.resolvedInstructionsPreview( - for: thread.id, - request: UserMessageRequest(text: "Give me a plan.") - ) - - XCTAssertTrue(preview.contains("Base host instructions.")) - XCTAssertTrue(preview.contains("Thread Persona Layers:")) - XCTAssertTrue(preview.contains("[planner]")) - XCTAssertTrue(preview.contains("Thread Skills:")) - XCTAssertTrue(preview.contains("[health_coach: Health Coach]")) - } - - func testRuntimeInjectsRelevantMemoryIntoInstructionsAndPreviewMatches() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let store = InMemoryMemoryStore(initialRecords: [ - MemoryRecord( - namespace: "demo-assistant", - scope: "feature:health-coach", - kind: "preference", - summary: "Health Coach should use direct accountability when the user is behind on steps.", - evidence: ["The user responds better to blunt coaching than soft encouragement."], - importance: 0.9, - tags: ["steps"] - ), - MemoryRecord( - namespace: "demo-assistant", - scope: "feature:travel-planner", - kind: "preference", - summary: "Travel Planner should keep itineraries compact and transit-aware.", - importance: 0.8 - ), - ]) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - memory: .init(store: store) - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Memory", - memoryContext: AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach"] - ) - ) - - let preview = try await runtime.memoryQueryPreview( - for: thread.id, - request: UserMessageRequest(text: "How should the health coach respond when the user is behind on steps?") - ) - XCTAssertEqual(preview?.matches.map(\.record.scope.rawValue), ["feature:health-coach"]) - - _ = try await runtime.sendMessage( - UserMessageRequest(text: "How should the health coach respond when the user is behind on steps?"), - in: thread.id - ) - - let instructions = await backend.receivedInstructions() - let resolved = try XCTUnwrap(instructions.last) - XCTAssertTrue(resolved.contains("Relevant Memory:")) - XCTAssertTrue(resolved.contains("Health Coach should use direct accountability when the user is behind on steps.")) - XCTAssertFalse(resolved.contains("Travel Planner should keep itineraries compact and transit-aware.")) - } - - func testRuntimeMemorySelectionCanReplaceOrDisableThreadDefaults() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let store = InMemoryMemoryStore(initialRecords: [ - MemoryRecord( - namespace: "demo-assistant", - scope: "feature:health-coach", - kind: "preference", - summary: "Health Coach preference." - ), - MemoryRecord( - namespace: "demo-assistant", - scope: "feature:travel-planner", - kind: "preference", - summary: "Travel Planner preference." - ), - ]) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - memory: .init(store: store) - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Scoped Memory", - memoryContext: AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach"] - ) - ) - - _ = try await runtime.sendMessage( - UserMessageRequest( - text: "Use travel planner memory instead.", - memorySelection: MemorySelection( - mode: .replace, - scopes: ["feature:travel-planner"] - ) - ), - in: thread.id - ) - - _ = try await runtime.sendMessage( - UserMessageRequest( - text: "Now disable memory.", - memorySelection: MemorySelection(mode: .disable) - ), - in: thread.id - ) - - let instructions = await backend.receivedInstructions() - XCTAssertEqual(instructions.count, 2) - XCTAssertTrue(instructions[0].contains("Travel Planner preference.")) - XCTAssertFalse(instructions[0].contains("Health Coach preference.")) - XCTAssertFalse(instructions[1].contains("Relevant Memory:")) - } - - func testRuntimeCanAutomaticallyCaptureMemoriesFromTranscript() async throws { - let backend = InMemoryAgentBackend( - structuredResponseText: """ - {"memories":[{"summary":"Health Coach should use direct accountability when step pace is low.","scope":"feature:health-coach","kind":"preference","evidence":["The user asked for blunt reminders when behind on steps."],"importance":0.92,"tags":["steps","tone"],"relatedIDs":["goal-10000"],"dedupeKey":"health-coach-direct-accountability"},{"summary":"Travel Planner should keep itineraries compact and transit-aware.","scope":"feature:travel-planner","kind":"preference","evidence":["The user dislikes sprawling travel plans."],"importance":0.81,"tags":["travel"],"relatedIDs":["travel-style-compact"],"dedupeKey":"travel-planner-compact-itinerary"}]} - """ - ) - let store = InMemoryMemoryStore() - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - memory: .init(store: store) - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Auto Memory", - memoryContext: AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach", "feature:travel-planner"] - ) - ) - - let capture = try await runtime.captureMemories( - from: .text(""" - User: Be direct when I am behind on steps. - User: Keep travel itineraries compact and transit-aware. - """), - for: thread.id, - options: .init( - defaults: .init(namespace: "demo-assistant"), - maxMemories: 3 - ) - ) - - XCTAssertEqual(capture.records.count, 2) - XCTAssertEqual(capture.records.map(\.scope.rawValue).sorted(), ["feature:health-coach", "feature:travel-planner"]) - - let stored = try await store.query( - MemoryQuery( - namespace: "demo-assistant", - scopes: ["feature:health-coach", "feature:travel-planner"], - text: "direct steps transit itinerary", - limit: 10, - maxCharacters: 1000 - ) - ) - XCTAssertEqual(stored.matches.count, 2) - - let formats = await backend.receivedResponseFormats() - XCTAssertEqual(formats.last??.name, "memory_extraction_batch") - } - - func testRuntimeCanAutomaticallyCaptureMemoryAfterSuccessfulTurn() async throws { - let backend = InMemoryAgentBackend( - structuredResponseText: """ - {"memories":[{"summary":"Health Coach should use direct accountability when the user falls behind on steps.","scope":"feature:health-coach","kind":"preference","evidence":["The user said blunt reminders work better than soft encouragement."],"importance":0.94,"tags":["steps","tone"],"relatedIDs":["goal-10000"],"dedupeKey":"health-coach-auto-capture"}]} - """ - ) - let store = InMemoryMemoryStore() - let observer = RecordingMemoryObserver() - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - memory: .init( - store: store, - observer: observer, - automaticCapturePolicy: .init( - source: .lastTurn, - options: .init( - defaults: .init(namespace: "demo-assistant"), - maxMemories: 2 - ) - ) - ) - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Auto Policy", - memoryContext: AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach"] - ) - ) - - _ = try await runtime.sendMessage( - UserMessageRequest(text: "If I am behind on steps, be direct and blunt with me."), - in: thread.id - ) - - let stored = try await store.query( - MemoryQuery( - namespace: "demo-assistant", - scopes: ["feature:health-coach"], - text: "direct blunt steps", - limit: 10, - maxCharacters: 1000 - ) - ) - XCTAssertEqual(stored.matches.count, 1) - XCTAssertEqual(stored.matches[0].record.dedupeKey, "health-coach-auto-capture") - - let formats = await backend.receivedResponseFormats() - XCTAssertEqual(formats.count, 2) - XCTAssertNil(formats.first!) - XCTAssertEqual(formats.last??.name, "memory_extraction_batch") - - let events = await observer.events() - let captureEvents = events.compactMap { event -> (String, String?, Int?)? in - switch event { - case let .captureStarted(threadID, sourceDescription): - return (threadID, sourceDescription, nil) - case let .captureSucceeded(threadID, result): - return (threadID, nil, result.records.count) - default: - return nil - } - } - - XCTAssertEqual(captureEvents.count, 2) - XCTAssertEqual(captureEvents[0].0, thread.id) - XCTAssertEqual(captureEvents[0].1, "last_turn") - XCTAssertEqual(captureEvents[1].0, thread.id) - XCTAssertEqual(captureEvents[1].2, 1) - } - - func testRuntimeGracefullyDegradesWhenMemoryStoreFails() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - memory: .init(store: ThrowingMemoryStore()) - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Graceful", - memoryContext: AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach"] - ) - ) - - _ = try await runtime.sendMessage( - UserMessageRequest(text: "This should still work."), - in: thread.id - ) - - let instructions = await backend.receivedInstructions() - let resolved = try XCTUnwrap(instructions.last) - XCTAssertFalse(resolved.contains("Relevant Memory:")) - } - - func testRuntimeReportsMemoryObservationEvents() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let store = InMemoryMemoryStore(initialRecords: [ - MemoryRecord( - namespace: "demo-assistant", - scope: "feature:health-coach", - kind: "preference", - summary: "Observed memory." - ), - ]) - let observer = RecordingMemoryObserver() - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - memory: .init( - store: store, - observer: observer - ) - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Observed", - memoryContext: AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach"] - ) - ) - - _ = try await runtime.sendMessage( - UserMessageRequest(text: "Use memory."), - in: thread.id - ) - - let events = await observer.events() - XCTAssertEqual(events.count, 2) - guard case let .queryStarted(startedQuery) = events[0] else { - return XCTFail("Expected queryStarted event.") - } - XCTAssertEqual(startedQuery.namespace, "demo-assistant") - guard case let .querySucceeded(_, result) = events[1] else { - return XCTFail("Expected querySucceeded event.") - } - XCTAssertEqual(result.matches.count, 1) - } - - func testRuntimeProvidesThreadAwareMemoryWriterDefaults() async throws { - let store = InMemoryMemoryStore() - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - memory: .init(store: store) - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread( - title: "Thread Memory Writer", - memoryContext: AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach"], - kinds: ["preference"], - tags: ["steps"], - relatedIDs: ["goal-10000"] - ) - ) - - let writer = try await runtime.memoryWriter(for: thread.id) - let record = try await writer.put( - MemoryDraft( - summary: "The demo user responds better to direct step reminders." - ) - ) - - XCTAssertEqual(record.namespace, "demo-assistant") - XCTAssertEqual(record.scope, "feature:health-coach") - XCTAssertEqual(record.kind, "preference") - XCTAssertEqual(record.tags, ["steps"]) - XCTAssertEqual(record.relatedIDs, ["goal-10000"]) - } - - func testRuntimeMemoryWriterThrowsWhenMemoryIsNotConfigured() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - - await XCTAssertThrowsErrorAsync(try await runtime.memoryWriter()) { error in - XCTAssertEqual(error as? AgentRuntimeError, .memoryNotConfigured()) - } - } - - func testResolvedInstructionsPreviewThrowsForMissingThread() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - - await XCTAssertThrowsErrorAsync( - try await runtime.resolvedInstructionsPreview( - for: "missing-thread", - request: UserMessageRequest(text: "hello") - ) - ) { error in - XCTAssertEqual(error as? AgentRuntimeError, .threadNotFound("missing-thread")) - } - } - - func testCreateThreadLoadsPersonaFromFileSource() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let personaText = "Act as a migration planning assistant focused on sequencing." - let personaFile = try temporaryFile( - with: personaText, - pathExtension: "txt" - ) - - let thread = try await runtime.createThread( - title: "Dynamic Persona", - personaSource: .file(personaFile) - ) - - let personaStack = try XCTUnwrap(thread.personaStack) - XCTAssertEqual(personaStack.layers.count, 1) - XCTAssertEqual(personaStack.layers[0].instructions, personaText) - - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "Plan this migration."), - in: thread.id - ) - for try await _ in stream {} - - let instructions = await backend.receivedInstructions() - let resolved = try XCTUnwrap(instructions.last) - XCTAssertTrue(resolved.contains("Thread Persona Layers:")) - XCTAssertTrue(resolved.contains(personaText)) - } - - func testRegisterSkillFromFileSourceCanBeUsedInThread() async throws { - let backend = InMemoryAgentBackend( - baseInstructions: "Base host instructions." - ) - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let skillJSON = """ - { - "id": "hydration_coach", - "name": "Hydration Coach", - "instructions": "Drive hydration execution with concrete water targets." - } - """ - let skillFile = try temporaryFile( - with: skillJSON, - pathExtension: "json" - ) - - _ = try await runtime.registerSkill(from: .file(skillFile)) - let registeredSkill = await runtime.skill(for: "hydration_coach") - XCTAssertNotNil(registeredSkill) - - let thread = try await runtime.createThread( - title: "Hydration", - skillIDs: ["hydration_coach"] - ) - - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "Give me today's hydration plan."), - in: thread.id - ) - for try await _ in stream {} - - let instructions = await backend.receivedInstructions() - let resolved = try XCTUnwrap(instructions.last) - XCTAssertTrue(resolved.contains("Thread Skills:")) - XCTAssertTrue(resolved.contains("[hydration_coach: Hydration Coach]")) - } - - func testImageOnlyMessageIsAcceptedAndPersisted() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread() - let image = AgentImageAttachment.png(Data([0x89, 0x50, 0x4E, 0x47])) - - let stream = try await runtime.streamMessage( - UserMessageRequest( - text: "", - images: [image] - ), - in: thread.id - ) - for try await _ in stream {} - - let messages = await runtime.messages(for: thread.id) - let userMessage = try XCTUnwrap(messages.first(where: { $0.role == .user })) - XCTAssertEqual(userMessage.images, [image]) - XCTAssertEqual(userMessage.text, "") - XCTAssertEqual(userMessage.displayText, "Attached 1 image") - - let threads = await runtime.threads() - let updatedThread = try XCTUnwrap(threads.first(where: { $0.id == thread.id })) - XCTAssertEqual(updatedThread.title, "Image message") - } - - func testAssistantImagesAreCommittedToThreadHistory() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: ImageReplyAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread() - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "Generate an image"), - in: thread.id - ) - for try await _ in stream {} - - let messages = await runtime.messages(for: thread.id) - let assistantMessage = try XCTUnwrap(messages.first(where: { $0.role == .assistant })) - XCTAssertEqual(assistantMessage.images.count, 1) - XCTAssertEqual(assistantMessage.images.first?.mimeType, "image/png") - } + // MARK: Legacy State func testRestoreDecodesLegacyStateWithoutPersonaOrImages() async throws { let legacyStateJSON = """ @@ -1367,220 +68,7 @@ final class AgentRuntimeTests: XCTestCase { XCTAssertEqual(state.messagesByThread["thread-1"]?.first?.text, "Hello from legacy state") } - func testThreadMemoryContextPersistsAcrossRestore() async throws { - let stateStore = InMemoryRuntimeStateStore() - let secureStore = KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ) - let memoryContext = AgentMemoryContext( - namespace: "demo-assistant", - scopes: ["feature:health-coach", "feature:travel-planner"], - readBudget: .init(maxItems: 4, maxCharacters: 800) - ) - - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: secureStore, - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: stateStore - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - let thread = try await runtime.createThread( - title: "Memory Restore", - memoryContext: memoryContext - ) - - let restoredRuntime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: secureStore, - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: stateStore - )) - - _ = try await restoredRuntime.restore() - - let restoredContext = try await restoredRuntime.memoryContext(for: thread.id) - let restoredThreads = await restoredRuntime.threads() - XCTAssertEqual(restoredContext, memoryContext) - XCTAssertEqual(restoredThreads.first?.memoryContext, memoryContext) - } - - func testRuntimeStreamsToolApprovalAndCompletion() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - _ = try await runtime.restore() - _ = try await runtime.signIn() - - try await runtime.replaceTool( - ToolDefinition( - name: "demo_lookup_profile", - description: "Lookup profile", - inputSchema: .object([:]), - approvalPolicy: .requiresApproval - ), - executor: AnyToolExecutor { invocation, _ in - .success(invocation: invocation, text: "demo-result") - } - ) - - let thread = try await runtime.createThread() - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "please use the tool"), - in: thread.id - ) - - var sawApproval = false - var sawToolResult = false - var sawTurnCompleted = false - - for try await event in stream { - switch event { - case .approvalRequested: - sawApproval = true - case let .toolCallFinished(result): - sawToolResult = true - XCTAssertEqual(result.primaryText, "demo-result") - case .turnCompleted: - sawTurnCompleted = true - default: - break - } - } - - XCTAssertTrue(sawApproval) - XCTAssertTrue(sawToolResult) - XCTAssertTrue(sawTurnCompleted) - - let messages = await runtime.messages(for: thread.id) - XCTAssertEqual(messages.filter { $0.role == .user }.count, 1) - XCTAssertEqual(messages.filter { $0.role == .assistant }.count, 1) - } - - func testSendMessageRetriesUnauthorizedByRefreshingSession() async throws { - let authProvider = RotatingDemoAuthProvider() - let backend = UnauthorizedThenSuccessBackend() - let runtime = try AgentRuntime(configuration: .init( - authProvider: authProvider, - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - let thread = try await runtime.createThread() - - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "Hello"), - in: thread.id - ) - for try await _ in stream {} - - let refreshCount = await authProvider.refreshCount() - XCTAssertEqual(refreshCount, 1) - - let attemptedTokens = await backend.attemptedAccessTokens() - XCTAssertEqual(attemptedTokens.count, 2) - XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") - XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") - - let messages = await runtime.messages(for: thread.id) - XCTAssertEqual(messages.filter { $0.role == .assistant }.count, 1) - } - - func testCreateThreadRetriesUnauthorizedByRefreshingSession() async throws { - let authProvider = RotatingDemoAuthProvider() - let backend = UnauthorizedOnCreateThenSuccessBackend() - let runtime = try AgentRuntime(configuration: .init( - authProvider: authProvider, - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: backend, - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore() - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread(title: "Recovered Thread") - XCTAssertEqual(thread.title, "Recovered Thread") - - let refreshCount = await authProvider.refreshCount() - XCTAssertEqual(refreshCount, 1) - - let attemptedTokens = await backend.attemptedAccessTokens() - XCTAssertEqual(attemptedTokens.count, 2) - XCTAssertEqual(attemptedTokens[0], "demo-access-token-initial") - XCTAssertEqual(attemptedTokens[1], "demo-access-token-refreshed-1") - } - - func testConfigurationRegistersInitialTools() async throws { - let runtime = try AgentRuntime(configuration: .init( - authProvider: DemoChatGPTAuthProvider(), - secureStore: KeychainSessionSecureStore( - service: "CodexKitTests.ChatGPTSession", - account: UUID().uuidString - ), - backend: InMemoryAgentBackend(), - approvalPresenter: AutoApprovalPresenter(), - stateStore: InMemoryRuntimeStateStore(), - tools: [ - .init( - definition: ToolDefinition( - name: "demo_lookup_profile", - description: "Lookup profile", - inputSchema: .object([:]), - approvalPolicy: .requiresApproval - ), - executor: AnyToolExecutor { invocation, _ in - .success(invocation: invocation, text: "demo-result") - } - ), - ] - )) - - _ = try await runtime.restore() - _ = try await runtime.signIn() - - let thread = try await runtime.createThread() - let stream = try await runtime.streamMessage( - UserMessageRequest(text: "please use the tool"), - in: thread.id - ) - - var sawToolResult = false - - for try await event in stream { - if case let .toolCallFinished(result) = event { - sawToolResult = true - XCTAssertEqual(result.primaryText, "demo-result") - } - } - - XCTAssertTrue(sawToolResult) - } - - private func temporaryFile( + func temporaryFile( with content: String, pathExtension: String ) throws -> URL { @@ -1592,7 +80,9 @@ final class AgentRuntimeTests: XCTestCase { } } -private actor RotatingDemoAuthProvider: ChatGPTAuthProviding { +// MARK: - Backend/Test Doubles + +actor RotatingDemoAuthProvider: ChatGPTAuthProviding { private var refreshInvocationCount = 0 func signInInteractively() async throws -> ChatGPTSession { @@ -1629,7 +119,7 @@ private actor RotatingDemoAuthProvider: ChatGPTAuthProviding { } } -private actor UnauthorizedThenSuccessBackend: AgentBackend { +actor UnauthorizedThenSuccessBackend: AgentBackend { private var didThrowUnauthorized = false private var accessTokensByAttempt: [String] = [] @@ -1669,7 +159,7 @@ private actor UnauthorizedThenSuccessBackend: AgentBackend { } } -private actor UnauthorizedOnCreateThenSuccessBackend: AgentBackend { +actor UnauthorizedOnCreateThenSuccessBackend: AgentBackend { private var didThrowUnauthorized = false private var accessTokensByAttempt: [String] = [] @@ -1709,7 +199,7 @@ private actor UnauthorizedOnCreateThenSuccessBackend: AgentBackend { } } -private actor ImageReplyAgentBackend: AgentBackend { +actor ImageReplyAgentBackend: AgentBackend { func createThread(session _: ChatGPTSession) async throws -> AgentThread { AgentThread(id: UUID().uuidString) } @@ -1731,7 +221,7 @@ private actor ImageReplyAgentBackend: AgentBackend { } } -private final class ImageReplyTurn: AgentTurnStreaming, @unchecked Sendable { +final class ImageReplyTurn: AgentTurnStreaming, @unchecked Sendable { let events: AsyncThrowingStream init(threadID: String) { @@ -1769,7 +259,7 @@ private final class ImageReplyTurn: AgentTurnStreaming, @unchecked Sendable { ) async throws {} } -private actor ThrowingMemoryStore: MemoryStoring { +actor ThrowingMemoryStore: MemoryStoring { func put(_ record: MemoryRecord) async throws {} func putMany(_ records: [MemoryRecord]) async throws {} @@ -1816,7 +306,7 @@ private actor ThrowingMemoryStore: MemoryStoring { } } -private actor RecordingMemoryObserver: MemoryObserving { +actor RecordingMemoryObserver: MemoryObserving { private var observedEvents: [MemoryObservationEvent] = [] func handle(event: MemoryObservationEvent) async {