From 7bd74bc107861d2ec468faf42997d7bb29f40f42 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 02:48:04 +0000 Subject: [PATCH 01/75] Add initial pkg/ai scaffold and parity unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/api_registry.go | 75 +++ pkg/ai/env_api_keys.go | 53 +++ pkg/ai/event_stream.go | 84 ++++ pkg/ai/event_stream_test.go | 39 ++ pkg/ai/models.go | 73 +++ pkg/ai/models_test.go | 43 ++ pkg/ai/providers/google_gemini_cli.go | 113 +++++ pkg/ai/providers/google_gemini_cli_test.go | 46 ++ pkg/ai/providers/google_shared.go | 247 ++++++++++ pkg/ai/providers/google_shared_test.go | 92 ++++ pkg/ai/providers/openai_completions.go | 450 ++++++++++++++++++ .../openai_completions_convert_test.go | 88 ++++ pkg/ai/providers/openai_completions_test.go | 87 ++++ pkg/ai/providers/register_builtins.go | 14 + pkg/ai/providers/transform_messages.go | 148 ++++++ pkg/ai/providers/transform_messages_test.go | 122 +++++ pkg/ai/stream.go | 56 +++ pkg/ai/types.go | 236 +++++++++ pkg/ai/utils/json_parse.go | 32 ++ pkg/ai/utils/json_parse_test.go | 15 + pkg/ai/utils/overflow.go | 50 ++ pkg/ai/utils/overflow_test.go | 32 ++ pkg/ai/utils/sanitize_unicode.go | 45 ++ pkg/ai/utils/sanitize_unicode_test.go | 16 + 24 files changed, 2256 insertions(+) create mode 100644 pkg/ai/api_registry.go create mode 100644 pkg/ai/env_api_keys.go create mode 100644 pkg/ai/event_stream.go create mode 100644 pkg/ai/event_stream_test.go create mode 100644 pkg/ai/models.go create mode 100644 pkg/ai/models_test.go create mode 100644 pkg/ai/providers/google_gemini_cli.go create mode 100644 pkg/ai/providers/google_gemini_cli_test.go create mode 100644 pkg/ai/providers/google_shared.go create mode 100644 pkg/ai/providers/google_shared_test.go create mode 100644 pkg/ai/providers/openai_completions.go create mode 100644 pkg/ai/providers/openai_completions_convert_test.go create mode 100644 pkg/ai/providers/openai_completions_test.go create mode 100644 pkg/ai/providers/register_builtins.go create mode 100644 pkg/ai/providers/transform_messages.go create mode 100644 pkg/ai/providers/transform_messages_test.go create mode 100644 pkg/ai/stream.go create mode 100644 pkg/ai/types.go create mode 100644 pkg/ai/utils/json_parse.go create mode 100644 pkg/ai/utils/json_parse_test.go create mode 100644 pkg/ai/utils/overflow.go create mode 100644 pkg/ai/utils/overflow_test.go create mode 100644 pkg/ai/utils/sanitize_unicode.go create mode 100644 pkg/ai/utils/sanitize_unicode_test.go diff --git a/pkg/ai/api_registry.go b/pkg/ai/api_registry.go new file mode 100644 index 00000000..006377b1 --- /dev/null +++ b/pkg/ai/api_registry.go @@ -0,0 +1,75 @@ +package ai + +import ( + "fmt" + "sync" +) + +type StreamFn func(model Model, context Context, options *StreamOptions) *AssistantMessageEventStream +type StreamSimpleFn func(model Model, context Context, options *SimpleStreamOptions) *AssistantMessageEventStream + +type APIProvider struct { + API Api + Stream StreamFn + StreamSimple StreamSimpleFn +} + +type registeredProvider struct { + provider APIProvider + sourceID string +} + +var ( + registryMu sync.RWMutex + registry = map[Api]registeredProvider{} +) + +func RegisterAPIProvider(provider APIProvider, sourceID string) { + registryMu.Lock() + defer registryMu.Unlock() + registry[provider.API] = registeredProvider{ + provider: provider, + sourceID: sourceID, + } +} + +func GetAPIProvider(api Api) (APIProvider, bool) { + registryMu.RLock() + defer registryMu.RUnlock() + entry, ok := registry[api] + return entry.provider, ok +} + +func GetAPIProviders() []APIProvider { + registryMu.RLock() + defer registryMu.RUnlock() + out := make([]APIProvider, 0, len(registry)) + for _, entry := range registry { + out = append(out, entry.provider) + } + return out +} + +func UnregisterAPIProviders(sourceID string) { + registryMu.Lock() + defer registryMu.Unlock() + for api, entry := range registry { + if entry.sourceID == sourceID { + delete(registry, api) + } + } +} + +func ClearAPIProviders() { + registryMu.Lock() + defer registryMu.Unlock() + registry = map[Api]registeredProvider{} +} + +func ResolveAPIProvider(api Api) (APIProvider, error) { + provider, ok := GetAPIProvider(api) + if !ok { + return APIProvider{}, fmt.Errorf("no API provider registered for api: %s", api) + } + return provider, nil +} diff --git a/pkg/ai/env_api_keys.go b/pkg/ai/env_api_keys.go new file mode 100644 index 00000000..b65cb87e --- /dev/null +++ b/pkg/ai/env_api_keys.go @@ -0,0 +1,53 @@ +package ai + +import "os" + +func GetEnvAPIKey(provider string) string { + switch provider { + case "github-copilot": + if v := os.Getenv("COPILOT_GITHUB_TOKEN"); v != "" { + return v + } + if v := os.Getenv("GH_TOKEN"); v != "" { + return v + } + return os.Getenv("GITHUB_TOKEN") + case "anthropic": + if v := os.Getenv("ANTHROPIC_OAUTH_TOKEN"); v != "" { + return v + } + return os.Getenv("ANTHROPIC_API_KEY") + case "openai": + return os.Getenv("OPENAI_API_KEY") + case "azure-openai-responses": + return os.Getenv("AZURE_OPENAI_API_KEY") + case "google": + return os.Getenv("GEMINI_API_KEY") + case "groq": + return os.Getenv("GROQ_API_KEY") + case "cerebras": + return os.Getenv("CEREBRAS_API_KEY") + case "xai": + return os.Getenv("XAI_API_KEY") + case "openrouter": + return os.Getenv("OPENROUTER_API_KEY") + case "vercel-ai-gateway": + return os.Getenv("AI_GATEWAY_API_KEY") + case "zai": + return os.Getenv("ZAI_API_KEY") + case "mistral": + return os.Getenv("MISTRAL_API_KEY") + case "minimax": + return os.Getenv("MINIMAX_API_KEY") + case "minimax-cn": + return os.Getenv("MINIMAX_CN_API_KEY") + case "huggingface": + return os.Getenv("HF_TOKEN") + case "opencode", "opencode-go": + return os.Getenv("OPENCODE_API_KEY") + case "kimi-coding": + return os.Getenv("KIMI_API_KEY") + default: + return "" + } +} diff --git a/pkg/ai/event_stream.go b/pkg/ai/event_stream.go new file mode 100644 index 00000000..01b26ef0 --- /dev/null +++ b/pkg/ai/event_stream.go @@ -0,0 +1,84 @@ +package ai + +import ( + "context" + "errors" + "io" + "sync" +) + +var ErrStreamClosed = errors.New("assistant message event stream closed") + +type AssistantMessageEventStream struct { + ch chan AssistantMessageEvent + done chan struct{} + once sync.Once + mu sync.Mutex + result Message + hasResult bool +} + +func NewAssistantMessageEventStream(buffer int) *AssistantMessageEventStream { + if buffer <= 0 { + buffer = 32 + } + return &AssistantMessageEventStream{ + ch: make(chan AssistantMessageEvent, buffer), + done: make(chan struct{}), + } +} + +func (s *AssistantMessageEventStream) Push(evt AssistantMessageEvent) { + select { + case <-s.done: + return + default: + } + + if evt.Type == EventDone { + s.mu.Lock() + s.result = evt.Message + s.hasResult = true + s.mu.Unlock() + } + if evt.Type == EventError { + s.mu.Lock() + s.result = evt.Error + s.hasResult = true + s.mu.Unlock() + } + + select { + case <-s.done: + case s.ch <- evt: + } +} + +func (s *AssistantMessageEventStream) Close() { + s.once.Do(func() { + close(s.done) + close(s.ch) + }) +} + +func (s *AssistantMessageEventStream) Next(ctx context.Context) (AssistantMessageEvent, error) { + select { + case <-ctx.Done(): + return AssistantMessageEvent{}, ctx.Err() + case evt, ok := <-s.ch: + if !ok { + return AssistantMessageEvent{}, io.EOF + } + return evt, nil + } +} + +func (s *AssistantMessageEventStream) Result() (Message, error) { + <-s.done + s.mu.Lock() + defer s.mu.Unlock() + if !s.hasResult { + return Message{}, ErrStreamClosed + } + return s.result, nil +} diff --git a/pkg/ai/event_stream_test.go b/pkg/ai/event_stream_test.go new file mode 100644 index 00000000..cd12e4c5 --- /dev/null +++ b/pkg/ai/event_stream_test.go @@ -0,0 +1,39 @@ +package ai + +import ( + "context" + "io" + "testing" + "time" +) + +func TestAssistantMessageEventStream_ResultFromDone(t *testing.T) { + s := NewAssistantMessageEventStream(4) + doneMsg := Message{Role: RoleAssistant, StopReason: StopReasonStop, Timestamp: 1} + + go func() { + s.Push(AssistantMessageEvent{Type: EventStart}) + s.Push(AssistantMessageEvent{Type: EventDone, Message: doneMsg, Reason: StopReasonStop}) + s.Close() + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for { + _, err := s.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("unexpected stream error: %v", err) + } + } + + result, err := s.Result() + if err != nil { + t.Fatalf("unexpected result error: %v", err) + } + if result.StopReason != StopReasonStop { + t.Fatalf("expected stop reason stop, got %s", result.StopReason) + } +} diff --git a/pkg/ai/models.go b/pkg/ai/models.go new file mode 100644 index 00000000..5546b9fc --- /dev/null +++ b/pkg/ai/models.go @@ -0,0 +1,73 @@ +package ai + +import "strings" + +var modelRegistry = map[string]map[string]Model{} + +func RegisterModels(provider string, models []Model) { + key := strings.TrimSpace(provider) + if key == "" { + return + } + if modelRegistry[key] == nil { + modelRegistry[key] = map[string]Model{} + } + for _, model := range models { + modelRegistry[key][model.ID] = model + } +} + +func GetModel(provider, modelID string) (Model, bool) { + models, ok := modelRegistry[provider] + if !ok { + return Model{}, false + } + model, ok := models[modelID] + return model, ok +} + +func GetProviders() []string { + out := make([]string, 0, len(modelRegistry)) + for provider := range modelRegistry { + out = append(out, provider) + } + return out +} + +func GetModels(provider string) []Model { + models, ok := modelRegistry[provider] + if !ok { + return nil + } + out := make([]Model, 0, len(models)) + for _, model := range models { + out = append(out, model) + } + return out +} + +func CalculateCost(model Model, usage Usage) UsageCost { + usage.Cost.Input = (model.Cost.Input / 1_000_000) * float64(usage.Input) + usage.Cost.Output = (model.Cost.Output / 1_000_000) * float64(usage.Output) + usage.Cost.CacheRead = (model.Cost.CacheRead / 1_000_000) * float64(usage.CacheRead) + usage.Cost.CacheWrite = (model.Cost.CacheWrite / 1_000_000) * float64(usage.CacheWrite) + usage.Cost.Total = usage.Cost.Input + usage.Cost.Output + usage.Cost.CacheRead + usage.Cost.CacheWrite + return usage.Cost +} + +func SupportsXhigh(model Model) bool { + if strings.Contains(model.ID, "gpt-5.2") || strings.Contains(model.ID, "gpt-5.3") { + return true + } + if model.API == APIAnthropicMessages { + return strings.Contains(model.ID, "opus-4-6") || strings.Contains(model.ID, "opus-4.6") + } + return false +} + +func ModelsAreEqual(a, b *Model) bool { + if a == nil || b == nil { + return false + } + return a.ID == b.ID && a.Provider == b.Provider +} diff --git a/pkg/ai/models_test.go b/pkg/ai/models_test.go new file mode 100644 index 00000000..d4ab202b --- /dev/null +++ b/pkg/ai/models_test.go @@ -0,0 +1,43 @@ +package ai + +import "testing" + +func TestSupportsXhigh(t *testing.T) { + anthropicOpus := Model{ + ID: "claude-opus-4-6", + API: APIAnthropicMessages, + Provider: "anthropic", + } + if !SupportsXhigh(anthropicOpus) { + t.Fatalf("expected anthropic opus 4.6 to support xhigh") + } + + anthropicSonnet := Model{ + ID: "claude-sonnet-4-5", + API: APIAnthropicMessages, + Provider: "anthropic", + } + if SupportsXhigh(anthropicSonnet) { + t.Fatalf("expected anthropic sonnet 4.5 not to support xhigh") + } + + openRouterOpus := Model{ + ID: "anthropic/claude-opus-4.6", + API: APIOpenAICompletions, + Provider: "openrouter", + } + if SupportsXhigh(openRouterOpus) { + t.Fatalf("expected openrouter opus to not support xhigh") + } +} + +func TestModelsAreEqual(t *testing.T) { + a := &Model{ID: "gpt-4o", Provider: "openai"} + b := &Model{ID: "gpt-4o", Provider: "openai"} + if !ModelsAreEqual(a, b) { + t.Fatalf("expected models to be equal") + } + if ModelsAreEqual(a, nil) { + t.Fatalf("expected nil model comparison to be false") + } +} diff --git a/pkg/ai/providers/google_gemini_cli.go b/pkg/ai/providers/google_gemini_cli.go new file mode 100644 index 00000000..d5c8fa12 --- /dev/null +++ b/pkg/ai/providers/google_gemini_cli.go @@ -0,0 +1,113 @@ +package providers + +import ( + "net/http" + "regexp" + "strconv" + "strings" + "time" +) + +func ExtractRetryDelay(errorText string, headers http.Header) (int, bool) { + return extractRetryDelayAt(errorText, headers, time.Now()) +} + +func extractRetryDelayAt(errorText string, headers http.Header, now time.Time) (int, bool) { + normalizeDelay := func(ms float64) (int, bool) { + if ms <= 0 { + return 0, false + } + return int(ms + 1000), true + } + + if headers != nil { + retryAfter := headerGetCI(headers, "Retry-After") + if retryAfter != "" { + if secs, err := strconv.ParseFloat(strings.TrimSpace(retryAfter), 64); err == nil { + if delay, ok := normalizeDelay(secs * 1000); ok { + return delay, true + } + } + if retryAt, err := http.ParseTime(retryAfter); err == nil { + if delay, ok := normalizeDelay(float64(retryAt.Sub(now).Milliseconds())); ok { + return delay, true + } + } + } + + if reset := headerGetCI(headers, "x-ratelimit-reset"); reset != "" { + if sec, err := strconv.ParseInt(strings.TrimSpace(reset), 10, 64); err == nil { + resetAt := time.Unix(sec, 0) + if delay, ok := normalizeDelay(float64(resetAt.Sub(now).Milliseconds())); ok { + return delay, true + } + } + } + + if resetAfter := headerGetCI(headers, "x-ratelimit-reset-after"); resetAfter != "" { + if secs, err := strconv.ParseFloat(strings.TrimSpace(resetAfter), 64); err == nil { + if delay, ok := normalizeDelay(secs * 1000); ok { + return delay, true + } + } + } + } + + durationPattern := regexp.MustCompile(`(?i)reset after (?:(\d+)h)?(?:(\d+)m)?(\d+(?:\.\d+)?)s`) + if matches := durationPattern.FindStringSubmatch(errorText); len(matches) == 4 { + hours, _ := strconv.ParseFloat(orZero(matches[1]), 64) + minutes, _ := strconv.ParseFloat(orZero(matches[2]), 64) + seconds, _ := strconv.ParseFloat(orZero(matches[3]), 64) + ms := ((hours*60+minutes)*60 + seconds) * 1000 + if delay, ok := normalizeDelay(ms); ok { + return delay, true + } + } + + retryInPattern := regexp.MustCompile(`(?i)Please retry in ([0-9.]+)(ms|s)`) + if matches := retryInPattern.FindStringSubmatch(errorText); len(matches) == 3 { + value, _ := strconv.ParseFloat(matches[1], 64) + if strings.EqualFold(matches[2], "s") { + value *= 1000 + } + if delay, ok := normalizeDelay(value); ok { + return delay, true + } + } + + retryDelayPattern := regexp.MustCompile(`(?i)"retryDelay":\s*"([0-9.]+)(ms|s)"`) + if matches := retryDelayPattern.FindStringSubmatch(errorText); len(matches) == 3 { + value, _ := strconv.ParseFloat(matches[1], 64) + if strings.EqualFold(matches[2], "s") { + value *= 1000 + } + if delay, ok := normalizeDelay(value); ok { + return delay, true + } + } + + return 0, false +} + +func orZero(in string) string { + if strings.TrimSpace(in) == "" { + return "0" + } + return in +} + +func headerGetCI(headers http.Header, key string) string { + if headers == nil { + return "" + } + if v := headers.Get(key); v != "" { + return v + } + for k, values := range headers { + if !strings.EqualFold(k, key) || len(values) == 0 { + continue + } + return values[0] + } + return "" +} diff --git a/pkg/ai/providers/google_gemini_cli_test.go b/pkg/ai/providers/google_gemini_cli_test.go new file mode 100644 index 00000000..9f54771b --- /dev/null +++ b/pkg/ai/providers/google_gemini_cli_test.go @@ -0,0 +1,46 @@ +package providers + +import ( + "net/http" + "strconv" + "testing" + "time" +) + +func TestExtractRetryDelay(t *testing.T) { + now := time.Date(2025, time.January, 1, 0, 0, 0, 0, time.UTC) + + t.Run("prefers Retry-After seconds header", func(t *testing.T) { + headers := http.Header{"Retry-After": []string{"5"}} + delay, ok := extractRetryDelayAt("Please retry in 1s", headers, now) + if !ok || delay != 6000 { + t.Fatalf("expected 6000ms, got %d (ok=%v)", delay, ok) + } + }) + + t.Run("parses Retry-After HTTP date header", func(t *testing.T) { + retryAt := now.Add(12 * time.Second).UTC().Format(http.TimeFormat) + headers := http.Header{"Retry-After": []string{retryAt}} + delay, ok := extractRetryDelayAt("", headers, now) + if !ok || delay != 13000 { + t.Fatalf("expected 13000ms, got %d (ok=%v)", delay, ok) + } + }) + + t.Run("parses x-ratelimit-reset header", func(t *testing.T) { + resetAt := now.Add(20 * time.Second).Unix() + headers := http.Header{"x-ratelimit-reset": []string{strconv.FormatInt(resetAt, 10)}} + delay, ok := extractRetryDelayAt("", headers, now) + if !ok || delay != 21000 { + t.Fatalf("expected 21000ms, got %d (ok=%v)", delay, ok) + } + }) + + t.Run("parses x-ratelimit-reset-after header", func(t *testing.T) { + headers := http.Header{"x-ratelimit-reset-after": []string{"30"}} + delay, ok := extractRetryDelayAt("", headers, now) + if !ok || delay != 31000 { + t.Fatalf("expected 31000ms, got %d (ok=%v)", delay, ok) + } + }) +} diff --git a/pkg/ai/providers/google_shared.go b/pkg/ai/providers/google_shared.go new file mode 100644 index 00000000..3c9c52a1 --- /dev/null +++ b/pkg/ai/providers/google_shared.go @@ -0,0 +1,247 @@ +package providers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +type GoogleContent struct { + Role string `json:"role"` + Parts []GooglePart `json:"parts"` +} + +type GooglePart struct { + Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + FunctionCall *GoogleFunctionCall `json:"functionCall,omitempty"` + FunctionResponse *GoogleFunctionResponse `json:"functionResponse,omitempty"` + InlineData *GoogleInlineData `json:"inlineData,omitempty"` +} + +type GoogleFunctionCall struct { + Name string `json:"name"` + Args map[string]any `json:"args,omitempty"` + ID string `json:"id,omitempty"` +} + +type GoogleFunctionResponse struct { + Name string `json:"name"` + Response map[string]any `json:"response,omitempty"` + ID string `json:"id,omitempty"` + Parts []GooglePart `json:"parts,omitempty"` +} + +type GoogleInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +func IsThinkingPart(part GooglePart) bool { + return part.Thought +} + +func RetainThoughtSignature(existing string, incoming string) string { + if strings.TrimSpace(incoming) != "" { + return incoming + } + return existing +} + +func RequiresToolCallID(modelID string) bool { + return strings.HasPrefix(modelID, "claude-") || strings.HasPrefix(modelID, "gpt-oss-") +} + +func ConvertGoogleMessages(model ai.Model, c ai.Context) []GoogleContent { + normalized := TransformMessages(c.Messages, model, func(id string, _ ai.Model, _ ai.Message) string { + if !RequiresToolCallID(model.ID) { + return id + } + sanitized := strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + return r + } + return '_' + }, id) + if len(sanitized) > 64 { + return sanitized[:64] + } + return sanitized + }) + + out := make([]GoogleContent, 0, len(normalized)) + for _, msg := range normalized { + switch msg.Role { + case ai.RoleUser: + parts := make([]GooglePart, 0, max(1, len(msg.Content))) + if strings.TrimSpace(msg.Text) != "" { + parts = append(parts, GooglePart{Text: utils.SanitizeSurrogates(msg.Text)}) + } + for _, block := range msg.Content { + switch block.Type { + case ai.ContentTypeText: + if strings.TrimSpace(block.Text) == "" { + continue + } + parts = append(parts, GooglePart{Text: utils.SanitizeSurrogates(block.Text)}) + case ai.ContentTypeImage: + if !supportsImageInput(model) { + continue + } + parts = append(parts, GooglePart{ + InlineData: &GoogleInlineData{ + MimeType: block.MimeType, + Data: block.Data, + }, + }) + } + } + if len(parts) == 0 { + continue + } + out = append(out, GoogleContent{Role: "user", Parts: parts}) + case ai.RoleAssistant: + parts := make([]GooglePart, 0, len(msg.Content)) + isSameProviderAndModel := msg.Provider == model.Provider && msg.Model == model.ID + for _, block := range msg.Content { + switch block.Type { + case ai.ContentTypeText: + if strings.TrimSpace(block.Text) == "" { + continue + } + part := GooglePart{Text: utils.SanitizeSurrogates(block.Text)} + if isSameProviderAndModel && isValidThoughtSignature(block.TextSignature) { + part.ThoughtSignature = block.TextSignature + } + parts = append(parts, part) + case ai.ContentTypeThinking: + if strings.TrimSpace(block.Thinking) == "" { + continue + } + if isSameProviderAndModel { + part := GooglePart{ + Thought: true, + Text: utils.SanitizeSurrogates(block.Thinking), + } + if isValidThoughtSignature(block.ThinkingSignature) { + part.ThoughtSignature = block.ThinkingSignature + } + parts = append(parts, part) + } else { + parts = append(parts, GooglePart{Text: utils.SanitizeSurrogates(block.Thinking)}) + } + case ai.ContentTypeToolCall: + sig := "" + if isSameProviderAndModel && isValidThoughtSignature(block.ThoughtSignature) { + sig = block.ThoughtSignature + } + isGemini3 := strings.Contains(strings.ToLower(model.ID), "gemini-3") + if isGemini3 && sig == "" { + argsBytes, _ := json.MarshalIndent(block.Arguments, "", " ") + parts = append(parts, GooglePart{ + Text: fmt.Sprintf( + `[Historical context: a different model called tool "%s" with arguments: %s. Do not mimic this format - use proper function calling.]`, + block.Name, + string(argsBytes), + ), + }) + continue + } + part := GooglePart{ + FunctionCall: &GoogleFunctionCall{ + Name: block.Name, + Args: block.Arguments, + }, + } + if RequiresToolCallID(model.ID) { + part.FunctionCall.ID = block.ID + } + if sig != "" { + part.ThoughtSignature = sig + } + parts = append(parts, part) + } + } + if len(parts) == 0 { + continue + } + out = append(out, GoogleContent{Role: "model", Parts: parts}) + case ai.RoleToolResult: + textResult := "" + imageParts := make([]GooglePart, 0) + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText { + if textResult != "" { + textResult += "\n" + } + textResult += block.Text + } + if block.Type == ai.ContentTypeImage && supportsImageInput(model) { + imageParts = append(imageParts, GooglePart{ + InlineData: &GoogleInlineData{ + MimeType: block.MimeType, + Data: block.Data, + }, + }) + } + } + respValue := textResult + if strings.TrimSpace(respValue) == "" && len(imageParts) > 0 { + respValue = "(see attached image)" + } + responseMap := map[string]any{"output": utils.SanitizeSurrogates(respValue)} + if msg.IsError { + responseMap = map[string]any{"error": utils.SanitizeSurrogates(respValue)} + } + part := GooglePart{ + FunctionResponse: &GoogleFunctionResponse{ + Name: msg.ToolName, + Response: responseMap, + }, + } + if RequiresToolCallID(model.ID) { + part.FunctionResponse.ID = msg.ToolCallID + } + out = append(out, GoogleContent{ + Role: "user", + Parts: []GooglePart{part}, + }) + if len(imageParts) > 0 && !strings.Contains(strings.ToLower(model.ID), "gemini-3") { + out = append(out, GoogleContent{ + Role: "user", + Parts: append([]GooglePart{{Text: "Tool result image:"}}, imageParts...), + }) + } + } + } + return out +} + +func supportsImageInput(model ai.Model) bool { + for _, input := range model.Input { + if input == "image" { + return true + } + } + return false +} + +func isValidThoughtSignature(signature string) bool { + if signature == "" { + return false + } + if len(signature)%4 != 0 { + return false + } + for _, r := range signature { + if (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '+' || r == '/' || r == '=' { + continue + } + return false + } + return true +} diff --git a/pkg/ai/providers/google_shared_test.go b/pkg/ai/providers/google_shared_test.go new file mode 100644 index 00000000..98e2a76c --- /dev/null +++ b/pkg/ai/providers/google_shared_test.go @@ -0,0 +1,92 @@ +package providers + +import ( + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestIsThinkingPartAndRetainSignature(t *testing.T) { + if !IsThinkingPart(GooglePart{Thought: true}) { + t.Fatalf("expected thought=true to be thinking part") + } + if IsThinkingPart(GooglePart{Thought: false, ThoughtSignature: "opaque"}) { + t.Fatalf("thoughtSignature alone must not mark part as thinking") + } + + first := RetainThoughtSignature("", "sig-1") + if first != "sig-1" { + t.Fatalf("expected initial signature to be retained") + } + second := RetainThoughtSignature(first, "") + if second != "sig-1" { + t.Fatalf("expected previous signature retained when incoming empty") + } + third := RetainThoughtSignature(second, "sig-2") + if third != "sig-2" { + t.Fatalf("expected signature to update when incoming non-empty") + } +} + +func TestConvertMessages_ConvertsUnsignedToolCallsToHistoricalTextForGemini3(t *testing.T) { + now := time.Now().UnixMilli() + model := ai.Model{ + ID: "gemini-3-pro-preview", + Name: "Gemini 3 Pro Preview", + API: ai.APIGoogleGenerativeAI, + Provider: "google", + BaseURL: "https://generativelanguage.googleapis.com", + Reasoning: true, + Input: []string{"text"}, + } + context := ai.Context{ + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "Hi", Timestamp: now}, + { + Role: ai.RoleAssistant, + Provider: "google-antigravity", + API: ai.APIGoogleGeminiCLI, + Model: "claude-sonnet-4-20250514", + StopReason: ai.StopReasonStop, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "call_1", + Name: "bash", + Arguments: map[string]any{"command": "ls -la"}, + }, + }, + Timestamp: now, + }, + }, + } + + contents := ConvertGoogleMessages(model, context) + var toolTurn *GoogleContent + for i := len(contents) - 1; i >= 0; i-- { + if contents[i].Role == "model" { + toolTurn = &contents[i] + break + } + } + if toolTurn == nil { + t.Fatalf("expected model content turn") + } + for _, part := range toolTurn.Parts { + if part.FunctionCall != nil { + t.Fatalf("expected no function call for unsigned tool call in gemini-3") + } + } + joined := "" + for _, part := range toolTurn.Parts { + joined += part.Text + "\n" + } + if !strings.Contains(joined, "Historical context") || + !strings.Contains(joined, "bash") || + !strings.Contains(joined, "ls -la") || + !strings.Contains(joined, "Do not mimic this format") { + t.Fatalf("unexpected historical context text: %s", joined) + } +} diff --git a/pkg/ai/providers/openai_completions.go b/pkg/ai/providers/openai_completions.go new file mode 100644 index 00000000..bfbcff9b --- /dev/null +++ b/pkg/ai/providers/openai_completions.go @@ -0,0 +1,450 @@ +package providers + +import ( + "encoding/json" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +type OpenAICompletionsOptions struct { + ToolChoice any + ReasoningEffort ai.ThinkingLevel + StreamOptions ai.StreamOptions +} + +type OpenAIMessage struct { + Role string `json:"role"` + Content any `json:"content,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` + ToolCalls []map[string]any `json:"tool_calls,omitempty"` + Extra map[string]any `json:"-"` +} + +func BuildOpenAICompletionsParams(model ai.Model, context ai.Context, options OpenAICompletionsOptions) map[string]any { + compat := GetCompat(model) + params := map[string]any{ + "model": model.ID, + "stream": true, + "messages": ConvertOpenAICompletionsMessages(model, context, compat), + } + if compat.SupportsUsageInStreaming { + params["stream_options"] = map[string]any{"include_usage": true} + } + if compat.SupportsStore { + params["store"] = false + } + if options.StreamOptions.MaxTokens > 0 { + if compat.MaxTokensField == "max_tokens" { + params["max_tokens"] = options.StreamOptions.MaxTokens + } else { + params["max_completion_tokens"] = options.StreamOptions.MaxTokens + } + } + if options.StreamOptions.Temperature != nil { + params["temperature"] = *options.StreamOptions.Temperature + } + if len(context.Tools) > 0 { + params["tools"] = convertOpenAICompletionsTools(context.Tools, compat) + } + if options.ToolChoice != nil { + params["tool_choice"] = options.ToolChoice + } + if options.ReasoningEffort != "" && compat.SupportsReasoningEffort { + params["reasoning_effort"] = mapReasoningEffort(options.ReasoningEffort, compat.ReasoningEffortMap) + } + return params +} + +func mapReasoningEffort(level ai.ThinkingLevel, custom map[ai.ThinkingLevel]string) string { + if custom != nil { + if mapped, ok := custom[level]; ok && mapped != "" { + return mapped + } + } + return string(level) +} + +func convertOpenAICompletionsTools(tools []ai.Tool, compat OpenAICompatResolved) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + fn := map[string]any{ + "name": tool.Name, + "description": tool.Description, + "parameters": tool.Parameters, + } + if compat.SupportsStrictMode { + fn["strict"] = false + } + out = append(out, map[string]any{ + "type": "function", + "function": fn, + }) + } + return out +} + +type OpenAICompatResolved struct { + SupportsStore bool + SupportsDeveloperRole bool + SupportsReasoningEffort bool + ReasoningEffortMap map[ai.ThinkingLevel]string + SupportsUsageInStreaming bool + MaxTokensField string + RequiresToolResultName bool + RequiresAssistantAfterToolResult bool + RequiresThinkingAsText bool + RequiresMistralToolIDs bool + ThinkingFormat string + SupportsStrictMode bool +} + +func DetectCompat(model ai.Model) OpenAICompatResolved { + provider := string(model.Provider) + baseURL := model.BaseURL + isZai := provider == "zai" || strings.Contains(baseURL, "api.z.ai") + isMistral := provider == "mistral" || strings.Contains(baseURL, "mistral.ai") + isGrok := provider == "xai" || strings.Contains(baseURL, "api.x.ai") + isGroq := provider == "groq" || strings.Contains(baseURL, "groq.com") + + isNonStandard := provider == "cerebras" || + strings.Contains(baseURL, "cerebras.ai") || + isGrok || + isMistral || + strings.Contains(baseURL, "chutes.ai") || + strings.Contains(baseURL, "deepseek.com") || + isZai || + provider == "opencode" || + strings.Contains(baseURL, "opencode.ai") + + reasoningEffortMap := map[ai.ThinkingLevel]string{} + if isGroq && model.ID == "qwen/qwen3-32b" { + reasoningEffortMap[ai.ThinkingMinimal] = "default" + reasoningEffortMap[ai.ThinkingLow] = "default" + reasoningEffortMap[ai.ThinkingMedium] = "default" + reasoningEffortMap[ai.ThinkingHigh] = "default" + reasoningEffortMap[ai.ThinkingXHigh] = "default" + } + + return OpenAICompatResolved{ + SupportsStore: !isNonStandard, + SupportsDeveloperRole: !isNonStandard, + SupportsReasoningEffort: !isGrok && !isZai, + ReasoningEffortMap: reasoningEffortMap, + SupportsUsageInStreaming: true, + MaxTokensField: chooseMaxTokensField(isMistral, baseURL), + RequiresToolResultName: isMistral, + RequiresAssistantAfterToolResult: false, + RequiresThinkingAsText: isMistral, + RequiresMistralToolIDs: isMistral, + ThinkingFormat: chooseThinkingFormat(isZai), + SupportsStrictMode: true, + } +} + +func chooseMaxTokensField(isMistral bool, baseURL string) string { + if isMistral || strings.Contains(baseURL, "chutes.ai") { + return "max_tokens" + } + return "max_completion_tokens" +} + +func chooseThinkingFormat(isZai bool) string { + if isZai { + return "zai" + } + return "openai" +} + +func GetCompat(model ai.Model) OpenAICompatResolved { + detected := DetectCompat(model) + if model.Compat == nil { + return detected + } + compat := model.Compat + if compat.SupportsStore != nil { + detected.SupportsStore = *compat.SupportsStore + } + if compat.SupportsDeveloperRole != nil { + detected.SupportsDeveloperRole = *compat.SupportsDeveloperRole + } + if compat.SupportsReasoningEffort != nil { + detected.SupportsReasoningEffort = *compat.SupportsReasoningEffort + } + if compat.ReasoningEffortMap != nil { + detected.ReasoningEffortMap = compat.ReasoningEffortMap + } + if compat.SupportsUsageInStreaming != nil { + detected.SupportsUsageInStreaming = *compat.SupportsUsageInStreaming + } + if strings.TrimSpace(compat.MaxTokensField) != "" { + detected.MaxTokensField = compat.MaxTokensField + } + if compat.RequiresToolResultName != nil { + detected.RequiresToolResultName = *compat.RequiresToolResultName + } + if compat.RequiresAssistantAfterToolResult != nil { + detected.RequiresAssistantAfterToolResult = *compat.RequiresAssistantAfterToolResult + } + if compat.RequiresThinkingAsText != nil { + detected.RequiresThinkingAsText = *compat.RequiresThinkingAsText + } + if compat.RequiresMistralToolIDs != nil { + detected.RequiresMistralToolIDs = *compat.RequiresMistralToolIDs + } + if strings.TrimSpace(compat.ThinkingFormat) != "" { + detected.ThinkingFormat = compat.ThinkingFormat + } + if compat.SupportsStrictMode != nil { + detected.SupportsStrictMode = *compat.SupportsStrictMode + } + return detected +} + +func ConvertOpenAICompletionsMessages(model ai.Model, context ai.Context, compat OpenAICompatResolved) []OpenAIMessage { + params := make([]OpenAIMessage, 0, len(context.Messages)+1) + if strings.TrimSpace(context.SystemPrompt) != "" { + role := "system" + if model.Reasoning && compat.SupportsDeveloperRole { + role = "developer" + } + params = append(params, OpenAIMessage{ + Role: role, + Content: utils.SanitizeSurrogates(context.SystemPrompt), + }) + } + + normalizeToolCallID := func(id string) string { + if compat.RequiresMistralToolIDs { + return normalizeMistralToolID(id) + } + if strings.Contains(id, "|") { + callID := strings.SplitN(id, "|", 2)[0] + sanitized := sanitizeToolID(callID) + if len(sanitized) > 40 { + return sanitized[:40] + } + return sanitized + } + if model.Provider == "openai" && len(id) > 40 { + return id[:40] + } + return id + } + + transformed := TransformMessages(context.Messages, model, func(id string, _ ai.Model, _ ai.Message) string { + return normalizeToolCallID(id) + }) + + lastRole := "" + for i := 0; i < len(transformed); i++ { + msg := transformed[i] + if compat.RequiresAssistantAfterToolResult && lastRole == string(ai.RoleToolResult) && msg.Role == ai.RoleUser { + params = append(params, OpenAIMessage{Role: "assistant", Content: "I have processed the tool results."}) + } + + switch msg.Role { + case ai.RoleUser: + userParts := make([]map[string]any, 0) + if strings.TrimSpace(msg.Text) != "" { + userParts = append(userParts, map[string]any{ + "type": "text", + "text": utils.SanitizeSurrogates(msg.Text), + }) + } + for _, part := range msg.Content { + switch part.Type { + case ai.ContentTypeText: + if strings.TrimSpace(part.Text) == "" { + continue + } + userParts = append(userParts, map[string]any{ + "type": "text", + "text": utils.SanitizeSurrogates(part.Text), + }) + case ai.ContentTypeImage: + if !supportsImage(model) { + continue + } + userParts = append(userParts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": "data:" + part.MimeType + ";base64," + part.Data, + }, + }) + } + } + if len(userParts) == 0 { + continue + } + if len(userParts) == 1 && userParts[0]["type"] == "text" { + params = append(params, OpenAIMessage{ + Role: "user", + Content: userParts[0]["text"], + }) + } else { + params = append(params, OpenAIMessage{ + Role: "user", + Content: userParts, + }) + } + case ai.RoleAssistant: + a := OpenAIMessage{ + Role: "assistant", + Content: nil, + } + textParts := make([]map[string]any, 0) + toolCalls := make([]map[string]any, 0) + for _, part := range msg.Content { + switch part.Type { + case ai.ContentTypeText: + if strings.TrimSpace(part.Text) == "" { + continue + } + textParts = append(textParts, map[string]any{"type": "text", "text": utils.SanitizeSurrogates(part.Text)}) + case ai.ContentTypeThinking: + if strings.TrimSpace(part.Thinking) == "" { + continue + } + if compat.RequiresThinkingAsText { + textParts = append([]map[string]any{{ + "type": "text", + "text": part.Thinking, + }}, textParts...) + } + case ai.ContentTypeToolCall: + argsBytes, _ := json.Marshal(part.Arguments) + toolCalls = append(toolCalls, map[string]any{ + "id": part.ID, + "type": "function", + "function": map[string]any{ + "name": part.Name, + "arguments": string(argsBytes), + }, + }) + } + } + if len(textParts) > 0 { + if model.Provider == "github-copilot" { + text := "" + for _, p := range textParts { + text += p["text"].(string) + } + a.Content = text + } else { + a.Content = textParts + } + } + if len(toolCalls) > 0 { + a.ToolCalls = toolCalls + } + hasContent := false + switch value := a.Content.(type) { + case string: + hasContent = strings.TrimSpace(value) != "" + case []map[string]any: + hasContent = len(value) > 0 + } + if !hasContent && len(a.ToolCalls) == 0 { + continue + } + params = append(params, a) + case ai.RoleToolResult: + imageBlocks := make([]map[string]any, 0) + j := i + for ; j < len(transformed) && transformed[j].Role == ai.RoleToolResult; j++ { + toolMsg := transformed[j] + textResult := "" + for _, block := range toolMsg.Content { + if block.Type == ai.ContentTypeText { + if textResult != "" { + textResult += "\n" + } + textResult += block.Text + } + } + hasText := strings.TrimSpace(textResult) != "" + msgPayload := OpenAIMessage{ + Role: "tool", + ToolCallID: toolMsg.ToolCallID, + Content: utils.SanitizeSurrogates(textResult), + } + if !hasText { + msgPayload.Content = "(see attached image)" + } + if compat.RequiresToolResultName && toolMsg.ToolName != "" { + msgPayload.Name = toolMsg.ToolName + } + params = append(params, msgPayload) + + if supportsImage(model) { + for _, block := range toolMsg.Content { + if block.Type == ai.ContentTypeImage { + imageBlocks = append(imageBlocks, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": "data:" + block.MimeType + ";base64," + block.Data, + }, + }) + } + } + } + } + i = j - 1 + if len(imageBlocks) > 0 { + if compat.RequiresAssistantAfterToolResult { + params = append(params, OpenAIMessage{Role: "assistant", Content: "I have processed the tool results."}) + } + content := []map[string]any{{"type": "text", "text": "Attached image(s) from tool result:"}} + content = append(content, imageBlocks...) + params = append(params, OpenAIMessage{Role: "user", Content: content}) + lastRole = "user" + } else { + lastRole = string(ai.RoleToolResult) + } + continue + } + lastRole = string(msg.Role) + } + return params +} + +func supportsImage(model ai.Model) bool { + for _, in := range model.Input { + if in == "image" { + return true + } + } + return false +} + +func sanitizeToolID(id string) string { + var b strings.Builder + for _, r := range id { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + b.WriteRune(r) + } else { + b.WriteRune('_') + } + } + return b.String() +} + +func normalizeMistralToolID(id string) string { + normalized := strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + return r + } + return -1 + }, id) + if len(normalized) < 9 { + padding := "ABCDEFGHI" + return normalized + padding[:9-len(normalized)] + } + if len(normalized) > 9 { + return normalized[:9] + } + return normalized +} diff --git a/pkg/ai/providers/openai_completions_convert_test.go b/pkg/ai/providers/openai_completions_convert_test.go new file mode 100644 index 00000000..6c183712 --- /dev/null +++ b/pkg/ai/providers/openai_completions_convert_test.go @@ -0,0 +1,88 @@ +package providers + +import ( + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestConvertOpenAICompletionsMessages_BatchesToolResultImages(t *testing.T) { + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAICompletions, + Provider: "openai", + BaseURL: "https://api.openai.com/v1", + Input: []string{"text", "image"}, + } + now := time.Now().UnixMilli() + context := ai.Context{ + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "Read the images", Timestamp: now - 2}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeToolCall, ID: "tool-1", Name: "read", Arguments: map[string]any{"path": "img-1.png"}}, + {Type: ai.ContentTypeToolCall, ID: "tool-2", Name: "read", Arguments: map[string]any{"path": "img-2.png"}}, + }, + Provider: "openai", + API: ai.APIOpenAICompletions, + Model: "gpt-4o-mini", + StopReason: ai.StopReasonToolUse, + Timestamp: now, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "tool-1", + ToolName: "read", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "Read image file [image/png]"}, + {Type: ai.ContentTypeImage, Data: "ZmFrZQ==", MimeType: "image/png"}, + }, + Timestamp: now + 1, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "tool-2", + ToolName: "read", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "Read image file [image/png]"}, + {Type: ai.ContentTypeImage, Data: "ZmFrZQ==", MimeType: "image/png"}, + }, + Timestamp: now + 2, + }, + }, + } + + compat := GetCompat(model) + messages := ConvertOpenAICompletionsMessages(model, context, compat) + if len(messages) != 5 { + t.Fatalf("expected 5 messages, got %d", len(messages)) + } + roles := []string{ + messages[0].Role, + messages[1].Role, + messages[2].Role, + messages[3].Role, + messages[4].Role, + } + expected := []string{"user", "assistant", "tool", "tool", "user"} + for i := range expected { + if roles[i] != expected[i] { + t.Fatalf("unexpected roles: %+v", roles) + } + } + content, ok := messages[4].Content.([]map[string]any) + if !ok { + t.Fatalf("expected final user content array, got %T", messages[4].Content) + } + imageCount := 0 + for _, part := range content { + if part["type"] == "image_url" { + imageCount++ + } + } + if imageCount != 2 { + t.Fatalf("expected 2 image parts, got %d", imageCount) + } +} diff --git a/pkg/ai/providers/openai_completions_test.go b/pkg/ai/providers/openai_completions_test.go new file mode 100644 index 00000000..59bdafb9 --- /dev/null +++ b/pkg/ai/providers/openai_completions_test.go @@ -0,0 +1,87 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func boolPtr(v bool) *bool { return &v } + +func TestBuildOpenAICompletionsParams_ToolChoiceAndStrict(t *testing.T) { + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAICompletions, + Provider: "openai", + BaseURL: "https://api.openai.com/v1", + Input: []string{"text"}, + } + context := ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "Call ping with ok=true", Timestamp: 1}}, + Tools: []ai.Tool{{ + Name: "ping", + Description: "Ping tool", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "ok": map[string]any{"type": "boolean"}, + }, + }, + }}, + } + + params := BuildOpenAICompletionsParams(model, context, OpenAICompletionsOptions{ + ToolChoice: "required", + }) + if params["tool_choice"] != "required" { + t.Fatalf("expected tool_choice=required, got %v", params["tool_choice"]) + } + tools, ok := params["tools"].([]map[string]any) + if !ok || len(tools) == 0 { + t.Fatalf("expected non-empty tools payload") + } + fn, _ := tools[0]["function"].(map[string]any) + if _, ok := fn["strict"]; !ok { + t.Fatalf("expected strict in function payload when supported") + } + + model.Compat = &ai.OpenAICompletionsCompat{ + SupportsStrictMode: boolPtr(false), + } + params = BuildOpenAICompletionsParams(model, context, OpenAICompletionsOptions{}) + tools, ok = params["tools"].([]map[string]any) + if !ok || len(tools) == 0 { + t.Fatalf("expected non-empty tools payload") + } + fn, _ = tools[0]["function"].(map[string]any) + if _, ok := fn["strict"]; ok { + t.Fatalf("expected strict to be omitted when supportsStrictMode=false") + } +} + +func TestBuildOpenAICompletionsParams_ReasoningEffortGroqMapping(t *testing.T) { + model := ai.Model{ + ID: "qwen/qwen3-32b", + API: ai.APIOpenAICompletions, + Provider: "groq", + BaseURL: "https://api.groq.com/openai/v1", + Input: []string{"text"}, + } + context := ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "Hi", Timestamp: 1}}, + } + params := BuildOpenAICompletionsParams(model, context, OpenAICompletionsOptions{ + ReasoningEffort: ai.ThinkingMedium, + }) + if params["reasoning_effort"] != "default" { + t.Fatalf("expected reasoning_effort=default, got %v", params["reasoning_effort"]) + } + + model.ID = "openai/gpt-oss-20b" + params = BuildOpenAICompletionsParams(model, context, OpenAICompletionsOptions{ + ReasoningEffort: ai.ThinkingMedium, + }) + if params["reasoning_effort"] != "medium" { + t.Fatalf("expected reasoning_effort=medium, got %v", params["reasoning_effort"]) + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go new file mode 100644 index 00000000..790db2ad --- /dev/null +++ b/pkg/ai/providers/register_builtins.go @@ -0,0 +1,14 @@ +package providers + +import "github.com/beeper/ai-bridge/pkg/ai" + +const BuiltinProviderSourceID = "pkg/ai/providers/register_builtins" + +// RegisterBuiltInAPIProviders registers providers implemented in this package. +// Initial scaffold keeps registry empty until concrete provider streamers are ported. +func RegisterBuiltInAPIProviders() {} + +func ResetAPIProviders() { + ai.ClearAPIProviders() + RegisterBuiltInAPIProviders() +} diff --git a/pkg/ai/providers/transform_messages.go b/pkg/ai/providers/transform_messages.go new file mode 100644 index 00000000..8676a0c1 --- /dev/null +++ b/pkg/ai/providers/transform_messages.go @@ -0,0 +1,148 @@ +package providers + +import ( + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TransformMessages( + messages []ai.Message, + model ai.Model, + normalizeToolCallID func(id string, model ai.Model, source ai.Message) string, +) []ai.Message { + toolCallIDMap := map[string]string{} + transformed := make([]ai.Message, 0, len(messages)) + + for _, msg := range messages { + switch msg.Role { + case ai.RoleUser: + transformed = append(transformed, msg) + case ai.RoleToolResult: + if normalized, ok := toolCallIDMap[msg.ToolCallID]; ok && normalized != "" { + msg.ToolCallID = normalized + } + transformed = append(transformed, msg) + case ai.RoleAssistant: + isSameModel := msg.Provider == model.Provider && msg.API == model.API && msg.Model == model.ID + nextContent := make([]ai.ContentBlock, 0, len(msg.Content)) + for _, block := range msg.Content { + switch block.Type { + case ai.ContentTypeThinking: + if block.Redacted { + if isSameModel { + nextContent = append(nextContent, block) + } + continue + } + if strings.TrimSpace(block.Thinking) == "" { + if isSameModel && block.ThinkingSignature != "" { + nextContent = append(nextContent, block) + } + continue + } + if isSameModel { + nextContent = append(nextContent, block) + } else { + nextContent = append(nextContent, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: block.Thinking, + }) + } + case ai.ContentTypeText: + nextContent = append(nextContent, block) + case ai.ContentTypeToolCall: + if !isSameModel { + block.ThoughtSignature = "" + if normalizeToolCallID != nil { + normalized := normalizeToolCallID(block.ID, model, msg) + if normalized != "" && normalized != block.ID { + toolCallIDMap[block.ID] = normalized + block.ID = normalized + } + } + } + nextContent = append(nextContent, block) + default: + nextContent = append(nextContent, block) + } + } + msg.Content = nextContent + transformed = append(transformed, msg) + default: + transformed = append(transformed, msg) + } + } + + // Second pass: synthesize missing tool results for orphaned tool calls. + result := make([]ai.Message, 0, len(transformed)) + var pendingToolCalls []ai.ContentBlock + existingToolResultIDs := map[string]struct{}{} + for _, msg := range transformed { + switch msg.Role { + case ai.RoleAssistant: + if len(pendingToolCalls) > 0 { + for _, tc := range pendingToolCalls { + if _, ok := existingToolResultIDs[tc.ID]; ok { + continue + } + result = append(result, ai.Message{ + Role: ai.RoleToolResult, + ToolCallID: tc.ID, + ToolName: tc.Name, + Content: []ai.ContentBlock{{ + Type: ai.ContentTypeText, + Text: "No result provided", + }}, + IsError: true, + Timestamp: time.Now().UnixMilli(), + }) + } + pendingToolCalls = nil + existingToolResultIDs = map[string]struct{}{} + } + + if msg.StopReason == ai.StopReasonError || msg.StopReason == ai.StopReasonAborted { + continue + } + + pendingToolCalls = nil + existingToolResultIDs = map[string]struct{}{} + for _, block := range msg.Content { + if block.Type == ai.ContentTypeToolCall { + pendingToolCalls = append(pendingToolCalls, block) + } + } + result = append(result, msg) + case ai.RoleToolResult: + existingToolResultIDs[msg.ToolCallID] = struct{}{} + result = append(result, msg) + case ai.RoleUser: + if len(pendingToolCalls) > 0 { + for _, tc := range pendingToolCalls { + if _, ok := existingToolResultIDs[tc.ID]; ok { + continue + } + result = append(result, ai.Message{ + Role: ai.RoleToolResult, + ToolCallID: tc.ID, + ToolName: tc.Name, + Content: []ai.ContentBlock{{ + Type: ai.ContentTypeText, + Text: "No result provided", + }}, + IsError: true, + Timestamp: time.Now().UnixMilli(), + }) + } + pendingToolCalls = nil + existingToolResultIDs = map[string]struct{}{} + } + result = append(result, msg) + default: + result = append(result, msg) + } + } + return result +} diff --git a/pkg/ai/providers/transform_messages_test.go b/pkg/ai/providers/transform_messages_test.go new file mode 100644 index 00000000..7d16dc78 --- /dev/null +++ b/pkg/ai/providers/transform_messages_test.go @@ -0,0 +1,122 @@ +package providers + +import ( + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func anthropicNormalizeID(id string, _ ai.Model, _ ai.Message) string { + sanitized := "" + for _, r := range id { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + sanitized += string(r) + } else { + sanitized += "_" + } + } + if len(sanitized) > 64 { + return sanitized[:64] + } + return sanitized +} + +func TestTransformMessages_OpenAIToAnthropicCopilot(t *testing.T) { + model := ai.Model{ + ID: "claude-sonnet-4", + API: ai.APIAnthropicMessages, + Provider: "github-copilot", + } + now := time.Now().UnixMilli() + + messages := []ai.Message{ + {Role: ai.RoleUser, Text: "hello", Timestamp: now}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeThinking, Thinking: "Let me think", ThinkingSignature: "reasoning_content"}, + {Type: ai.ContentTypeText, Text: "Hi there!"}, + }, + API: ai.APIOpenAICompletions, + Provider: "github-copilot", + Model: "gpt-4o", + StopReason: ai.StopReasonStop, + Timestamp: now, + }, + } + + result := TransformMessages(messages, model, anthropicNormalizeID) + var assistant ai.Message + for _, msg := range result { + if msg.Role == ai.RoleAssistant { + assistant = msg + break + } + } + thinkingBlocks := 0 + textBlocks := 0 + for _, block := range assistant.Content { + if block.Type == ai.ContentTypeThinking { + thinkingBlocks++ + } + if block.Type == ai.ContentTypeText { + textBlocks++ + } + } + if thinkingBlocks != 0 { + t.Fatalf("expected no thinking blocks after cross-model transform") + } + if textBlocks < 2 { + t.Fatalf("expected at least two text blocks after transform, got %d", textBlocks) + } +} + +func TestTransformMessages_RemovesThoughtSignatureAcrossModels(t *testing.T) { + model := ai.Model{ + ID: "claude-sonnet-4", + API: ai.APIAnthropicMessages, + Provider: "github-copilot", + } + now := time.Now().UnixMilli() + messages := []ai.Message{ + {Role: ai.RoleUser, Text: "run command", Timestamp: now}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{{ + Type: ai.ContentTypeToolCall, + ID: "call_123", + Name: "bash", + Arguments: map[string]any{"command": "ls"}, + ThoughtSignature: `{"type":"reasoning.encrypted"}`, + }}, + API: ai.APIOpenAIResponses, + Provider: "github-copilot", + Model: "gpt-5", + StopReason: ai.StopReasonToolUse, + Timestamp: now, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call_123", + ToolName: "bash", + Content: []ai.ContentBlock{{ + Type: ai.ContentTypeText, + Text: "output", + }}, + Timestamp: now, + }, + } + + result := TransformMessages(messages, model, anthropicNormalizeID) + for _, msg := range result { + if msg.Role != ai.RoleAssistant { + continue + } + for _, block := range msg.Content { + if block.Type == ai.ContentTypeToolCall && block.ThoughtSignature != "" { + t.Fatalf("expected thoughtSignature to be removed across model handoff") + } + } + } +} diff --git a/pkg/ai/stream.go b/pkg/ai/stream.go new file mode 100644 index 00000000..3d187e49 --- /dev/null +++ b/pkg/ai/stream.go @@ -0,0 +1,56 @@ +package ai + +import ( + "context" + "fmt" +) + +func Stream(model Model, c Context, options *StreamOptions) (*AssistantMessageEventStream, error) { + provider, err := ResolveAPIProvider(model.API) + if err != nil { + return nil, err + } + if provider.Stream == nil { + return nil, fmt.Errorf("provider %s has no stream function", model.API) + } + return provider.Stream(model, c, options), nil +} + +func Complete(model Model, c Context, options *StreamOptions) (Message, error) { + s, err := Stream(model, c, options) + if err != nil { + return Message{}, err + } + for { + _, nextErr := s.Next(context.Background()) + if nextErr != nil { + break + } + } + return s.Result() +} + +func StreamSimple(model Model, c Context, options *SimpleStreamOptions) (*AssistantMessageEventStream, error) { + provider, err := ResolveAPIProvider(model.API) + if err != nil { + return nil, err + } + if provider.StreamSimple == nil { + return nil, fmt.Errorf("provider %s has no streamSimple function", model.API) + } + return provider.StreamSimple(model, c, options), nil +} + +func CompleteSimple(model Model, c Context, options *SimpleStreamOptions) (Message, error) { + s, err := StreamSimple(model, c, options) + if err != nil { + return Message{}, err + } + for { + _, nextErr := s.Next(context.Background()) + if nextErr != nil { + break + } + } + return s.Result() +} diff --git a/pkg/ai/types.go b/pkg/ai/types.go new file mode 100644 index 00000000..a05c82a4 --- /dev/null +++ b/pkg/ai/types.go @@ -0,0 +1,236 @@ +package ai + +import "context" + +type Api string + +const ( + APIOpenAICompletions Api = "openai-completions" + APIOpenAIResponses Api = "openai-responses" + APIAzureOpenAIResponse Api = "azure-openai-responses" + APIOpenAICodexResponse Api = "openai-codex-responses" + APIAnthropicMessages Api = "anthropic-messages" + APIBedrockConverse Api = "bedrock-converse-stream" + APIGoogleGenerativeAI Api = "google-generative-ai" + APIGoogleGeminiCLI Api = "google-gemini-cli" + APIGoogleVertex Api = "google-vertex" +) + +type Provider string + +type ThinkingLevel string + +const ( + ThinkingMinimal ThinkingLevel = "minimal" + ThinkingLow ThinkingLevel = "low" + ThinkingMedium ThinkingLevel = "medium" + ThinkingHigh ThinkingLevel = "high" + ThinkingXHigh ThinkingLevel = "xhigh" +) + +type ThinkingBudgets struct { + Minimal int + Low int + Medium int + High int +} + +type CacheRetention string + +const ( + CacheRetentionNone CacheRetention = "none" + CacheRetentionShort CacheRetention = "short" + CacheRetentionLong CacheRetention = "long" +) + +type Transport string + +const ( + TransportSSE Transport = "sse" + TransportWebSocket Transport = "websocket" + TransportAuto Transport = "auto" +) + +type StreamOptions struct { + Temperature *float64 + MaxTokens int + Ctx context.Context + APIKey string + Transport Transport + CacheRetention CacheRetention + SessionID string + OnPayload func(any) + Headers map[string]string + MaxRetryDelayMs int + Metadata map[string]any +} + +type SimpleStreamOptions struct { + StreamOptions + Reasoning ThinkingLevel + ThinkingBudgets ThinkingBudgets +} + +type ContentType string + +const ( + ContentTypeText ContentType = "text" + ContentTypeThinking ContentType = "thinking" + ContentTypeImage ContentType = "image" + ContentTypeToolCall ContentType = "toolCall" +) + +type ContentBlock struct { + Type ContentType `json:"type"` + + Text string `json:"text,omitempty"` + TextSignature string `json:"textSignature,omitempty"` + + Thinking string `json:"thinking,omitempty"` + ThinkingSignature string `json:"thinkingSignature,omitempty"` + Redacted bool `json:"redacted,omitempty"` + + Data string `json:"data,omitempty"` + MimeType string `json:"mimeType,omitempty"` + + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` +} + +type UsageCost struct { + Input float64 `json:"input"` + Output float64 `json:"output"` + CacheRead float64 `json:"cacheRead"` + CacheWrite float64 `json:"cacheWrite"` + Total float64 `json:"total"` +} + +type Usage struct { + Input int `json:"input"` + Output int `json:"output"` + CacheRead int `json:"cacheRead"` + CacheWrite int `json:"cacheWrite"` + TotalTokens int `json:"totalTokens"` + Cost UsageCost `json:"cost"` +} + +type StopReason string + +const ( + StopReasonStop StopReason = "stop" + StopReasonLength StopReason = "length" + StopReasonToolUse StopReason = "toolUse" + StopReasonError StopReason = "error" + StopReasonAborted StopReason = "aborted" +) + +type MessageRole string + +const ( + RoleUser MessageRole = "user" + RoleAssistant MessageRole = "assistant" + RoleToolResult MessageRole = "toolResult" +) + +type Message struct { + Role MessageRole `json:"role"` + + // user message can be string or blocks + Text string `json:"text,omitempty"` + Content []ContentBlock `json:"content,omitempty"` + + // assistant metadata + API Api `json:"api,omitempty"` + Provider Provider `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Usage Usage `json:"usage,omitempty"` + StopReason StopReason `json:"stopReason,omitempty"` + ErrorMessage string `json:"errorMessage,omitempty"` + + // toolResult metadata + ToolCallID string `json:"toolCallId,omitempty"` + ToolName string `json:"toolName,omitempty"` + IsError bool `json:"isError,omitempty"` + + Timestamp int64 `json:"timestamp"` +} + +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` +} + +type Context struct { + SystemPrompt string `json:"systemPrompt,omitempty"` + Messages []Message `json:"messages"` + Tools []Tool `json:"tools,omitempty"` +} + +type OpenAICompletionsCompat struct { + SupportsStore *bool `json:"supportsStore,omitempty"` + SupportsDeveloperRole *bool `json:"supportsDeveloperRole,omitempty"` + SupportsReasoningEffort *bool `json:"supportsReasoningEffort,omitempty"` + ReasoningEffortMap map[ThinkingLevel]string `json:"reasoningEffortMap,omitempty"` + SupportsUsageInStreaming *bool `json:"supportsUsageInStreaming,omitempty"` + MaxTokensField string `json:"maxTokensField,omitempty"` // max_completion_tokens|max_tokens + RequiresToolResultName *bool `json:"requiresToolResultName,omitempty"` + RequiresAssistantAfterToolResult *bool `json:"requiresAssistantAfterToolResult,omitempty"` + RequiresThinkingAsText *bool `json:"requiresThinkingAsText,omitempty"` + RequiresMistralToolIDs *bool `json:"requiresMistralToolIds,omitempty"` + ThinkingFormat string `json:"thinkingFormat,omitempty"` // openai|zai|qwen + SupportsStrictMode *bool `json:"supportsStrictMode,omitempty"` +} + +type ModelCost struct { + Input float64 `json:"input"` + Output float64 `json:"output"` + CacheRead float64 `json:"cacheRead"` + CacheWrite float64 `json:"cacheWrite"` +} + +type Model struct { + ID string `json:"id"` + Name string `json:"name"` + API Api `json:"api"` + Provider Provider `json:"provider"` + BaseURL string `json:"baseUrl"` + Reasoning bool `json:"reasoning"` + Input []string `json:"input"` // text,image + Cost ModelCost `json:"cost"` + ContextWindow int `json:"contextWindow"` + MaxTokens int `json:"maxTokens"` + Headers map[string]string `json:"headers,omitempty"` + Compat *OpenAICompletionsCompat `json:"compat,omitempty"` +} + +type AssistantMessageEventType string + +const ( + EventStart AssistantMessageEventType = "start" + EventTextStart AssistantMessageEventType = "text_start" + EventTextDelta AssistantMessageEventType = "text_delta" + EventTextEnd AssistantMessageEventType = "text_end" + EventThinkingStart AssistantMessageEventType = "thinking_start" + EventThinkingDelta AssistantMessageEventType = "thinking_delta" + EventThinkingEnd AssistantMessageEventType = "thinking_end" + EventToolCallStart AssistantMessageEventType = "toolcall_start" + EventToolCallDelta AssistantMessageEventType = "toolcall_delta" + EventToolCallEnd AssistantMessageEventType = "toolcall_end" + EventDone AssistantMessageEventType = "done" + EventError AssistantMessageEventType = "error" +) + +type AssistantMessageEvent struct { + Type AssistantMessageEventType + ContentIndex int + Delta string + Content string + ToolCall *ContentBlock + Partial Message + Message Message + Error Message + Reason StopReason +} diff --git a/pkg/ai/utils/json_parse.go b/pkg/ai/utils/json_parse.go new file mode 100644 index 00000000..19c0abb9 --- /dev/null +++ b/pkg/ai/utils/json_parse.go @@ -0,0 +1,32 @@ +package utils + +import "encoding/json" + +// ParseStreamingJSON attempts to decode possibly incomplete JSON. +// If parsing fails it best-effort trims incomplete suffixes and retries, +// returning an empty object when no parse can succeed. +func ParseStreamingJSON(partial string) map[string]any { + if partial == "" { + return map[string]any{} + } + + var out map[string]any + if err := json.Unmarshal([]byte(partial), &out); err == nil && out != nil { + return out + } + + // Best-effort fallback: trim the tail and try to parse repeatedly. + for i := len(partial) - 1; i > 1; i-- { + ch := partial[i] + switch ch { + case '{', '[', ',', ':': + continue + } + candidate := partial[:i] + out = nil + if err := json.Unmarshal([]byte(candidate), &out); err == nil && out != nil { + return out + } + } + return map[string]any{} +} diff --git a/pkg/ai/utils/json_parse_test.go b/pkg/ai/utils/json_parse_test.go new file mode 100644 index 00000000..03a1a8f9 --- /dev/null +++ b/pkg/ai/utils/json_parse_test.go @@ -0,0 +1,15 @@ +package utils + +import "testing" + +func TestParseStreamingJSON(t *testing.T) { + parsed := ParseStreamingJSON(`{"a":1,"b":"x"}`) + if parsed["a"] != float64(1) || parsed["b"] != "x" { + t.Fatalf("unexpected parsed output: %#v", parsed) + } + + partial := ParseStreamingJSON(`{"a":1,"b":"x`) + if len(partial) != 0 { + t.Fatalf("expected empty fallback map for malformed partial json, got %#v", partial) + } +} diff --git a/pkg/ai/utils/overflow.go b/pkg/ai/utils/overflow.go new file mode 100644 index 00000000..1bc498bf --- /dev/null +++ b/pkg/ai/utils/overflow.go @@ -0,0 +1,50 @@ +package utils + +import ( + "regexp" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +var overflowPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)prompt is too long`), + regexp.MustCompile(`(?i)input is too long for requested model`), + regexp.MustCompile(`(?i)exceeds the context window`), + regexp.MustCompile(`(?i)input token count.*exceeds the maximum`), + regexp.MustCompile(`(?i)maximum prompt length is \d+`), + regexp.MustCompile(`(?i)reduce the length of the messages`), + regexp.MustCompile(`(?i)maximum context length is \d+ tokens`), + regexp.MustCompile(`(?i)exceeds the limit of \d+`), + regexp.MustCompile(`(?i)exceeds the available context size`), + regexp.MustCompile(`(?i)greater than the context length`), + regexp.MustCompile(`(?i)context window exceeds limit`), + regexp.MustCompile(`(?i)exceeded model token limit`), + regexp.MustCompile(`(?i)context[_ ]length[_ ]exceeded`), + regexp.MustCompile(`(?i)too many tokens`), + regexp.MustCompile(`(?i)token limit exceeded`), +} + +func IsContextOverflow(message ai.Message, contextWindow int) bool { + if message.StopReason == ai.StopReasonError && message.ErrorMessage != "" { + for _, p := range overflowPatterns { + if p.MatchString(message.ErrorMessage) { + return true + } + } + if regexp.MustCompile(`(?i)^4(00|13)\s*(status code)?\s*\(no body\)`).MatchString(message.ErrorMessage) { + return true + } + } + + if contextWindow > 0 && message.StopReason == ai.StopReasonStop { + inputTokens := message.Usage.Input + message.Usage.CacheRead + return inputTokens > contextWindow + } + return false +} + +func GetOverflowPatterns() []*regexp.Regexp { + result := make([]*regexp.Regexp, 0, len(overflowPatterns)) + result = append(result, overflowPatterns...) + return result +} diff --git a/pkg/ai/utils/overflow_test.go b/pkg/ai/utils/overflow_test.go new file mode 100644 index 00000000..6a07b36c --- /dev/null +++ b/pkg/ai/utils/overflow_test.go @@ -0,0 +1,32 @@ +package utils + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestIsContextOverflow_ErrorPattern(t *testing.T) { + msg := ai.Message{ + Role: ai.RoleAssistant, + StopReason: ai.StopReasonError, + ErrorMessage: "prompt is too long: 213462 tokens > 200000 maximum", + } + if !IsContextOverflow(msg, 0) { + t.Fatalf("expected context overflow for anthropic style error") + } +} + +func TestIsContextOverflow_SilentOverflow(t *testing.T) { + msg := ai.Message{ + Role: ai.RoleAssistant, + StopReason: ai.StopReasonStop, + Usage: ai.Usage{ + Input: 1200, + CacheRead: 100, + }, + } + if !IsContextOverflow(msg, 1000) { + t.Fatalf("expected silent overflow when usage exceeds context window") + } +} diff --git a/pkg/ai/utils/sanitize_unicode.go b/pkg/ai/utils/sanitize_unicode.go new file mode 100644 index 00000000..b1ff2a82 --- /dev/null +++ b/pkg/ai/utils/sanitize_unicode.go @@ -0,0 +1,45 @@ +package utils + +import ( + "strings" + "unicode/utf8" +) + +// SanitizeSurrogates removes unpaired UTF-16 surrogate code points. +func SanitizeSurrogates(text string) string { + if text == "" { + return text + } + + var out strings.Builder + runes := make([]rune, 0, len(text)) + for len(text) > 0 { + r, size := utf8.DecodeRuneInString(text) + if r == utf8.RuneError && size == 1 { + text = text[size:] + continue + } + runes = append(runes, r) + text = text[size:] + } + + for i := 0; i < len(runes); i++ { + r := runes[i] + if r >= 0xD800 && r <= 0xDBFF { + if i+1 < len(runes) { + next := runes[i+1] + if next >= 0xDC00 && next <= 0xDFFF { + out.WriteRune(r) + out.WriteRune(next) + i++ + } + } + continue + } + if r >= 0xDC00 && r <= 0xDFFF { + continue + } + out.WriteRune(r) + } + return out.String() +} diff --git a/pkg/ai/utils/sanitize_unicode_test.go b/pkg/ai/utils/sanitize_unicode_test.go new file mode 100644 index 00000000..cd4ea7ec --- /dev/null +++ b/pkg/ai/utils/sanitize_unicode_test.go @@ -0,0 +1,16 @@ +package utils + +import "testing" + +func TestSanitizeSurrogates(t *testing.T) { + in := "Hello 🙈 World" + if got := SanitizeSurrogates(in); got != in { + t.Fatalf("expected valid emoji pair unchanged, got %q", got) + } + + invalidSurrogateBytes := string([]byte{0xED, 0xA0, 0xBD}) // UTF-8 bytes for surrogate range (invalid) + got := SanitizeSurrogates("Text " + invalidSurrogateBytes + " here") + if got != "Text here" { + t.Fatalf("expected invalid surrogate bytes removed, got %q", got) + } +} From 2f369c237bba188e1000ef827d55cdb859993e67 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 02:51:35 +0000 Subject: [PATCH 02/75] Add OpenAI/Anthropic cache retention payload builders MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/anthropic.go | 217 +++++++++++++++++++++ pkg/ai/providers/cache_retention_test.go | 162 ++++++++++++++++ pkg/ai/providers/openai_responses.go | 234 +++++++++++++++++++++++ 3 files changed, 613 insertions(+) create mode 100644 pkg/ai/providers/anthropic.go create mode 100644 pkg/ai/providers/cache_retention_test.go create mode 100644 pkg/ai/providers/openai_responses.go diff --git a/pkg/ai/providers/anthropic.go b/pkg/ai/providers/anthropic.go new file mode 100644 index 00000000..1c898e91 --- /dev/null +++ b/pkg/ai/providers/anthropic.go @@ -0,0 +1,217 @@ +package providers + +import ( + "os" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +type AnthropicOptions struct { + StreamOptions ai.StreamOptions + ThinkingEnabled bool + ThinkingBudgetTokens int + Effort string + InterleavedThinking bool + ToolChoice string +} + +type cacheControl struct { + Type string `json:"type"` + TTL string `json:"ttl,omitempty"` +} + +func resolveAnthropicCacheRetention(cacheRetention ai.CacheRetention) ai.CacheRetention { + if cacheRetention != "" { + return cacheRetention + } + if strings.EqualFold(os.Getenv("PI_CACHE_RETENTION"), "long") { + return ai.CacheRetentionLong + } + return ai.CacheRetentionShort +} + +func GetAnthropicCacheControl(baseURL string, cacheRetention ai.CacheRetention) (ai.CacheRetention, *cacheControl) { + retention := resolveAnthropicCacheRetention(cacheRetention) + if retention == ai.CacheRetentionNone { + return retention, nil + } + cc := &cacheControl{Type: "ephemeral"} + if retention == ai.CacheRetentionLong && strings.Contains(baseURL, "api.anthropic.com") { + cc.TTL = "1h" + } + return retention, cc +} + +func BuildAnthropicParams(model ai.Model, context ai.Context, options AnthropicOptions) map[string]any { + params := map[string]any{ + "model": model.ID, + "stream": true, + "max_tokens": max(1024, options.StreamOptions.MaxTokens), + "messages": convertAnthropicMessages(model, context), + } + + _, cache := GetAnthropicCacheControl(model.BaseURL, options.StreamOptions.CacheRetention) + if strings.TrimSpace(context.SystemPrompt) != "" { + systemPart := map[string]any{ + "type": "text", + "text": utils.SanitizeSurrogates(context.SystemPrompt), + } + if cache != nil { + systemPart["cache_control"] = map[string]any{ + "type": cache.Type, + } + if cache.TTL != "" { + systemPart["cache_control"].(map[string]any)["ttl"] = cache.TTL + } + } + params["system"] = []map[string]any{systemPart} + } + + if options.StreamOptions.Temperature != nil { + params["temperature"] = *options.StreamOptions.Temperature + } + if options.ToolChoice != "" { + params["tool_choice"] = map[string]any{"type": options.ToolChoice} + } + if len(context.Tools) > 0 { + params["tools"] = convertAnthropicTools(context.Tools) + } + if options.ThinkingEnabled { + thinking := map[string]any{"type": "enabled"} + if options.ThinkingBudgetTokens > 0 { + thinking["budget_tokens"] = options.ThinkingBudgetTokens + } + if strings.TrimSpace(options.Effort) != "" { + thinking["effort"] = options.Effort + } + params["thinking"] = thinking + } + if options.InterleavedThinking { + params["anthropic-beta"] = "interleaved-thinking-2025-05-14" + } + return params +} + +func convertAnthropicMessages(model ai.Model, context ai.Context) []map[string]any { + transformed := TransformMessages(context.Messages, model, func(id string, _ ai.Model, _ ai.Message) string { + sanitized := strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + return r + } + return '_' + }, id) + if len(sanitized) > 64 { + return sanitized[:64] + } + return sanitized + }) + + out := make([]map[string]any, 0, len(transformed)) + for _, msg := range transformed { + switch msg.Role { + case ai.RoleUser: + parts := []map[string]any{} + if strings.TrimSpace(msg.Text) != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": utils.SanitizeSurrogates(msg.Text), + }) + } + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText && strings.TrimSpace(block.Text) != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": utils.SanitizeSurrogates(block.Text), + }) + } + if block.Type == ai.ContentTypeImage { + parts = append(parts, map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": block.MimeType, + "data": block.Data, + }, + }) + } + } + if len(parts) == 0 { + continue + } + out = append(out, map[string]any{ + "role": "user", + "content": parts, + }) + case ai.RoleAssistant: + parts := []map[string]any{} + for _, block := range msg.Content { + switch block.Type { + case ai.ContentTypeText: + if strings.TrimSpace(block.Text) == "" { + continue + } + parts = append(parts, map[string]any{"type": "text", "text": utils.SanitizeSurrogates(block.Text)}) + case ai.ContentTypeThinking: + if strings.TrimSpace(block.Thinking) == "" { + continue + } + parts = append(parts, map[string]any{"type": "thinking", "thinking": utils.SanitizeSurrogates(block.Thinking)}) + case ai.ContentTypeToolCall: + parts = append(parts, map[string]any{ + "type": "tool_use", + "id": block.ID, + "name": block.Name, + "input": block.Arguments, + }) + } + } + if len(parts) == 0 { + continue + } + out = append(out, map[string]any{ + "role": "assistant", + "content": parts, + }) + case ai.RoleToolResult: + resultText := "" + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText { + if resultText != "" { + resultText += "\n" + } + resultText += block.Text + } + } + if strings.TrimSpace(resultText) == "" { + resultText = "(see attached image)" + } + out = append(out, map[string]any{ + "role": "user", + "content": []map[string]any{{ + "type": "tool_result", + "tool_use_id": msg.ToolCallID, + "is_error": msg.IsError, + "content": []map[string]any{{ + "type": "text", + "text": utils.SanitizeSurrogates(resultText), + }}, + }}, + }) + } + } + return out +} + +func convertAnthropicTools(tools []ai.Tool) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + out = append(out, map[string]any{ + "name": tool.Name, + "description": tool.Description, + "input_schema": tool.Parameters, + }) + } + return out +} diff --git a/pkg/ai/providers/cache_retention_test.go b/pkg/ai/providers/cache_retention_test.go new file mode 100644 index 00000000..83d653f0 --- /dev/null +++ b/pkg/ai/providers/cache_retention_test.go @@ -0,0 +1,162 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestCacheRetentionAnthropicAndOpenAIResponses(t *testing.T) { + baseContext := ai.Context{ + SystemPrompt: "You are a helpful assistant.", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "Hello", Timestamp: 1}, + }, + } + + t.Run("anthropic default short has ephemeral cache no ttl", func(t *testing.T) { + t.Setenv("PI_CACHE_RETENTION", "") + model := ai.Model{ + ID: "claude-3-5-haiku-20241022", + API: ai.APIAnthropicMessages, + Provider: "anthropic", + BaseURL: "https://api.anthropic.com", + } + params := BuildAnthropicParams(model, baseContext, AnthropicOptions{}) + system := params["system"].([]map[string]any) + cc := system[0]["cache_control"].(map[string]any) + if cc["type"] != "ephemeral" { + t.Fatalf("expected ephemeral cache control, got %#v", cc) + } + if _, ok := cc["ttl"]; ok { + t.Fatalf("expected no ttl for default retention") + } + }) + + t.Run("anthropic long retention adds 1h ttl for direct api", func(t *testing.T) { + t.Setenv("PI_CACHE_RETENTION", "long") + model := ai.Model{ + ID: "claude-3-5-haiku-20241022", + API: ai.APIAnthropicMessages, + Provider: "anthropic", + BaseURL: "https://api.anthropic.com", + } + params := BuildAnthropicParams(model, baseContext, AnthropicOptions{}) + system := params["system"].([]map[string]any) + cc := system[0]["cache_control"].(map[string]any) + if cc["ttl"] != "1h" { + t.Fatalf("expected ttl=1h, got %#v", cc["ttl"]) + } + }) + + t.Run("anthropic long retention omits ttl on proxy base url", func(t *testing.T) { + model := ai.Model{ + ID: "claude-3-5-haiku-20241022", + API: ai.APIAnthropicMessages, + Provider: "anthropic", + BaseURL: "https://my-proxy.example.com/v1", + } + params := BuildAnthropicParams(model, baseContext, AnthropicOptions{ + StreamOptions: ai.StreamOptions{ + CacheRetention: ai.CacheRetentionLong, + }, + }) + system := params["system"].([]map[string]any) + cc := system[0]["cache_control"].(map[string]any) + if _, ok := cc["ttl"]; ok { + t.Fatalf("expected ttl omitted for proxy base url") + } + }) + + t.Run("anthropic cache retention none omits cache_control", func(t *testing.T) { + model := ai.Model{ + ID: "claude-3-5-haiku-20241022", + API: ai.APIAnthropicMessages, + Provider: "anthropic", + BaseURL: "https://api.anthropic.com", + } + params := BuildAnthropicParams(model, baseContext, AnthropicOptions{ + StreamOptions: ai.StreamOptions{ + CacheRetention: ai.CacheRetentionNone, + }, + }) + system := params["system"].([]map[string]any) + if _, ok := system[0]["cache_control"]; ok { + t.Fatalf("expected cache_control omitted for cacheRetention=none") + } + }) + + t.Run("openai responses default has no prompt_cache_retention", func(t *testing.T) { + t.Setenv("PI_CACHE_RETENTION", "") + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAIResponses, + Provider: "openai", + BaseURL: "https://api.openai.com/v1", + } + params := BuildOpenAIResponsesParams(model, baseContext, OpenAIResponsesOptions{}) + if _, ok := params["prompt_cache_retention"]; ok { + t.Fatalf("expected prompt_cache_retention omitted by default") + } + }) + + t.Run("openai responses long sets retention and key", func(t *testing.T) { + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAIResponses, + Provider: "openai", + BaseURL: "https://api.openai.com/v1", + } + params := BuildOpenAIResponsesParams(model, baseContext, OpenAIResponsesOptions{ + StreamOptions: ai.StreamOptions{ + CacheRetention: ai.CacheRetentionLong, + SessionID: "session-2", + }, + }) + if params["prompt_cache_key"] != "session-2" { + t.Fatalf("expected prompt_cache_key=session-2, got %v", params["prompt_cache_key"]) + } + if params["prompt_cache_retention"] != "24h" { + t.Fatalf("expected prompt_cache_retention=24h, got %v", params["prompt_cache_retention"]) + } + }) + + t.Run("openai responses long proxy base omits prompt_cache_retention", func(t *testing.T) { + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAIResponses, + Provider: "openai", + BaseURL: "https://my-proxy.example.com/v1", + } + params := BuildOpenAIResponsesParams(model, baseContext, OpenAIResponsesOptions{ + StreamOptions: ai.StreamOptions{ + CacheRetention: ai.CacheRetentionLong, + SessionID: "session-2", + }, + }) + if _, ok := params["prompt_cache_retention"]; ok { + t.Fatalf("expected prompt_cache_retention omitted for proxy base URL") + } + }) + + t.Run("openai responses cache retention none omits key and retention", func(t *testing.T) { + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAIResponses, + Provider: "openai", + BaseURL: "https://api.openai.com/v1", + } + params := BuildOpenAIResponsesParams(model, baseContext, OpenAIResponsesOptions{ + StreamOptions: ai.StreamOptions{ + CacheRetention: ai.CacheRetentionNone, + SessionID: "session-1", + }, + }) + if _, ok := params["prompt_cache_key"]; ok { + t.Fatalf("expected prompt_cache_key omitted for cacheRetention=none") + } + if _, ok := params["prompt_cache_retention"]; ok { + t.Fatalf("expected prompt_cache_retention omitted for cacheRetention=none") + } + }) +} diff --git a/pkg/ai/providers/openai_responses.go b/pkg/ai/providers/openai_responses.go new file mode 100644 index 00000000..ec793524 --- /dev/null +++ b/pkg/ai/providers/openai_responses.go @@ -0,0 +1,234 @@ +package providers + +import ( + "encoding/json" + "os" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +type OpenAIResponsesOptions struct { + StreamOptions ai.StreamOptions + ReasoningEffort ai.ThinkingLevel + ReasoningSummary string + ServiceTier string +} + +func ResolveCacheRetention(cacheRetention ai.CacheRetention) ai.CacheRetention { + if cacheRetention != "" { + return cacheRetention + } + if strings.EqualFold(os.Getenv("PI_CACHE_RETENTION"), "long") { + return ai.CacheRetentionLong + } + return ai.CacheRetentionShort +} + +func GetPromptCacheRetention(baseURL string, cacheRetention ai.CacheRetention) string { + if cacheRetention != ai.CacheRetentionLong { + return "" + } + if strings.Contains(baseURL, "api.openai.com") { + return "24h" + } + return "" +} + +func BuildOpenAIResponsesParams(model ai.Model, context ai.Context, options OpenAIResponsesOptions) map[string]any { + messages := ConvertOpenAIResponsesMessages(model, context) + retention := ResolveCacheRetention(options.StreamOptions.CacheRetention) + params := map[string]any{ + "model": model.ID, + "input": messages, + "stream": true, + "store": false, + } + if options.StreamOptions.MaxTokens > 0 { + params["max_output_tokens"] = options.StreamOptions.MaxTokens + } + if options.StreamOptions.Temperature != nil { + params["temperature"] = *options.StreamOptions.Temperature + } + if options.ServiceTier != "" { + params["service_tier"] = options.ServiceTier + } + if context.Tools != nil { + params["tools"] = convertResponsesTools(context.Tools) + } + if retention != ai.CacheRetentionNone && strings.TrimSpace(options.StreamOptions.SessionID) != "" { + params["prompt_cache_key"] = options.StreamOptions.SessionID + } + if cache := GetPromptCacheRetention(model.BaseURL, retention); cache != "" { + params["prompt_cache_retention"] = cache + } + if model.Reasoning { + if options.ReasoningEffort != "" || strings.TrimSpace(options.ReasoningSummary) != "" { + summary := options.ReasoningSummary + if summary == "" { + summary = "auto" + } + effort := options.ReasoningEffort + if effort == "" { + effort = ai.ThinkingMedium + } + params["reasoning"] = map[string]any{ + "effort": string(effort), + "summary": summary, + } + params["include"] = []string{"reasoning.encrypted_content"} + } else if strings.HasPrefix(strings.ToLower(model.Name), "gpt-5") { + messages = append(messages, map[string]any{ + "role": "developer", + "content": []map[string]any{{ + "type": "input_text", + "text": "# Juice: 0 !important", + }}, + }) + params["input"] = messages + } + } + return params +} + +func ConvertOpenAIResponsesMessages(model ai.Model, context ai.Context) []map[string]any { + messages := make([]map[string]any, 0, len(context.Messages)+1) + if strings.TrimSpace(context.SystemPrompt) != "" { + role := "system" + if model.Reasoning { + role = "developer" + } + messages = append(messages, map[string]any{ + "role": role, + "content": utils.SanitizeSurrogates(context.SystemPrompt), + }) + } + + transformed := TransformMessages(context.Messages, model, nil) + for _, msg := range transformed { + switch msg.Role { + case ai.RoleUser: + content := []map[string]any{} + if strings.TrimSpace(msg.Text) != "" { + content = append(content, map[string]any{ + "type": "input_text", + "text": utils.SanitizeSurrogates(msg.Text), + }) + } + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText && strings.TrimSpace(block.Text) != "" { + content = append(content, map[string]any{ + "type": "input_text", + "text": utils.SanitizeSurrogates(block.Text), + }) + } + if block.Type == ai.ContentTypeImage { + content = append(content, map[string]any{ + "type": "input_image", + "detail": "auto", + "image_url": "data:" + block.MimeType + ";base64," + block.Data, + }) + } + } + if len(content) == 0 { + continue + } + messages = append(messages, map[string]any{ + "role": "user", + "content": content, + }) + case ai.RoleAssistant: + for _, block := range msg.Content { + switch block.Type { + case ai.ContentTypeText: + messages = append(messages, map[string]any{ + "type": "message", + "role": "assistant", + "status": "completed", + "id": fallbackTextID(block.TextSignature), + "content": []map[string]any{{ + "type": "output_text", + "text": utils.SanitizeSurrogates(block.Text), + "annotations": []any{}, + }}, + }) + case ai.ContentTypeThinking: + if block.ThinkingSignature != "" { + // signature payload is already serialized response item. + // best-effort keep as text fallback when opaque. + messages = append(messages, map[string]any{ + "type": "reasoning", + "summary": []map[string]any{{"type": "summary_text", "text": block.Thinking}}, + }) + } + case ai.ContentTypeToolCall: + parts := strings.SplitN(block.ID, "|", 2) + callID := block.ID + itemID := "" + if len(parts) == 2 { + callID = parts[0] + itemID = parts[1] + } + args := "{}" + if block.Arguments != nil { + b, _ := json.Marshal(block.Arguments) + args = string(b) + } + messages = append(messages, map[string]any{ + "type": "function_call", + "id": itemID, + "call_id": callID, + "name": block.Name, + "arguments": args, + }) + } + } + case ai.RoleToolResult: + callID := msg.ToolCallID + if strings.Contains(callID, "|") { + callID = strings.SplitN(callID, "|", 2)[0] + } + output := "(see attached image)" + var textParts []string + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText { + textParts = append(textParts, block.Text) + } + } + if len(textParts) > 0 { + output = strings.Join(textParts, "\n") + } + messages = append(messages, map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": utils.SanitizeSurrogates(output), + }) + } + } + return messages +} + +func fallbackTextID(signature string) string { + if strings.TrimSpace(signature) != "" { + if len(signature) > 64 { + return "msg_" + signature[:16] + } + return signature + } + return "msg_0" +} + +func convertResponsesTools(tools []ai.Tool) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + out = append(out, map[string]any{ + "type": "function", + "name": tool.Name, + "description": tool.Description, + "parameters": tool.Parameters, + "strict": false, + }) + } + return out +} From 9338c1101b70edc956bb8ba3661b1e2914242012 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 02:54:21 +0000 Subject: [PATCH 03/75] Add codex SSE mapper and gemini tool-call normalizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/google_gemini_cli.go | 19 +++ .../google_tool_call_missing_args_test.go | 23 +++ pkg/ai/providers/openai_codex_responses.go | 148 ++++++++++++++++++ .../providers/openai_codex_responses_test.go | 65 ++++++++ 4 files changed, 255 insertions(+) create mode 100644 pkg/ai/providers/google_tool_call_missing_args_test.go create mode 100644 pkg/ai/providers/openai_codex_responses.go create mode 100644 pkg/ai/providers/openai_codex_responses_test.go diff --git a/pkg/ai/providers/google_gemini_cli.go b/pkg/ai/providers/google_gemini_cli.go index d5c8fa12..90598645 100644 --- a/pkg/ai/providers/google_gemini_cli.go +++ b/pkg/ai/providers/google_gemini_cli.go @@ -6,6 +6,8 @@ import ( "strconv" "strings" "time" + + "github.com/beeper/ai-bridge/pkg/ai" ) func ExtractRetryDelay(errorText string, headers http.Header) (int, bool) { @@ -111,3 +113,20 @@ func headerGetCI(headers http.Header, key string) string { } return "" } + +func NormalizeGoogleToolCall(name string, args map[string]any, id string, thoughtSignature string) ai.ContentBlock { + normalizedArgs := args + if normalizedArgs == nil { + normalizedArgs = map[string]any{} + } + block := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: id, + Name: name, + Arguments: normalizedArgs, + } + if strings.TrimSpace(thoughtSignature) != "" { + block.ThoughtSignature = thoughtSignature + } + return block +} diff --git a/pkg/ai/providers/google_tool_call_missing_args_test.go b/pkg/ai/providers/google_tool_call_missing_args_test.go new file mode 100644 index 00000000..2f05453e --- /dev/null +++ b/pkg/ai/providers/google_tool_call_missing_args_test.go @@ -0,0 +1,23 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestNormalizeGoogleToolCall_DefaultsMissingArgsToEmptyObject(t *testing.T) { + toolCall := NormalizeGoogleToolCall("get_status", nil, "call_1", "") + if toolCall.Type != ai.ContentTypeToolCall { + t.Fatalf("expected tool call block type, got %s", toolCall.Type) + } + if toolCall.Name != "get_status" { + t.Fatalf("unexpected tool name: %s", toolCall.Name) + } + if toolCall.Arguments == nil { + t.Fatalf("expected arguments map, got nil") + } + if len(toolCall.Arguments) != 0 { + t.Fatalf("expected empty arguments map, got %#v", toolCall.Arguments) + } +} diff --git a/pkg/ai/providers/openai_codex_responses.go b/pkg/ai/providers/openai_codex_responses.go new file mode 100644 index 00000000..5ec752a2 --- /dev/null +++ b/pkg/ai/providers/openai_codex_responses.go @@ -0,0 +1,148 @@ +package providers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +// ProcessCodexSSEPayload maps Codex SSE payload chunks into unified stream events. +// This is a deterministic helper used by tests while the full transport integration +// is being ported. +func ProcessCodexSSEPayload(payload string, output *ai.Message) ([]ai.AssistantMessageEvent, error) { + if output == nil { + return nil, fmt.Errorf("output message is required") + } + lines := strings.Split(payload, "\n") + events := make([]ai.AssistantMessageEvent, 0) + currentTextIndex := -1 + + emit := func(evt ai.AssistantMessageEvent) { + evt.Partial = *output + events = append(events, evt) + } + + for _, rawLine := range lines { + line := strings.TrimSpace(rawLine) + if line == "" || !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" || data == "[DONE]" { + continue + } + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return nil, err + } + eventType, _ := event["type"].(string) + switch eventType { + case "response.output_item.added": + item, _ := event["item"].(map[string]any) + itemType, _ := item["type"].(string) + if itemType != "message" { + continue + } + output.Content = append(output.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + }) + currentTextIndex = len(output.Content) - 1 + emit(ai.AssistantMessageEvent{ + Type: ai.EventTextStart, + ContentIndex: currentTextIndex, + }) + case "response.output_text.delta": + delta, _ := event["delta"].(string) + if currentTextIndex >= 0 && currentTextIndex < len(output.Content) { + output.Content[currentTextIndex].Text += delta + emit(ai.AssistantMessageEvent{ + Type: ai.EventTextDelta, + ContentIndex: currentTextIndex, + Delta: delta, + }) + } + case "response.output_item.done": + item, _ := event["item"].(map[string]any) + itemType, _ := item["type"].(string) + if itemType != "message" { + continue + } + content, _ := item["content"].([]any) + finalText := "" + for _, raw := range content { + part, _ := raw.(map[string]any) + if part["type"] == "output_text" { + if text, ok := part["text"].(string); ok { + finalText += text + } + } + } + if currentTextIndex >= 0 && currentTextIndex < len(output.Content) { + output.Content[currentTextIndex].Text = finalText + emit(ai.AssistantMessageEvent{ + Type: ai.EventTextEnd, + ContentIndex: currentTextIndex, + Content: finalText, + }) + } + case "response.completed": + response, _ := event["response"].(map[string]any) + status, _ := response["status"].(string) + usage, _ := response["usage"].(map[string]any) + inputTokens := int(numberValue(usage["input_tokens"])) + outputTokens := int(numberValue(usage["output_tokens"])) + totalTokens := int(numberValue(usage["total_tokens"])) + cacheRead := 0 + if details, ok := usage["input_tokens_details"].(map[string]any); ok { + cacheRead = int(numberValue(details["cached_tokens"])) + } + output.Usage = ai.Usage{ + Input: inputTokens - cacheRead, + Output: outputTokens, + CacheRead: cacheRead, + CacheWrite: 0, + TotalTokens: totalTokens, + } + if status == "completed" { + output.StopReason = ai.StopReasonStop + } else if status == "incomplete" { + output.StopReason = ai.StopReasonLength + } else { + output.StopReason = ai.StopReasonError + } + emit(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Reason: output.StopReason, + Message: *output, + }) + case "error": + msg, _ := event["message"].(string) + output.StopReason = ai.StopReasonError + output.ErrorMessage = msg + emit(ai.AssistantMessageEvent{ + Type: ai.EventError, + Reason: ai.StopReasonError, + Error: *output, + }) + } + } + + return events, nil +} + +func numberValue(v any) float64 { + switch value := v.(type) { + case float64: + return value + case float32: + return float64(value) + case int: + return float64(value) + case int64: + return float64(value) + default: + return 0 + } +} diff --git a/pkg/ai/providers/openai_codex_responses_test.go b/pkg/ai/providers/openai_codex_responses_test.go new file mode 100644 index 00000000..9df5e81d --- /dev/null +++ b/pkg/ai/providers/openai_codex_responses_test.go @@ -0,0 +1,65 @@ +package providers + +import ( + "strings" + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestProcessCodexSSEPayload_MapsToAssistantEvents(t *testing.T) { + payload := strings.Join([]string{ + `data: {"type":"response.output_item.added","item":{"type":"message","id":"msg_1","role":"assistant","status":"in_progress","content":[]}}`, + ``, + `data: {"type":"response.content_part.added","part":{"type":"output_text","text":""}}`, + ``, + `data: {"type":"response.output_text.delta","delta":"Hello"}`, + ``, + `data: {"type":"response.output_item.done","item":{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"Hello"}]}}`, + ``, + `data: {"type":"response.completed","response":{"status":"completed","usage":{"input_tokens":5,"output_tokens":3,"total_tokens":8,"input_tokens_details":{"cached_tokens":0}}}}`, + ``, + }, "\n") + + output := ai.Message{ + Role: ai.RoleAssistant, + API: ai.APIOpenAICodexResponse, + Provider: "openai-codex", + Model: "gpt-5.1-codex", + } + events, err := ProcessCodexSSEPayload(payload, &output) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(events) == 0 { + t.Fatalf("expected events from payload") + } + + var sawTextDelta, sawDone bool + for _, evt := range events { + if evt.Type == ai.EventTextDelta { + sawTextDelta = true + if evt.Delta != "Hello" { + t.Fatalf("expected text delta Hello, got %q", evt.Delta) + } + } + if evt.Type == ai.EventDone { + sawDone = true + if evt.Message.StopReason != ai.StopReasonStop { + t.Fatalf("expected done stop reason stop, got %s", evt.Message.StopReason) + } + } + } + if !sawTextDelta { + t.Fatalf("expected text delta event") + } + if !sawDone { + t.Fatalf("expected done event") + } + if output.Usage.TotalTokens != 8 || output.Usage.Input != 5 || output.Usage.Output != 3 { + t.Fatalf("unexpected usage: %+v", output.Usage) + } + if len(output.Content) != 1 || output.Content[0].Text != "Hello" { + t.Fatalf("unexpected output content: %+v", output.Content) + } +} From 415c8a5fcd47778af03919d52613b2daa4df8f7e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 02:57:03 +0000 Subject: [PATCH 04/75] Add OAuth registry and PKCE scaffolding for pkg/ai MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/oauth/index.go | 95 +++++++++++++++++++++++++++++++++++++ pkg/ai/oauth/index_test.go | 97 ++++++++++++++++++++++++++++++++++++++ pkg/ai/oauth/pkce.go | 26 ++++++++++ pkg/ai/oauth/pkce_test.go | 23 +++++++++ pkg/ai/oauth/types.go | 38 +++++++++++++++ 5 files changed, 279 insertions(+) create mode 100644 pkg/ai/oauth/index.go create mode 100644 pkg/ai/oauth/index_test.go create mode 100644 pkg/ai/oauth/pkce.go create mode 100644 pkg/ai/oauth/pkce_test.go create mode 100644 pkg/ai/oauth/types.go diff --git a/pkg/ai/oauth/index.go b/pkg/ai/oauth/index.go new file mode 100644 index 00000000..015cf8c6 --- /dev/null +++ b/pkg/ai/oauth/index.go @@ -0,0 +1,95 @@ +package oauth + +import ( + "fmt" + "sync" + "time" +) + +type providerInfo struct { + provider Provider + builtin bool +} + +var ( + providersMu sync.RWMutex + providers = map[ProviderID]providerInfo{} +) + +func GetProvider(id ProviderID) (Provider, bool) { + providersMu.RLock() + defer providersMu.RUnlock() + entry, ok := providers[id] + return entry.provider, ok +} + +func RegisterProvider(provider Provider) { + providersMu.Lock() + defer providersMu.Unlock() + providers[provider.ID()] = providerInfo{provider: provider} +} + +func RegisterBuiltinProvider(provider Provider) { + providersMu.Lock() + defer providersMu.Unlock() + providers[provider.ID()] = providerInfo{provider: provider, builtin: true} +} + +func UnregisterProvider(id ProviderID) { + providersMu.Lock() + defer providersMu.Unlock() + if entry, ok := providers[id]; ok && entry.builtin { + return + } + delete(providers, id) +} + +func ResetProviders() { + providersMu.Lock() + defer providersMu.Unlock() + next := map[ProviderID]providerInfo{} + for id, entry := range providers { + if entry.builtin { + next[id] = entry + } + } + providers = next +} + +func GetProviders() []Provider { + providersMu.RLock() + defer providersMu.RUnlock() + out := make([]Provider, 0, len(providers)) + for _, entry := range providers { + out = append(out, entry.provider) + } + return out +} + +func RefreshToken(providerID ProviderID, credentials Credentials) (Credentials, error) { + provider, ok := GetProvider(providerID) + if !ok { + return Credentials{}, fmt.Errorf("unknown OAuth provider: %s", providerID) + } + return provider.RefreshToken(credentials) +} + +func GetAPIKey(providerID ProviderID, credentials map[ProviderID]Credentials) (*Credentials, string, error) { + provider, ok := GetProvider(providerID) + if !ok { + return nil, "", fmt.Errorf("unknown OAuth provider: %s", providerID) + } + creds, ok := credentials[providerID] + if !ok { + return nil, "", nil + } + if time.Now().UnixMilli() >= creds.Expires { + refreshed, err := provider.RefreshToken(creds) + if err != nil { + return nil, "", fmt.Errorf("failed to refresh OAuth token for %s", providerID) + } + creds = refreshed + } + key := provider.GetAPIKey(creds) + return &creds, key, nil +} diff --git a/pkg/ai/oauth/index_test.go b/pkg/ai/oauth/index_test.go new file mode 100644 index 00000000..a94d039c --- /dev/null +++ b/pkg/ai/oauth/index_test.go @@ -0,0 +1,97 @@ +package oauth + +import ( + "errors" + "testing" + "time" +) + +type testProvider struct { + id ProviderID + name string + apiKey string + refreshed Credentials + refreshErr error + usesCallback bool +} + +func (p *testProvider) ID() ProviderID { return p.id } +func (p *testProvider) Name() string { return p.name } +func (p *testProvider) Login(callbacks LoginCallbacks) (Credentials, error) { + return Credentials{}, nil +} +func (p *testProvider) UsesCallbackServer() bool { return p.usesCallback } +func (p *testProvider) RefreshToken(credentials Credentials) (Credentials, error) { + if p.refreshErr != nil { + return Credentials{}, p.refreshErr + } + return p.refreshed, nil +} +func (p *testProvider) GetAPIKey(credentials Credentials) string { + return p.apiKey +} + +func TestProviderRegistryAndGetAPIKey(t *testing.T) { + ResetProviders() + p := &testProvider{ + id: "test", + name: "Test", + apiKey: "key-1", + refreshed: Credentials{ + Refresh: "r2", + Access: "a2", + Expires: time.Now().Add(time.Hour).UnixMilli(), + }, + } + RegisterProvider(p) + + got, ok := GetProvider("test") + if !ok || got.Name() != "Test" { + t.Fatalf("expected provider in registry") + } + + credsMap := map[ProviderID]Credentials{ + "test": { + Refresh: "r1", + Access: "a1", + Expires: time.Now().Add(-time.Minute).UnixMilli(), + }, + } + newCreds, apiKey, err := GetAPIKey("test", credsMap) + if err != nil { + t.Fatalf("unexpected get api key error: %v", err) + } + if newCreds == nil || newCreds.Access != "a2" { + t.Fatalf("expected refreshed credentials, got %#v", newCreds) + } + if apiKey != "key-1" { + t.Fatalf("expected api key key-1, got %s", apiKey) + } + + UnregisterProvider("test") + if _, ok := GetProvider("test"); ok { + t.Fatalf("expected provider removed") + } +} + +func TestGetAPIKey_RefreshFailure(t *testing.T) { + ResetProviders() + p := &testProvider{ + id: "broken", + name: "Broken", + apiKey: "x", + refreshErr: errors.New("boom"), + } + RegisterProvider(p) + + _, _, err := GetAPIKey("broken", map[ProviderID]Credentials{ + "broken": { + Refresh: "r1", + Access: "a1", + Expires: time.Now().Add(-time.Minute).UnixMilli(), + }, + }) + if err == nil { + t.Fatalf("expected refresh failure error") + } +} diff --git a/pkg/ai/oauth/pkce.go b/pkg/ai/oauth/pkce.go new file mode 100644 index 00000000..00298394 --- /dev/null +++ b/pkg/ai/oauth/pkce.go @@ -0,0 +1,26 @@ +package oauth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +type PKCE struct { + Verifier string + Challenge string +} + +func GeneratePKCE() (PKCE, error) { + verifierBytes := make([]byte, 32) + if _, err := rand.Read(verifierBytes); err != nil { + return PKCE{}, err + } + verifier := base64.RawURLEncoding.EncodeToString(verifierBytes) + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + return PKCE{ + Verifier: verifier, + Challenge: challenge, + }, nil +} diff --git a/pkg/ai/oauth/pkce_test.go b/pkg/ai/oauth/pkce_test.go new file mode 100644 index 00000000..3efb795c --- /dev/null +++ b/pkg/ai/oauth/pkce_test.go @@ -0,0 +1,23 @@ +package oauth + +import ( + "regexp" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + pkce, err := GeneratePKCE() + if err != nil { + t.Fatalf("unexpected generate pkce error: %v", err) + } + if pkce.Verifier == "" || pkce.Challenge == "" { + t.Fatalf("expected verifier and challenge to be non-empty") + } + base64url := regexp.MustCompile(`^[A-Za-z0-9_-]+$`) + if !base64url.MatchString(pkce.Verifier) { + t.Fatalf("verifier must be base64url: %s", pkce.Verifier) + } + if !base64url.MatchString(pkce.Challenge) { + t.Fatalf("challenge must be base64url: %s", pkce.Challenge) + } +} diff --git a/pkg/ai/oauth/types.go b/pkg/ai/oauth/types.go new file mode 100644 index 00000000..3ea2d103 --- /dev/null +++ b/pkg/ai/oauth/types.go @@ -0,0 +1,38 @@ +package oauth + +type Credentials struct { + Refresh string `json:"refresh"` + Access string `json:"access"` + Expires int64 `json:"expires"` + Extra map[string]any `json:"extra,omitempty"` +} + +type ProviderID string + +type Prompt struct { + Message string `json:"message"` + Placeholder string `json:"placeholder,omitempty"` + AllowEmpty bool `json:"allowEmpty,omitempty"` +} + +type AuthInfo struct { + URL string `json:"url"` + Instructions string `json:"instructions,omitempty"` +} + +type LoginCallbacks struct { + OnAuth func(info AuthInfo) + OnPrompt func(prompt Prompt) (string, error) + OnProgress func(message string) + OnManualCodeInput func() (string, error) +} + +type Provider interface { + ID() ProviderID + Name() string + + Login(callbacks LoginCallbacks) (Credentials, error) + UsesCallbackServer() bool + RefreshToken(credentials Credentials) (Credentials, error) + GetAPIKey(credentials Credentials) string +} From 1821b226cdb401362d7bdb8276d180b708c7640a Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:00:54 +0000 Subject: [PATCH 05/75] Add connector-to-pkg-ai context adapter scaffold MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/ai_adapter.go | 115 +++++++++++++++++++++++++++++++ pkg/connector/ai_adapter_test.go | 70 +++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 pkg/connector/ai_adapter.go create mode 100644 pkg/connector/ai_adapter_test.go diff --git a/pkg/connector/ai_adapter.go b/pkg/connector/ai_adapter.go new file mode 100644 index 00000000..8b384862 --- /dev/null +++ b/pkg/connector/ai_adapter.go @@ -0,0 +1,115 @@ +package connector + +import ( + "encoding/json" + "strings" + + aipkg "github.com/beeper/ai-bridge/pkg/ai" +) + +// toAIContext maps connector runtime request data into pkg/ai portable context. +// This adapter is intentionally side-effect free to enable incremental migration. +func toAIContext(systemPrompt string, messages []UnifiedMessage, tools []ToolDefinition) aipkg.Context { + aiMessages := make([]aipkg.Message, 0, len(messages)) + for _, msg := range messages { + converted := aipkg.Message{ + Role: mapAIRole(msg.Role), + Timestamp: 0, + } + + switch msg.Role { + case RoleUser: + blocks := make([]aipkg.ContentBlock, 0, len(msg.Content)) + for _, part := range msg.Content { + switch part.Type { + case ContentTypeText: + converted.Text = part.Text + blocks = append(blocks, aipkg.ContentBlock{ + Type: aipkg.ContentTypeText, + Text: part.Text, + }) + case ContentTypeImage: + data := part.ImageB64 + if data == "" && strings.HasPrefix(part.ImageURL, "data:") { + data = strings.TrimPrefix(part.ImageURL, "data:") + } + blocks = append(blocks, aipkg.ContentBlock{ + Type: aipkg.ContentTypeImage, + Data: data, + MimeType: part.MimeType, + }) + } + } + converted.Content = blocks + case RoleAssistant: + blocks := make([]aipkg.ContentBlock, 0, len(msg.Content)+len(msg.ToolCalls)) + for _, part := range msg.Content { + if part.Type == ContentTypeText { + blocks = append(blocks, aipkg.ContentBlock{ + Type: aipkg.ContentTypeText, + Text: part.Text, + }) + } + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, aipkg.ContentBlock{ + Type: aipkg.ContentTypeToolCall, + ID: tc.ID, + Name: tc.Name, + Arguments: parseToolArguments(tc.Arguments), + }) + } + converted.Content = blocks + case RoleTool: + converted.ToolCallID = msg.ToolCallID + converted.ToolName = msg.Name + text := msg.Text() + converted.Content = []aipkg.ContentBlock{{ + Type: aipkg.ContentTypeText, + Text: text, + }} + converted.Text = text + } + + aiMessages = append(aiMessages, converted) + } + + aiTools := make([]aipkg.Tool, 0, len(tools)) + for _, tool := range tools { + aiTools = append(aiTools, aipkg.Tool{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters, + }) + } + + return aipkg.Context{ + SystemPrompt: systemPrompt, + Messages: aiMessages, + Tools: aiTools, + } +} + +func mapAIRole(role MessageRole) aipkg.MessageRole { + switch role { + case RoleUser: + return aipkg.RoleUser + case RoleAssistant: + return aipkg.RoleAssistant + case RoleTool: + return aipkg.RoleToolResult + default: + return aipkg.RoleUser + } +} + +func parseToolArguments(raw string) map[string]any { + if strings.TrimSpace(raw) == "" { + return map[string]any{} + } + var parsed map[string]any + if err := json.Unmarshal([]byte(raw), &parsed); err != nil || parsed == nil { + return map[string]any{} + } + return parsed +} diff --git a/pkg/connector/ai_adapter_test.go b/pkg/connector/ai_adapter_test.go new file mode 100644 index 00000000..4ba25c02 --- /dev/null +++ b/pkg/connector/ai_adapter_test.go @@ -0,0 +1,70 @@ +package connector + +import "testing" + +func TestToAIContext_MapsMessagesAndTools(t *testing.T) { + inputMessages := []UnifiedMessage{ + { + Role: RoleUser, + Content: []ContentPart{ + {Type: ContentTypeText, Text: "hello"}, + }, + }, + { + Role: RoleAssistant, + Content: []ContentPart{ + {Type: ContentTypeText, Text: "use tool"}, + }, + ToolCalls: []ToolCallResult{{ + ID: "call_1", + Name: "echo", + Arguments: `{"message":"hi"}`, + }}, + }, + { + Role: RoleTool, + ToolCallID: "call_1", + Name: "echo", + Content: []ContentPart{ + {Type: ContentTypeText, Text: "hi"}, + }, + }, + } + tools := []ToolDefinition{{ + Name: "echo", + Description: "Echo message", + Parameters: map[string]any{ + "type": "object", + }, + }} + + ctx := toAIContext("system prompt", inputMessages, tools) + if ctx.SystemPrompt != "system prompt" { + t.Fatalf("unexpected system prompt: %s", ctx.SystemPrompt) + } + if len(ctx.Messages) != 3 { + t.Fatalf("expected 3 mapped messages, got %d", len(ctx.Messages)) + } + if ctx.Messages[0].Role != "user" { + t.Fatalf("expected first role user, got %s", ctx.Messages[0].Role) + } + if ctx.Messages[1].Role != "assistant" { + t.Fatalf("expected second role assistant, got %s", ctx.Messages[1].Role) + } + if len(ctx.Messages[1].Content) < 2 { + t.Fatalf("expected assistant content to include text and tool call, got %+v", ctx.Messages[1].Content) + } + toolCall := ctx.Messages[1].Content[1] + if toolCall.Type != "toolCall" || toolCall.Name != "echo" { + t.Fatalf("unexpected tool call mapping: %+v", toolCall) + } + if toolCall.Arguments["message"] != "hi" { + t.Fatalf("unexpected parsed tool args: %+v", toolCall.Arguments) + } + if ctx.Messages[2].Role != "toolResult" { + t.Fatalf("expected tool role mapped to toolResult, got %s", ctx.Messages[2].Role) + } + if len(ctx.Tools) != 1 || ctx.Tools[0].Name != "echo" { + t.Fatalf("expected mapped tools in context, got %+v", ctx.Tools) + } +} From d75ca515ddcdb9b9d4cad85830c5e9d35211557c Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:06:45 +0000 Subject: [PATCH 06/75] Add simple options and tool validation helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/simple_options.go | 107 +++++++++++++++ pkg/ai/providers/simple_options_test.go | 61 +++++++++ pkg/ai/utils/type_helpers.go | 28 ++++ pkg/ai/utils/type_helpers_test.go | 20 +++ pkg/ai/utils/validation.go | 174 ++++++++++++++++++++++++ pkg/ai/utils/validation_test.go | 75 ++++++++++ 6 files changed, 465 insertions(+) create mode 100644 pkg/ai/providers/simple_options.go create mode 100644 pkg/ai/providers/simple_options_test.go create mode 100644 pkg/ai/utils/type_helpers.go create mode 100644 pkg/ai/utils/type_helpers_test.go create mode 100644 pkg/ai/utils/validation.go create mode 100644 pkg/ai/utils/validation_test.go diff --git a/pkg/ai/providers/simple_options.go b/pkg/ai/providers/simple_options.go new file mode 100644 index 00000000..63e31bb9 --- /dev/null +++ b/pkg/ai/providers/simple_options.go @@ -0,0 +1,107 @@ +package providers + +import "github.com/beeper/ai-bridge/pkg/ai" + +func BuildBaseOptions(model ai.Model, options *ai.SimpleStreamOptions, apiKey string) ai.StreamOptions { + if options == nil { + return ai.StreamOptions{ + MaxTokens: minInt(model.MaxTokens, 32000), + APIKey: apiKey, + } + } + maxTokens := options.MaxTokens + if maxTokens <= 0 { + maxTokens = minInt(model.MaxTokens, 32000) + } + return ai.StreamOptions{ + Temperature: options.Temperature, + MaxTokens: maxTokens, + Ctx: options.Ctx, + APIKey: coalesce(apiKey, options.APIKey), + Transport: options.Transport, + CacheRetention: options.CacheRetention, + SessionID: options.SessionID, + OnPayload: options.OnPayload, + Headers: options.Headers, + MaxRetryDelayMs: options.MaxRetryDelayMs, + Metadata: options.Metadata, + } +} + +func ClampReasoning(effort ai.ThinkingLevel) ai.ThinkingLevel { + if effort == ai.ThinkingXHigh { + return ai.ThinkingHigh + } + return effort +} + +func AdjustMaxTokensForThinking( + baseMaxTokens int, + modelMaxTokens int, + reasoningLevel ai.ThinkingLevel, + customBudgets ai.ThinkingBudgets, +) (maxTokens int, thinkingBudget int) { + defaultBudgets := ai.ThinkingBudgets{ + Minimal: 1024, + Low: 2048, + Medium: 8192, + High: 16384, + } + budgets := mergeThinkingBudgets(defaultBudgets, customBudgets) + level := ClampReasoning(reasoningLevel) + switch level { + case ai.ThinkingMinimal: + thinkingBudget = budgets.Minimal + case ai.ThinkingLow: + thinkingBudget = budgets.Low + case ai.ThinkingMedium: + thinkingBudget = budgets.Medium + default: + thinkingBudget = budgets.High + } + + maxTokens = minInt(baseMaxTokens+thinkingBudget, modelMaxTokens) + minOutput := 1024 + if maxTokens <= thinkingBudget { + thinkingBudget = maxInt(0, maxTokens-minOutput) + } + return maxTokens, thinkingBudget +} + +func mergeThinkingBudgets(base, custom ai.ThinkingBudgets) ai.ThinkingBudgets { + out := base + if custom.Minimal > 0 { + out.Minimal = custom.Minimal + } + if custom.Low > 0 { + out.Low = custom.Low + } + if custom.Medium > 0 { + out.Medium = custom.Medium + } + if custom.High > 0 { + out.High = custom.High + } + return out +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func coalesce(primary, fallback string) string { + if primary != "" { + return primary + } + return fallback +} diff --git a/pkg/ai/providers/simple_options_test.go b/pkg/ai/providers/simple_options_test.go new file mode 100644 index 00000000..02ae5616 --- /dev/null +++ b/pkg/ai/providers/simple_options_test.go @@ -0,0 +1,61 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestClampReasoning(t *testing.T) { + if got := ClampReasoning(ai.ThinkingXHigh); got != ai.ThinkingHigh { + t.Fatalf("expected xhigh to clamp to high, got %s", got) + } + if got := ClampReasoning(ai.ThinkingLow); got != ai.ThinkingLow { + t.Fatalf("expected low to stay low, got %s", got) + } +} + +func TestAdjustMaxTokensForThinking(t *testing.T) { + maxTokens, budget := AdjustMaxTokensForThinking(4000, 10000, ai.ThinkingMedium, ai.ThinkingBudgets{}) + if maxTokens != 10000 { + t.Fatalf("expected maxTokens capped to model max 10000, got %d", maxTokens) + } + if budget != 8192 { + t.Fatalf("expected medium budget 8192, got %d", budget) + } + + maxTokens, budget = AdjustMaxTokensForThinking(1200, 1500, ai.ThinkingHigh, ai.ThinkingBudgets{}) + if budget < 0 || maxTokens <= 0 { + t.Fatalf("expected non-negative budget and positive max tokens, got max=%d budget=%d", maxTokens, budget) + } + + maxTokens, budget = AdjustMaxTokensForThinking(1000, 9000, ai.ThinkingLow, ai.ThinkingBudgets{Low: 3333}) + if budget != 3333 { + t.Fatalf("expected custom low budget 3333, got %d", budget) + } + if maxTokens != 4333 { + t.Fatalf("expected max tokens 4333, got %d", maxTokens) + } +} + +func TestBuildBaseOptions(t *testing.T) { + model := ai.Model{MaxTokens: 8000} + temp := 0.2 + opts := &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + Temperature: &temp, + MaxTokens: 0, + APIKey: "from-options", + }, + } + base := BuildBaseOptions(model, opts, "from-param") + if base.APIKey != "from-param" { + t.Fatalf("expected explicit apiKey to win, got %s", base.APIKey) + } + if base.MaxTokens != 8000 { + t.Fatalf("expected default maxTokens=min(model,32000)=8000, got %d", base.MaxTokens) + } + if base.Temperature == nil || *base.Temperature != 0.2 { + t.Fatalf("unexpected temperature in base options") + } +} diff --git a/pkg/ai/utils/type_helpers.go b/pkg/ai/utils/type_helpers.go new file mode 100644 index 00000000..d1058385 --- /dev/null +++ b/pkg/ai/utils/type_helpers.go @@ -0,0 +1,28 @@ +package utils + +type StringEnumSchema struct { + Type string `json:"type"` + Enum []string `json:"enum"` + Description string `json:"description,omitempty"` + Default string `json:"default,omitempty"` +} + +func StringEnum(values []string, description string, defaultValue string) map[string]any { + schema := StringEnumSchema{ + Type: "string", + Enum: append([]string(nil), values...), + Description: description, + Default: defaultValue, + } + out := map[string]any{ + "type": "string", + "enum": schema.Enum, + } + if schema.Description != "" { + out["description"] = schema.Description + } + if schema.Default != "" { + out["default"] = schema.Default + } + return out +} diff --git a/pkg/ai/utils/type_helpers_test.go b/pkg/ai/utils/type_helpers_test.go new file mode 100644 index 00000000..1a08e75c --- /dev/null +++ b/pkg/ai/utils/type_helpers_test.go @@ -0,0 +1,20 @@ +package utils + +import "testing" + +func TestStringEnum(t *testing.T) { + schema := StringEnum([]string{"add", "subtract"}, "operation", "add") + if schema["type"] != "string" { + t.Fatalf("expected type string, got %v", schema["type"]) + } + enumVals, ok := schema["enum"].([]string) + if !ok || len(enumVals) != 2 { + t.Fatalf("expected enum values in schema, got %#v", schema["enum"]) + } + if schema["description"] != "operation" { + t.Fatalf("expected description set") + } + if schema["default"] != "add" { + t.Fatalf("expected default add") + } +} diff --git a/pkg/ai/utils/validation.go b/pkg/ai/utils/validation.go new file mode 100644 index 00000000..bb4b3725 --- /dev/null +++ b/pkg/ai/utils/validation.go @@ -0,0 +1,174 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func ValidateToolCall(tools []ai.Tool, toolCall ai.ContentBlock) (map[string]any, error) { + for _, tool := range tools { + if tool.Name == toolCall.Name { + return ValidateToolArguments(tool, toolCall) + } + } + return nil, fmt.Errorf(`tool "%s" not found`, toolCall.Name) +} + +func ValidateToolArguments(tool ai.Tool, toolCall ai.ContentBlock) (map[string]any, error) { + args := cloneMap(toolCall.Arguments) + if args == nil { + args = map[string]any{} + } + if tool.Parameters == nil { + return args, nil + } + required, _ := tool.Parameters["required"].([]any) + for _, raw := range required { + key, ok := raw.(string) + if !ok { + continue + } + if _, exists := args[key]; !exists { + return nil, fmt.Errorf("validation failed for tool %q: missing required field %q", toolCall.Name, key) + } + } + + props, _ := tool.Parameters["properties"].(map[string]any) + for key, propSchemaRaw := range props { + propSchema, ok := propSchemaRaw.(map[string]any) + if !ok { + continue + } + expectedType, _ := propSchema["type"].(string) + if expectedType == "" { + continue + } + value, exists := args[key] + if !exists { + continue + } + coerced, err := coerceJSONType(expectedType, value) + if err != nil { + return nil, fmt.Errorf("validation failed for tool %q: %s: %w", toolCall.Name, key, err) + } + args[key] = coerced + } + return args, nil +} + +func cloneMap(input map[string]any) map[string]any { + if input == nil { + return nil + } + out := make(map[string]any, len(input)) + for k, v := range input { + out[k] = v + } + return out +} + +func coerceJSONType(expected string, value any) (any, error) { + switch expected { + case "string": + switch v := value.(type) { + case string: + return v, nil + default: + return fmt.Sprintf("%v", v), nil + } + case "number": + switch v := value.(type) { + case float64, float32, int, int64, uint64, uint32, int32: + return toFloat64(v), nil + case string: + f, err := strconv.ParseFloat(strings.TrimSpace(v), 64) + if err != nil { + return nil, fmt.Errorf("expected number, got %T", value) + } + return f, nil + default: + return nil, fmt.Errorf("expected number, got %T", value) + } + case "integer": + switch v := value.(type) { + case int, int64, int32, uint64, uint32: + return toInt64(v), nil + case float64: + return int64(v), nil + case string: + i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64) + if err != nil { + return nil, fmt.Errorf("expected integer, got %T", value) + } + return i, nil + default: + return nil, fmt.Errorf("expected integer, got %T", value) + } + case "boolean": + switch v := value.(type) { + case bool: + return v, nil + case string: + b, err := strconv.ParseBool(strings.TrimSpace(v)) + if err != nil { + return nil, fmt.Errorf("expected boolean, got %T", value) + } + return b, nil + default: + return nil, fmt.Errorf("expected boolean, got %T", value) + } + case "object": + if _, ok := value.(map[string]any); ok { + return value, nil + } + return nil, fmt.Errorf("expected object, got %T", value) + case "array": + if _, ok := value.([]any); ok { + return value, nil + } + return nil, fmt.Errorf("expected array, got %T", value) + default: + return value, nil + } +} + +func toFloat64(v any) float64 { + switch n := v.(type) { + case float64: + return n + case float32: + return float64(n) + case int: + return float64(n) + case int64: + return float64(n) + case int32: + return float64(n) + case uint64: + return float64(n) + case uint32: + return float64(n) + default: + return 0 + } +} + +func toInt64(v any) int64 { + switch n := v.(type) { + case int: + return int64(n) + case int64: + return n + case int32: + return int64(n) + case uint64: + return int64(n) + case uint32: + return int64(n) + default: + return 0 + } +} diff --git a/pkg/ai/utils/validation_test.go b/pkg/ai/utils/validation_test.go new file mode 100644 index 00000000..f957dceb --- /dev/null +++ b/pkg/ai/utils/validation_test.go @@ -0,0 +1,75 @@ +package utils + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestValidateToolCallAndArguments(t *testing.T) { + tool := ai.Tool{ + Name: "calculate", + Description: "calc", + Parameters: map[string]any{ + "type": "object", + "required": []any{ + "expression", + "strict", + }, + "properties": map[string]any{ + "expression": map[string]any{"type": "string"}, + "strict": map[string]any{"type": "boolean"}, + "count": map[string]any{"type": "number"}, + }, + }, + } + call := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + Name: "calculate", + Arguments: map[string]any{ + "expression": 123, + "strict": "true", + "count": "10.5", + }, + } + validated, err := ValidateToolCall([]ai.Tool{tool}, call) + if err != nil { + t.Fatalf("unexpected validation error: %v", err) + } + if validated["expression"] != "123" { + t.Fatalf("expected expression coerced to string, got %#v", validated["expression"]) + } + if validated["strict"] != true { + t.Fatalf("expected strict coerced to bool, got %#v", validated["strict"]) + } + if validated["count"] != 10.5 { + t.Fatalf("expected count coerced to float64, got %#v", validated["count"]) + } +} + +func TestValidateToolCall_MissingToolAndRequiredField(t *testing.T) { + _, err := ValidateToolCall(nil, ai.ContentBlock{Name: "missing"}) + if err == nil { + t.Fatalf("expected error for missing tool") + } + + tool := ai.Tool{ + Name: "echo", + Parameters: map[string]any{ + "type": "object", + "required": []any{ + "message", + }, + "properties": map[string]any{ + "message": map[string]any{"type": "string"}, + }, + }, + } + _, err = ValidateToolCall([]ai.Tool{tool}, ai.ContentBlock{ + Name: "echo", + Arguments: map[string]any{}, + }) + if err == nil { + t.Fatalf("expected missing required field error") + } +} From de4ad3d7a8601156df84cd071c8c5c46e070e2c5 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:09:32 +0000 Subject: [PATCH 07/75] Add env-gated ai e2e parity scaffolding tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/e2e/abort_test.go | 14 ++++++++++++++ pkg/ai/e2e/context_overflow_test.go | 14 ++++++++++++++ pkg/ai/e2e/stream_test.go | 15 +++++++++++++++ 3 files changed, 43 insertions(+) create mode 100644 pkg/ai/e2e/abort_test.go create mode 100644 pkg/ai/e2e/context_overflow_test.go create mode 100644 pkg/ai/e2e/stream_test.go diff --git a/pkg/ai/e2e/abort_test.go b/pkg/ai/e2e/abort_test.go new file mode 100644 index 00000000..b4a657f4 --- /dev/null +++ b/pkg/ai/e2e/abort_test.go @@ -0,0 +1,14 @@ +package e2e + +import ( + "os" + "testing" +) + +// Scaffolding parity target for pi-mono/packages/ai/test/abort.test.ts. +func TestAbortE2EParityScaffold(t *testing.T) { + if os.Getenv("PI_AI_E2E") == "" { + t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") + } + t.Skip("abort e2e parity test pending full provider runtime port") +} diff --git a/pkg/ai/e2e/context_overflow_test.go b/pkg/ai/e2e/context_overflow_test.go new file mode 100644 index 00000000..75e144a7 --- /dev/null +++ b/pkg/ai/e2e/context_overflow_test.go @@ -0,0 +1,14 @@ +package e2e + +import ( + "os" + "testing" +) + +// Scaffolding parity target for pi-mono/packages/ai/test/context-overflow.test.ts. +func TestContextOverflowE2EParityScaffold(t *testing.T) { + if os.Getenv("PI_AI_E2E") == "" { + t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") + } + t.Skip("context overflow e2e parity test pending full provider runtime port") +} diff --git a/pkg/ai/e2e/stream_test.go b/pkg/ai/e2e/stream_test.go new file mode 100644 index 00000000..eea5cfe9 --- /dev/null +++ b/pkg/ai/e2e/stream_test.go @@ -0,0 +1,15 @@ +package e2e + +import ( + "os" + "testing" +) + +// Scaffolding parity target for pi-mono/packages/ai/test/stream.test.ts. +// This is intentionally env-gated while provider runtime integration is in progress. +func TestGenerateE2EParityScaffold(t *testing.T) { + if os.Getenv("PI_AI_E2E") == "" { + t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") + } + t.Skip("stream e2e parity test pending full provider runtime port") +} From 5bef3d3e527dea375e4fffef909fc2c195c0b2e4 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:11:08 +0000 Subject: [PATCH 08/75] Auto-close assistant event stream on terminal events MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/event_stream.go | 6 ++++++ pkg/ai/event_stream_test.go | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/pkg/ai/event_stream.go b/pkg/ai/event_stream.go index 01b26ef0..fc18b41c 100644 --- a/pkg/ai/event_stream.go +++ b/pkg/ai/event_stream.go @@ -35,23 +35,29 @@ func (s *AssistantMessageEventStream) Push(evt AssistantMessageEvent) { default: } + isComplete := false if evt.Type == EventDone { s.mu.Lock() s.result = evt.Message s.hasResult = true s.mu.Unlock() + isComplete = true } if evt.Type == EventError { s.mu.Lock() s.result = evt.Error s.hasResult = true s.mu.Unlock() + isComplete = true } select { case <-s.done: case s.ch <- evt: } + if isComplete { + s.Close() + } } func (s *AssistantMessageEventStream) Close() { diff --git a/pkg/ai/event_stream_test.go b/pkg/ai/event_stream_test.go index cd12e4c5..8f145338 100644 --- a/pkg/ai/event_stream_test.go +++ b/pkg/ai/event_stream_test.go @@ -14,7 +14,6 @@ func TestAssistantMessageEventStream_ResultFromDone(t *testing.T) { go func() { s.Push(AssistantMessageEvent{Type: EventStart}) s.Push(AssistantMessageEvent{Type: EventDone, Message: doneMsg, Reason: StopReasonStop}) - s.Close() }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -37,3 +36,35 @@ func TestAssistantMessageEventStream_ResultFromDone(t *testing.T) { t.Fatalf("expected stop reason stop, got %s", result.StopReason) } } + +func TestAssistantMessageEventStream_ResultFromError(t *testing.T) { + s := NewAssistantMessageEventStream(2) + errMsg := Message{Role: RoleAssistant, ErrorMessage: "boom", Timestamp: 2} + + go func() { + s.Push(AssistantMessageEvent{Type: EventError, Error: errMsg}) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + evt, err := s.Next(ctx) + if err != nil { + t.Fatalf("unexpected next error: %v", err) + } + if evt.Type != EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + + _, err = s.Next(ctx) + if err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } + + result, err := s.Result() + if err != nil { + t.Fatalf("unexpected result error: %v", err) + } + if result.ErrorMessage != "boom" { + t.Fatalf("expected result from error event, got %#v", result) + } +} From 71b187f9be7931d718c64be1c051992c475b8157 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:12:15 +0000 Subject: [PATCH 09/75] Sort provider and model registry output deterministically MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/models.go | 9 ++++++++- pkg/ai/models_test.go | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pkg/ai/models.go b/pkg/ai/models.go index 5546b9fc..a30f7560 100644 --- a/pkg/ai/models.go +++ b/pkg/ai/models.go @@ -1,6 +1,9 @@ package ai -import "strings" +import ( + "slices" + "strings" +) var modelRegistry = map[string]map[string]Model{} @@ -31,6 +34,7 @@ func GetProviders() []string { for provider := range modelRegistry { out = append(out, provider) } + slices.Sort(out) return out } @@ -43,6 +47,9 @@ func GetModels(provider string) []Model { for _, model := range models { out = append(out, model) } + slices.SortFunc(out, func(a, b Model) int { + return strings.Compare(a.ID, b.ID) + }) return out } diff --git a/pkg/ai/models_test.go b/pkg/ai/models_test.go index d4ab202b..5468283a 100644 --- a/pkg/ai/models_test.go +++ b/pkg/ai/models_test.go @@ -41,3 +41,24 @@ func TestModelsAreEqual(t *testing.T) { t.Fatalf("expected nil model comparison to be false") } } + +func TestModelRegistryDeterministicOrdering(t *testing.T) { + previous := modelRegistry + modelRegistry = map[string]map[string]Model{} + defer func() { + modelRegistry = previous + }() + + RegisterModels("z-provider", []Model{{ID: "b-model"}, {ID: "a-model"}}) + RegisterModels("a-provider", []Model{{ID: "z-model"}, {ID: "x-model"}}) + + providers := GetProviders() + if len(providers) != 2 || providers[0] != "a-provider" || providers[1] != "z-provider" { + t.Fatalf("expected sorted providers, got %#v", providers) + } + + models := GetModels("z-provider") + if len(models) != 2 || models[0].ID != "a-model" || models[1].ID != "b-model" { + t.Fatalf("expected sorted model IDs, got %#v", models) + } +} From d090380a08b62013f78043c320411c5a6a3d3e5e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:14:15 +0000 Subject: [PATCH 10/75] Add ai model generator command groundwork MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- cmd/generate-ai-models/main.go | 98 ++++++++++++++++++++++++++++++++++ pkg/ai/models.go | 4 ++ pkg/ai/models_generated.go | 38 +++++++++++++ pkg/ai/models_test.go | 10 ++++ 4 files changed, 150 insertions(+) create mode 100644 cmd/generate-ai-models/main.go create mode 100644 pkg/ai/models_generated.go diff --git a/cmd/generate-ai-models/main.go b/cmd/generate-ai-models/main.go new file mode 100644 index 00000000..9be7e39a --- /dev/null +++ b/cmd/generate-ai-models/main.go @@ -0,0 +1,98 @@ +package main + +import ( + "bytes" + "fmt" + "go/format" + "os" + "slices" + "strings" +) + +type seedModel struct { + Provider string + ID string + Name string + API string + MaxTokens int +} + +func main() { + models := []seedModel{ + { + Provider: "openai", + ID: "gpt-5", + Name: "GPT-5", + API: "APIOpenAIResponses", + MaxTokens: 128000, + }, + { + Provider: "openai", + ID: "gpt-5-mini", + Name: "GPT-5 mini", + API: "APIOpenAIResponses", + MaxTokens: 128000, + }, + { + Provider: "anthropic", + ID: "claude-sonnet-4-5", + Name: "Claude Sonnet 4.5", + API: "APIAnthropicMessages", + MaxTokens: 64000, + }, + { + Provider: "anthropic", + ID: "claude-opus-4-6", + Name: "Claude Opus 4.6", + API: "APIAnthropicMessages", + MaxTokens: 64000, + }, + } + + slices.SortFunc(models, func(a, b seedModel) int { + if cmp := strings.Compare(a.Provider, b.Provider); cmp != 0 { + return cmp + } + return strings.Compare(a.ID, b.ID) + }) + + var buf bytes.Buffer + buf.WriteString("// Code generated by cmd/generate-ai-models. DO NOT EDIT.\n\n") + buf.WriteString("package ai\n\n") + buf.WriteString("func registerGeneratedModels() {\n") + + byProvider := map[string][]seedModel{} + for _, model := range models { + byProvider[model.Provider] = append(byProvider[model.Provider], model) + } + providers := make([]string, 0, len(byProvider)) + for provider := range byProvider { + providers = append(providers, provider) + } + slices.Sort(providers) + + for _, provider := range providers { + buf.WriteString(fmt.Sprintf("\tRegisterModels(%q, []Model{\n", provider)) + for _, model := range byProvider[provider] { + buf.WriteString("\t\t{\n") + buf.WriteString(fmt.Sprintf("\t\t\tID: %q,\n", model.ID)) + buf.WriteString(fmt.Sprintf("\t\t\tName: %q,\n", model.Name)) + buf.WriteString(fmt.Sprintf("\t\t\tProvider: %q,\n", model.Provider)) + buf.WriteString(fmt.Sprintf("\t\t\tAPI: %s,\n", model.API)) + buf.WriteString(fmt.Sprintf("\t\t\tMaxTokens: %d,\n", model.MaxTokens)) + buf.WriteString("\t\t},\n") + } + buf.WriteString("\t})\n") + } + + buf.WriteString("}\n") + + out, err := format.Source(buf.Bytes()) + if err != nil { + panic(err) + } + + if err := os.WriteFile("pkg/ai/models_generated.go", out, 0o644); err != nil { + panic(err) + } +} diff --git a/pkg/ai/models.go b/pkg/ai/models.go index a30f7560..ec7528fc 100644 --- a/pkg/ai/models.go +++ b/pkg/ai/models.go @@ -7,6 +7,10 @@ import ( var modelRegistry = map[string]map[string]Model{} +func init() { + registerGeneratedModels() +} + func RegisterModels(provider string, models []Model) { key := strings.TrimSpace(provider) if key == "" { diff --git a/pkg/ai/models_generated.go b/pkg/ai/models_generated.go new file mode 100644 index 00000000..9798238c --- /dev/null +++ b/pkg/ai/models_generated.go @@ -0,0 +1,38 @@ +// Code generated by cmd/generate-ai-models. DO NOT EDIT. + +package ai + +func registerGeneratedModels() { + RegisterModels("anthropic", []Model{ + { + ID: "claude-opus-4-6", + Name: "Claude Opus 4.6", + Provider: "anthropic", + API: APIAnthropicMessages, + MaxTokens: 64000, + }, + { + ID: "claude-sonnet-4-5", + Name: "Claude Sonnet 4.5", + Provider: "anthropic", + API: APIAnthropicMessages, + MaxTokens: 64000, + }, + }) + RegisterModels("openai", []Model{ + { + ID: "gpt-5", + Name: "GPT-5", + Provider: "openai", + API: APIOpenAIResponses, + MaxTokens: 128000, + }, + { + ID: "gpt-5-mini", + Name: "GPT-5 mini", + Provider: "openai", + API: APIOpenAIResponses, + MaxTokens: 128000, + }, + }) +} diff --git a/pkg/ai/models_test.go b/pkg/ai/models_test.go index 5468283a..0fbe6d17 100644 --- a/pkg/ai/models_test.go +++ b/pkg/ai/models_test.go @@ -31,6 +31,16 @@ func TestSupportsXhigh(t *testing.T) { } } +func TestGeneratedModelsRegisteredOnInit(t *testing.T) { + model, ok := GetModel("openai", "gpt-5") + if !ok { + t.Fatalf("expected generated openai gpt-5 model to be registered") + } + if model.API != APIOpenAIResponses { + t.Fatalf("expected openai gpt-5 to use responses api, got %s", model.API) + } +} + func TestModelsAreEqual(t *testing.T) { a := &Model{ID: "gpt-4o", Provider: "openai"} b := &Model{ID: "gpt-4o", Provider: "openai"} From 2b5222b584235c0504e8999852708d36d475bc1b Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:16:54 +0000 Subject: [PATCH 11/75] Add Azure OpenAI responses payload/config helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/azure_openai_responses.go | 146 ++++++++++++++++++ .../providers/azure_openai_responses_test.go | 124 +++++++++++++++ 2 files changed, 270 insertions(+) create mode 100644 pkg/ai/providers/azure_openai_responses.go create mode 100644 pkg/ai/providers/azure_openai_responses_test.go diff --git a/pkg/ai/providers/azure_openai_responses.go b/pkg/ai/providers/azure_openai_responses.go new file mode 100644 index 00000000..1741b3a4 --- /dev/null +++ b/pkg/ai/providers/azure_openai_responses.go @@ -0,0 +1,146 @@ +package providers + +import ( + "errors" + "os" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +const defaultAzureAPIVersion = "v1" + +var ErrMissingAzureBaseURL = errors.New("azure openai base url is required") + +type AzureOpenAIResponsesOptions struct { + OpenAIResponsesOptions + AzureAPIVersion string + AzureResourceName string + AzureBaseURL string + AzureDeploymentName string +} + +func ParseDeploymentNameMap(value string) map[string]string { + out := map[string]string{} + for _, entry := range strings.Split(value, ",") { + trimmed := strings.TrimSpace(entry) + if trimmed == "" { + continue + } + parts := strings.SplitN(trimmed, "=", 2) + if len(parts) != 2 { + continue + } + modelID := strings.TrimSpace(parts[0]) + deployment := strings.TrimSpace(parts[1]) + if modelID == "" || deployment == "" { + continue + } + out[modelID] = deployment + } + return out +} + +func ResolveDeploymentName(model ai.Model, options *AzureOpenAIResponsesOptions) string { + if options != nil && strings.TrimSpace(options.AzureDeploymentName) != "" { + return strings.TrimSpace(options.AzureDeploymentName) + } + mapped := ParseDeploymentNameMap(os.Getenv("AZURE_OPENAI_DEPLOYMENT_NAME_MAP")) + if deployment, ok := mapped[model.ID]; ok { + return deployment + } + return model.ID +} + +func normalizeAzureBaseURL(baseURL string) string { + return strings.TrimRight(baseURL, "/") +} + +func buildDefaultAzureBaseURL(resourceName string) string { + return "https://" + resourceName + ".openai.azure.com/openai/v1" +} + +func ResolveAzureConfig(model ai.Model, options *AzureOpenAIResponsesOptions) (baseURL string, apiVersion string, err error) { + apiVersion = defaultAzureAPIVersion + if envVersion := strings.TrimSpace(os.Getenv("AZURE_OPENAI_API_VERSION")); envVersion != "" { + apiVersion = envVersion + } + if options != nil && strings.TrimSpace(options.AzureAPIVersion) != "" { + apiVersion = strings.TrimSpace(options.AzureAPIVersion) + } + + var resolvedBaseURL string + if options != nil && strings.TrimSpace(options.AzureBaseURL) != "" { + resolvedBaseURL = strings.TrimSpace(options.AzureBaseURL) + } + if resolvedBaseURL == "" { + resolvedBaseURL = strings.TrimSpace(os.Getenv("AZURE_OPENAI_BASE_URL")) + } + + resourceName := strings.TrimSpace(os.Getenv("AZURE_OPENAI_RESOURCE_NAME")) + if options != nil && strings.TrimSpace(options.AzureResourceName) != "" { + resourceName = strings.TrimSpace(options.AzureResourceName) + } + if resolvedBaseURL == "" && resourceName != "" { + resolvedBaseURL = buildDefaultAzureBaseURL(resourceName) + } + if resolvedBaseURL == "" { + resolvedBaseURL = strings.TrimSpace(model.BaseURL) + } + if resolvedBaseURL == "" { + return "", "", ErrMissingAzureBaseURL + } + return normalizeAzureBaseURL(resolvedBaseURL), apiVersion, nil +} + +func BuildAzureOpenAIResponsesParams( + model ai.Model, + context ai.Context, + options AzureOpenAIResponsesOptions, +) map[string]any { + deploymentName := ResolveDeploymentName(model, &options) + messages := ConvertOpenAIResponsesMessages(model, context) + + params := map[string]any{ + "model": deploymentName, + "input": messages, + "stream": true, + "prompt_cache_key": options.StreamOptions.SessionID, + } + if options.StreamOptions.MaxTokens > 0 { + params["max_output_tokens"] = options.StreamOptions.MaxTokens + } + if options.StreamOptions.Temperature != nil { + params["temperature"] = *options.StreamOptions.Temperature + } + if len(context.Tools) > 0 { + params["tools"] = convertResponsesTools(context.Tools) + } + if model.Reasoning { + if options.ReasoningEffort != "" || strings.TrimSpace(options.ReasoningSummary) != "" { + effort := options.ReasoningEffort + if effort == "" { + effort = ai.ThinkingMedium + } + summary := strings.TrimSpace(options.ReasoningSummary) + if summary == "" { + summary = "auto" + } + params["reasoning"] = map[string]any{ + "effort": string(effort), + "summary": summary, + } + params["include"] = []string{"reasoning.encrypted_content"} + } else if strings.HasPrefix(strings.ToLower(model.Name), "gpt-5") { + params["input"] = append(messages, map[string]any{ + "role": "developer", + "content": []map[string]any{{ + "type": "input_text", + "text": "# Juice: 0 !important", + }}, + }) + } + } + + return params +} diff --git a/pkg/ai/providers/azure_openai_responses_test.go b/pkg/ai/providers/azure_openai_responses_test.go new file mode 100644 index 00000000..1dc928fd --- /dev/null +++ b/pkg/ai/providers/azure_openai_responses_test.go @@ -0,0 +1,124 @@ +package providers + +import ( + "os" + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestParseDeploymentNameMap(t *testing.T) { + mapped := ParseDeploymentNameMap("gpt-5=my-gpt5, claude-opus-4-6 = claude-deploy , invalid") + if mapped["gpt-5"] != "my-gpt5" { + t.Fatalf("expected gpt-5 mapping, got %#v", mapped) + } + if mapped["claude-opus-4-6"] != "claude-deploy" { + t.Fatalf("expected claude mapping, got %#v", mapped) + } + if _, ok := mapped["invalid"]; ok { + t.Fatalf("expected invalid entry to be ignored") + } +} + +func TestResolveAzureConfig(t *testing.T) { + t.Setenv("AZURE_OPENAI_API_VERSION", "") + t.Setenv("AZURE_OPENAI_BASE_URL", "") + t.Setenv("AZURE_OPENAI_RESOURCE_NAME", "") + + _, _, err := ResolveAzureConfig(ai.Model{}, nil) + if err == nil { + t.Fatalf("expected missing base url error") + } + + baseURL, version, err := ResolveAzureConfig(ai.Model{ + BaseURL: "https://custom.openai.azure.com/openai/v1/", + }, nil) + if err != nil { + t.Fatalf("unexpected error resolving from model base URL: %v", err) + } + if baseURL != "https://custom.openai.azure.com/openai/v1" { + t.Fatalf("expected trimmed base URL, got %s", baseURL) + } + if version != "v1" { + t.Fatalf("expected default api version v1, got %s", version) + } + + t.Setenv("AZURE_OPENAI_RESOURCE_NAME", "my-resource") + baseURL, _, err = ResolveAzureConfig(ai.Model{}, nil) + if err != nil { + t.Fatalf("unexpected error resolving from resource name: %v", err) + } + if baseURL != "https://my-resource.openai.azure.com/openai/v1" { + t.Fatalf("unexpected default resource base URL: %s", baseURL) + } +} + +func TestBuildAzureOpenAIResponsesParams(t *testing.T) { + t.Setenv("AZURE_OPENAI_DEPLOYMENT_NAME_MAP", "gpt-5=my-deployment") + + temp := 0.5 + params := BuildAzureOpenAIResponsesParams( + ai.Model{ + ID: "gpt-5", + Name: "GPT-5", + Provider: "azure-openai-responses", + API: ai.APIAzureOpenAIResponse, + Reasoning: true, + }, + ai.Context{ + SystemPrompt: "system", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "hello"}, + }, + }, + AzureOpenAIResponsesOptions{ + OpenAIResponsesOptions: OpenAIResponsesOptions{ + StreamOptions: ai.StreamOptions{ + SessionID: "session-id", + Temperature: &temp, + MaxTokens: 2048, + }, + ReasoningEffort: ai.ThinkingHigh, + ReasoningSummary: "auto", + }, + }, + ) + + if params["model"] != "my-deployment" { + t.Fatalf("expected deployment name from env map, got %#v", params["model"]) + } + if params["prompt_cache_key"] != "session-id" { + t.Fatalf("expected prompt cache key") + } + if params["max_output_tokens"] != 2048 { + t.Fatalf("expected max output tokens 2048") + } + if params["temperature"] != 0.5 { + t.Fatalf("expected temperature 0.5") + } + reasoning, ok := params["reasoning"].(map[string]any) + if !ok || reasoning["effort"] != "high" || reasoning["summary"] != "auto" { + t.Fatalf("unexpected reasoning payload: %#v", params["reasoning"]) + } + if _, ok := params["include"]; !ok { + t.Fatalf("expected include reasoning encrypted content") + } +} + +func TestResolveDeploymentNamePriority(t *testing.T) { + t.Setenv("AZURE_OPENAI_DEPLOYMENT_NAME_MAP", "gpt-5=map-deployment") + + model := ai.Model{ID: "gpt-5"} + name := ResolveDeploymentName(model, &AzureOpenAIResponsesOptions{ + AzureDeploymentName: "direct-deployment", + }) + if name != "direct-deployment" { + t.Fatalf("expected direct deployment override, got %s", name) + } + + os.Unsetenv("AZURE_OPENAI_DEPLOYMENT_NAME_MAP") + name = ResolveDeploymentName(model, nil) + if name != "gpt-5" { + t.Fatalf("expected fallback to model id, got %s", name) + } +} From 2749434e6cb25d03cb5208aa41f313ce48c938bf Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:17:53 +0000 Subject: [PATCH 12/75] Add deterministic API registry ordering and lifecycle tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/api_registry.go | 10 +++++++++ pkg/ai/api_registry_test.go | 42 +++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 pkg/ai/api_registry_test.go diff --git a/pkg/ai/api_registry.go b/pkg/ai/api_registry.go index 006377b1..b16975d3 100644 --- a/pkg/ai/api_registry.go +++ b/pkg/ai/api_registry.go @@ -2,6 +2,7 @@ package ai import ( "fmt" + "slices" "sync" ) @@ -47,6 +48,15 @@ func GetAPIProviders() []APIProvider { for _, entry := range registry { out = append(out, entry.provider) } + slices.SortFunc(out, func(a, b APIProvider) int { + if a.API == b.API { + return 0 + } + if a.API < b.API { + return -1 + } + return 1 + }) return out } diff --git a/pkg/ai/api_registry_test.go b/pkg/ai/api_registry_test.go new file mode 100644 index 00000000..bdf6dfb8 --- /dev/null +++ b/pkg/ai/api_registry_test.go @@ -0,0 +1,42 @@ +package ai + +import "testing" + +func TestAPIRegistryLifecycle(t *testing.T) { + ClearAPIProviders() + t.Cleanup(ClearAPIProviders) + + RegisterAPIProvider(APIProvider{ + API: APIOpenAIResponses, + StreamSimple: func(model Model, context Context, options *SimpleStreamOptions) *AssistantMessageEventStream { + return NewAssistantMessageEventStream(1) + }, + }, "source-a") + + RegisterAPIProvider(APIProvider{ + API: APIAnthropicMessages, + Stream: func(model Model, context Context, options *StreamOptions) *AssistantMessageEventStream { + return NewAssistantMessageEventStream(1) + }, + }, "source-b") + + providers := GetAPIProviders() + if len(providers) != 2 { + t.Fatalf("expected 2 providers, got %d", len(providers)) + } + if providers[0].API != APIAnthropicMessages || providers[1].API != APIOpenAIResponses { + t.Fatalf("expected providers sorted by api, got %#v", providers) + } + + if _, ok := GetAPIProvider(APIOpenAIResponses); !ok { + t.Fatalf("expected openai responses provider in registry") + } + + UnregisterAPIProviders("source-a") + if _, ok := GetAPIProvider(APIOpenAIResponses); ok { + t.Fatalf("expected source-a providers to be removed") + } + if _, ok := GetAPIProvider(APIAnthropicMessages); !ok { + t.Fatalf("expected source-b provider to remain") + } +} From 701fea6b443c381b85fd5406116914dda8351321 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:19:12 +0000 Subject: [PATCH 13/75] Add anthropic payload conversion parity tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/anthropic_test.go | 118 +++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 pkg/ai/providers/anthropic_test.go diff --git a/pkg/ai/providers/anthropic_test.go b/pkg/ai/providers/anthropic_test.go new file mode 100644 index 00000000..5173a84a --- /dev/null +++ b/pkg/ai/providers/anthropic_test.go @@ -0,0 +1,118 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestBuildAnthropicParams_WithThinkingAndCacheControl(t *testing.T) { + temp := 0.3 + params := BuildAnthropicParams( + ai.Model{ + ID: "claude-sonnet-4-5", + BaseURL: "https://api.anthropic.com", + }, + ai.Context{ + SystemPrompt: "You are helpful", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "hello"}, + }, + Tools: []ai.Tool{ + {Name: "calc", Description: "Calculate", Parameters: map[string]any{"type": "object"}}, + }, + }, + AnthropicOptions{ + StreamOptions: ai.StreamOptions{ + Temperature: &temp, + MaxTokens: 4096, + CacheRetention: ai.CacheRetentionLong, + }, + ThinkingEnabled: true, + ThinkingBudgetTokens: 2048, + Effort: "medium", + InterleavedThinking: true, + ToolChoice: "auto", + }, + ) + + if params["model"] != "claude-sonnet-4-5" { + t.Fatalf("expected model id set") + } + if params["max_tokens"] != 4096 { + t.Fatalf("expected max tokens 4096") + } + if params["temperature"] != 0.3 { + t.Fatalf("expected temperature 0.3") + } + if params["anthropic-beta"] != "interleaved-thinking-2025-05-14" { + t.Fatalf("expected interleaved thinking beta header") + } + systemBlocks, ok := params["system"].([]map[string]any) + if !ok || len(systemBlocks) != 1 { + t.Fatalf("expected one system block, got %#v", params["system"]) + } + cacheControl, ok := systemBlocks[0]["cache_control"].(map[string]any) + if !ok || cacheControl["ttl"] != "1h" { + t.Fatalf("expected anthropic cache control ttl 1h, got %#v", systemBlocks[0]["cache_control"]) + } + thinking, ok := params["thinking"].(map[string]any) + if !ok || thinking["budget_tokens"] != 2048 { + t.Fatalf("expected thinking budget tokens, got %#v", params["thinking"]) + } + tools, ok := params["tools"].([]map[string]any) + if !ok || len(tools) != 1 { + t.Fatalf("expected converted tools, got %#v", params["tools"]) + } + toolChoice, ok := params["tool_choice"].(map[string]any) + if !ok || toolChoice["type"] != "auto" { + t.Fatalf("expected auto tool_choice, got %#v", params["tool_choice"]) + } +} + +func TestConvertAnthropicMessages_ToolCallsAndToolResultFallback(t *testing.T) { + messages := convertAnthropicMessages( + ai.Model{ID: "claude-sonnet-4-5"}, + ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "invalid:tool id!", + Name: "lookup", + Arguments: map[string]any{"q": "go"}, + }, + }, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "invalid:tool id!", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeImage, MimeType: "image/png", Data: "abc"}, + }, + }, + }, + }, + ) + + if len(messages) != 2 { + t.Fatalf("expected two converted messages, got %d", len(messages)) + } + assistant := messages[0] + content := assistant["content"].([]map[string]any) + toolUse := content[0] + if toolUse["type"] != "tool_use" { + t.Fatalf("expected tool_use block, got %#v", toolUse) + } + if toolUse["id"] != "invalid_tool_id_" { + t.Fatalf("expected sanitized tool call id, got %#v", toolUse["id"]) + } + + toolResult := messages[1]["content"].([]map[string]any)[0] + innerContent := toolResult["content"].([]map[string]any) + if innerContent[0]["text"] != "(see attached image)" { + t.Fatalf("expected fallback text for non-text tool result, got %#v", innerContent[0]["text"]) + } +} From cd4caecdc26155c99d20fca93d9bb3a5e1ee296f Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:20:09 +0000 Subject: [PATCH 14/75] Add OpenAI responses parity tests for reasoning and tool IDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/openai_responses_test.go | 79 +++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 pkg/ai/providers/openai_responses_test.go diff --git a/pkg/ai/providers/openai_responses_test.go b/pkg/ai/providers/openai_responses_test.go new file mode 100644 index 00000000..35f99682 --- /dev/null +++ b/pkg/ai/providers/openai_responses_test.go @@ -0,0 +1,79 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestBuildOpenAIResponsesParams_ReasoningAndGPT5Fallback(t *testing.T) { + model := ai.Model{ + ID: "gpt-5", + Name: "GPT-5", + Provider: "openai", + API: ai.APIOpenAIResponses, + Reasoning: true, + BaseURL: "https://api.openai.com/v1", + } + + withReasoning := BuildOpenAIResponsesParams(model, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, OpenAIResponsesOptions{ + ReasoningEffort: ai.ThinkingHigh, + ReasoningSummary: "detailed", + }) + reasoning, ok := withReasoning["reasoning"].(map[string]any) + if !ok || reasoning["effort"] != "high" || reasoning["summary"] != "detailed" { + t.Fatalf("expected explicit reasoning payload, got %#v", withReasoning["reasoning"]) + } + include, ok := withReasoning["include"].([]string) + if !ok || len(include) != 1 || include[0] != "reasoning.encrypted_content" { + t.Fatalf("expected include reasoning encrypted content, got %#v", withReasoning["include"]) + } + + noReasoning := BuildOpenAIResponsesParams(model, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, OpenAIResponsesOptions{}) + input := noReasoning["input"].([]map[string]any) + last := input[len(input)-1] + if last["role"] != "developer" { + t.Fatalf("expected gpt-5 fallback developer hint when reasoning omitted, got %#v", last) + } +} + +func TestConvertOpenAIResponsesMessages_ToolCallIDPipeParsing(t *testing.T) { + messages := ConvertOpenAIResponsesMessages(ai.Model{}, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "call_123|item_456", + Name: "lookup", + Arguments: map[string]any{"q": "go"}, + }, + }, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call_123|item_456", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "result"}, + }, + }, + }, + }) + + if len(messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(messages)) + } + call := messages[0] + if call["call_id"] != "call_123" || call["id"] != "item_456" { + t.Fatalf("expected split call_id/item_id, got %#v", call) + } + result := messages[1] + if result["call_id"] != "call_123" { + t.Fatalf("expected tool result call_id to strip item_id suffix, got %#v", result["call_id"]) + } +} From 600bdbf8eedd3b2e26e5ddbe243150979a31b108 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:21:59 +0000 Subject: [PATCH 15/75] Register builtin api providers with terminal error streams MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/register_builtins.go | 67 +++++++++++++++++++++- pkg/ai/providers/register_builtins_test.go | 46 +++++++++++++++ 2 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 pkg/ai/providers/register_builtins_test.go diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index 790db2ad..f5fc8f18 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -1,12 +1,73 @@ package providers -import "github.com/beeper/ai-bridge/pkg/ai" +import ( + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) const BuiltinProviderSourceID = "pkg/ai/providers/register_builtins" +func notImplementedStream(apiID ai.Api) ai.StreamFn { + return func(model ai.Model, _ ai.Context, _ *ai.StreamOptions) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(2) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventError, + Error: ai.Message{ + Role: ai.RoleAssistant, + API: apiID, + Provider: model.Provider, + Model: model.ID, + StopReason: ai.StopReasonError, + ErrorMessage: "provider runtime is not implemented yet", + Timestamp: time.Now().UnixMilli(), + }, + Reason: ai.StopReasonError, + }) + return stream + } +} + +func notImplementedSimpleStream(apiID ai.Api) ai.StreamSimpleFn { + return func(model ai.Model, _ ai.Context, _ *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(2) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventError, + Error: ai.Message{ + Role: ai.RoleAssistant, + API: apiID, + Provider: model.Provider, + Model: model.ID, + StopReason: ai.StopReasonError, + ErrorMessage: "provider runtime is not implemented yet", + Timestamp: time.Now().UnixMilli(), + }, + Reason: ai.StopReasonError, + }) + return stream + } +} + // RegisterBuiltInAPIProviders registers providers implemented in this package. -// Initial scaffold keeps registry empty until concrete provider streamers are ported. -func RegisterBuiltInAPIProviders() {} +func RegisterBuiltInAPIProviders() { + for _, apiID := range []ai.Api{ + ai.APIOpenAICompletions, + ai.APIOpenAIResponses, + ai.APIAzureOpenAIResponse, + ai.APIOpenAICodexResponse, + ai.APIAnthropicMessages, + ai.APIGoogleGenerativeAI, + ai.APIGoogleGeminiCLI, + ai.APIGoogleVertex, + ai.APIBedrockConverse, + } { + ai.RegisterAPIProvider(ai.APIProvider{ + API: apiID, + Stream: notImplementedStream(apiID), + StreamSimple: notImplementedSimpleStream(apiID), + }, BuiltinProviderSourceID) + } +} func ResetAPIProviders() { ai.ClearAPIProviders() diff --git a/pkg/ai/providers/register_builtins_test.go b/pkg/ai/providers/register_builtins_test.go new file mode 100644 index 00000000..42ea0ca6 --- /dev/null +++ b/pkg/ai/providers/register_builtins_test.go @@ -0,0 +1,46 @@ +package providers + +import ( + "context" + "io" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestRegisterBuiltInAPIProviders(t *testing.T) { + ai.ClearAPIProviders() + t.Cleanup(ai.ClearAPIProviders) + + RegisterBuiltInAPIProviders() + providers := ai.GetAPIProviders() + if len(providers) < 9 { + t.Fatalf("expected builtin providers to be registered, got %d", len(providers)) + } + + stream, err := ai.Stream(ai.Model{ + ID: "gpt-5", + Provider: "openai", + API: ai.APIOpenAIResponses, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected stream resolve error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if evt.Error.StopReason != ai.StopReasonError { + t.Fatalf("expected stopReason=error, got %s", evt.Error.StopReason) + } + if _, err := stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} From b8e24b91f3588f83f7df76ff3f2ebd34569b9474 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:46:14 +0000 Subject: [PATCH 16/75] Add tool ID and Claude Code tool-name normalization parity tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/anthropic.go | 74 ++++++++++++++++--- pkg/ai/providers/anthropic_test.go | 61 +++++++++++++++ .../openai_completions_convert_test.go | 62 ++++++++++++++++ pkg/ai/providers/transform_messages_test.go | 43 +++++++++++ 4 files changed, 229 insertions(+), 11 deletions(-) diff --git a/pkg/ai/providers/anthropic.go b/pkg/ai/providers/anthropic.go index 1c898e91..3f4a610a 100644 --- a/pkg/ai/providers/anthropic.go +++ b/pkg/ai/providers/anthropic.go @@ -9,12 +9,13 @@ import ( ) type AnthropicOptions struct { - StreamOptions ai.StreamOptions - ThinkingEnabled bool - ThinkingBudgetTokens int - Effort string - InterleavedThinking bool - ToolChoice string + StreamOptions ai.StreamOptions + ThinkingEnabled bool + ThinkingBudgetTokens int + Effort string + InterleavedThinking bool + ToolChoice string + UseClaudeCodeToolNames bool } type cacheControl struct { @@ -22,6 +23,45 @@ type cacheControl struct { TTL string `json:"ttl,omitempty"` } +var claudeCodeToolLookup = map[string]string{ + "read": "Read", + "write": "Write", + "edit": "Edit", + "bash": "Bash", + "grep": "Grep", + "glob": "Glob", + "askuserquestion": "AskUserQuestion", + "enterplanmode": "EnterPlanMode", + "exitplanmode": "ExitPlanMode", + "killshell": "KillShell", + "notebookedit": "NotebookEdit", + "skill": "Skill", + "task": "Task", + "taskoutput": "TaskOutput", + "todowrite": "TodoWrite", + "webfetch": "WebFetch", + "websearch": "WebSearch", +} + +func ToClaudeCodeToolName(name string) string { + if normalized, ok := claudeCodeToolLookup[strings.ToLower(strings.TrimSpace(name))]; ok { + return normalized + } + return name +} + +func FromClaudeCodeToolName(name string, tools []ai.Tool) string { + if len(tools) > 0 { + lower := strings.ToLower(strings.TrimSpace(name)) + for _, tool := range tools { + if strings.ToLower(strings.TrimSpace(tool.Name)) == lower { + return tool.Name + } + } + } + return name +} + func resolveAnthropicCacheRetention(cacheRetention ai.CacheRetention) ai.CacheRetention { if cacheRetention != "" { return cacheRetention @@ -49,7 +89,7 @@ func BuildAnthropicParams(model ai.Model, context ai.Context, options AnthropicO "model": model.ID, "stream": true, "max_tokens": max(1024, options.StreamOptions.MaxTokens), - "messages": convertAnthropicMessages(model, context), + "messages": convertAnthropicMessagesInternal(model, context, options.UseClaudeCodeToolNames), } _, cache := GetAnthropicCacheControl(model.BaseURL, options.StreamOptions.CacheRetention) @@ -76,7 +116,7 @@ func BuildAnthropicParams(model ai.Model, context ai.Context, options AnthropicO params["tool_choice"] = map[string]any{"type": options.ToolChoice} } if len(context.Tools) > 0 { - params["tools"] = convertAnthropicTools(context.Tools) + params["tools"] = convertAnthropicTools(context.Tools, options.UseClaudeCodeToolNames) } if options.ThinkingEnabled { thinking := map[string]any{"type": "enabled"} @@ -95,6 +135,10 @@ func BuildAnthropicParams(model ai.Model, context ai.Context, options AnthropicO } func convertAnthropicMessages(model ai.Model, context ai.Context) []map[string]any { + return convertAnthropicMessagesInternal(model, context, false) +} + +func convertAnthropicMessagesInternal(model ai.Model, context ai.Context, useClaudeCodeToolNames bool) []map[string]any { transformed := TransformMessages(context.Messages, model, func(id string, _ ai.Model, _ ai.Message) string { sanitized := strings.Map(func(r rune) rune { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { @@ -159,10 +203,14 @@ func convertAnthropicMessages(model ai.Model, context ai.Context) []map[string]a } parts = append(parts, map[string]any{"type": "thinking", "thinking": utils.SanitizeSurrogates(block.Thinking)}) case ai.ContentTypeToolCall: + toolName := block.Name + if useClaudeCodeToolNames { + toolName = ToClaudeCodeToolName(toolName) + } parts = append(parts, map[string]any{ "type": "tool_use", "id": block.ID, - "name": block.Name, + "name": toolName, "input": block.Arguments, }) } @@ -204,11 +252,15 @@ func convertAnthropicMessages(model ai.Model, context ai.Context) []map[string]a return out } -func convertAnthropicTools(tools []ai.Tool) []map[string]any { +func convertAnthropicTools(tools []ai.Tool, useClaudeCodeToolNames bool) []map[string]any { out := make([]map[string]any, 0, len(tools)) for _, tool := range tools { + toolName := tool.Name + if useClaudeCodeToolNames { + toolName = ToClaudeCodeToolName(tool.Name) + } out = append(out, map[string]any{ - "name": tool.Name, + "name": toolName, "description": tool.Description, "input_schema": tool.Parameters, }) diff --git a/pkg/ai/providers/anthropic_test.go b/pkg/ai/providers/anthropic_test.go index 5173a84a..5a36f69d 100644 --- a/pkg/ai/providers/anthropic_test.go +++ b/pkg/ai/providers/anthropic_test.go @@ -116,3 +116,64 @@ func TestConvertAnthropicMessages_ToolCallsAndToolResultFallback(t *testing.T) { t.Fatalf("expected fallback text for non-text tool result, got %#v", innerContent[0]["text"]) } } + +func TestClaudeCodeToolNameNormalization(t *testing.T) { + tools := []ai.Tool{ + {Name: "todowrite"}, + {Name: "read"}, + {Name: "find"}, + {Name: "my_custom_tool"}, + } + + if got := ToClaudeCodeToolName("todowrite"); got != "TodoWrite" { + t.Fatalf("expected todowrite -> TodoWrite, got %q", got) + } + if got := FromClaudeCodeToolName("TodoWrite", tools); got != "todowrite" { + t.Fatalf("expected TodoWrite to map back to user tool name todowrite, got %q", got) + } + + // find is not a Claude Code canonical tool name and must not map to Glob. + if got := ToClaudeCodeToolName("find"); got != "find" { + t.Fatalf("expected find to remain unchanged, got %q", got) + } + if got := FromClaudeCodeToolName("Glob", tools); got != "Glob" { + t.Fatalf("expected Glob to stay Glob when no matching user tool exists, got %q", got) + } + + if got := ToClaudeCodeToolName("my_custom_tool"); got != "my_custom_tool" { + t.Fatalf("expected custom tool to remain unchanged, got %q", got) + } +} + +func TestBuildAnthropicParams_UsesClaudeCodeCanonicalToolNames(t *testing.T) { + params := BuildAnthropicParams( + ai.Model{ID: "claude-sonnet-4-5"}, + ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeToolCall, ID: "call_1", Name: "read", Arguments: map[string]any{"path": "/tmp/test.txt"}}, + }, + }, + }, + Tools: []ai.Tool{ + {Name: "read", Description: "Read file", Parameters: map[string]any{"type": "object"}}, + }, + }, + AnthropicOptions{ + UseClaudeCodeToolNames: true, + }, + ) + + tools := params["tools"].([]map[string]any) + if tools[0]["name"] != "Read" { + t.Fatalf("expected outbound tool schema name to use Claude Code casing, got %#v", tools[0]["name"]) + } + + messages := params["messages"].([]map[string]any) + assistantContent := messages[0]["content"].([]map[string]any) + if assistantContent[0]["name"] != "Read" { + t.Fatalf("expected outbound assistant tool_use name to use Claude Code casing, got %#v", assistantContent[0]["name"]) + } +} diff --git a/pkg/ai/providers/openai_completions_convert_test.go b/pkg/ai/providers/openai_completions_convert_test.go index 6c183712..4ef0322c 100644 --- a/pkg/ai/providers/openai_completions_convert_test.go +++ b/pkg/ai/providers/openai_completions_convert_test.go @@ -86,3 +86,65 @@ func TestConvertOpenAICompletionsMessages_BatchesToolResultImages(t *testing.T) t.Fatalf("expected 2 image parts, got %d", imageCount) } } + +func TestConvertOpenAICompletionsMessages_NormalizesPipeSeparatedToolCallIDs(t *testing.T) { + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAICompletions, + Provider: "openrouter", + BaseURL: "https://openrouter.ai/api/v1", + Input: []string{"text"}, + } + now := time.Now().UnixMilli() + context := ai.Context{ + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "Use tool", Timestamp: now}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "call_abc123|this-is-a-very-long-item-id-with-specials+/==", + Name: "echo", + Arguments: map[string]any{"message": "hello"}, + }, + }, + Provider: "github-copilot", + API: ai.APIOpenAIResponses, + Model: "gpt-5.2-codex", + StopReason: ai.StopReasonToolUse, + Timestamp: now + 1, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call_abc123|this-is-a-very-long-item-id-with-specials+/==", + ToolName: "echo", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "hello"}, + }, + Timestamp: now + 2, + }, + }, + } + + messages := ConvertOpenAICompletionsMessages(model, context, GetCompat(model)) + if len(messages) < 3 { + t.Fatalf("expected converted messages, got %d", len(messages)) + } + assistant := messages[1] + if len(assistant.ToolCalls) != 1 { + t.Fatalf("expected single assistant tool call, got %#v", assistant.ToolCalls) + } + normalizedID, _ := assistant.ToolCalls[0]["id"].(string) + if normalizedID == "" || normalizedID == "call_abc123|this-is-a-very-long-item-id-with-specials+/==" { + t.Fatalf("expected tool call id to be normalized, got %q", normalizedID) + } + + toolResult := messages[2] + if toolResult.Role != "tool" { + t.Fatalf("expected tool message, got %s", toolResult.Role) + } + if toolResult.ToolCallID != normalizedID { + t.Fatalf("expected tool result id to match normalized call id, got tool=%q assistant=%q", toolResult.ToolCallID, normalizedID) + } +} diff --git a/pkg/ai/providers/transform_messages_test.go b/pkg/ai/providers/transform_messages_test.go index 7d16dc78..7c8983e2 100644 --- a/pkg/ai/providers/transform_messages_test.go +++ b/pkg/ai/providers/transform_messages_test.go @@ -120,3 +120,46 @@ func TestTransformMessages_RemovesThoughtSignatureAcrossModels(t *testing.T) { } } } + +func TestTransformMessages_SynthesizesMissingToolResultBeforeNextUserTurn(t *testing.T) { + model := ai.Model{ + ID: "gpt-4o-mini", + API: ai.APIOpenAICompletions, + Provider: "openai", + } + now := time.Now().UnixMilli() + input := []ai.Message{ + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "call_1", + Name: "calculate", + Arguments: map[string]any{"expression": "25*18"}, + }, + }, + StopReason: ai.StopReasonToolUse, + Timestamp: now, + }, + { + Role: ai.RoleUser, + Text: "Never mind, what is 2+2?", + Timestamp: now + 1, + }, + } + + result := TransformMessages(input, model, nil) + if len(result) != 3 { + t.Fatalf("expected synthesized tool result to be inserted, got %d messages", len(result)) + } + if result[1].Role != ai.RoleToolResult { + t.Fatalf("expected synthesized message at index 1 to be toolResult, got %s", result[1].Role) + } + if result[1].ToolCallID != "call_1" || result[1].ToolName != "calculate" || !result[1].IsError { + t.Fatalf("unexpected synthesized toolResult payload: %#v", result[1]) + } + if len(result[1].Content) == 0 || result[1].Content[0].Text != "No result provided" { + t.Fatalf("expected synthesized toolResult content fallback") + } +} From 87cd7bd990e3fe3f880f79fdba27011d51565b58 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:48:05 +0000 Subject: [PATCH 17/75] Add GitHub Copilot dynamic header helper parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/github_copilot_headers.go | 39 ++++++++++++ .../providers/github_copilot_headers_test.go | 59 +++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 pkg/ai/providers/github_copilot_headers.go create mode 100644 pkg/ai/providers/github_copilot_headers_test.go diff --git a/pkg/ai/providers/github_copilot_headers.go b/pkg/ai/providers/github_copilot_headers.go new file mode 100644 index 00000000..1963a9d6 --- /dev/null +++ b/pkg/ai/providers/github_copilot_headers.go @@ -0,0 +1,39 @@ +package providers + +import "github.com/beeper/ai-bridge/pkg/ai" + +func InferCopilotInitiator(messages []ai.Message) string { + if len(messages) == 0 { + return "user" + } + last := messages[len(messages)-1] + if last.Role != ai.RoleUser { + return "agent" + } + return "user" +} + +func HasCopilotVisionInput(messages []ai.Message) bool { + for _, msg := range messages { + switch msg.Role { + case ai.RoleUser, ai.RoleToolResult: + for _, block := range msg.Content { + if block.Type == ai.ContentTypeImage { + return true + } + } + } + } + return false +} + +func BuildCopilotDynamicHeaders(messages []ai.Message, hasImages bool) map[string]string { + headers := map[string]string{ + "X-Initiator": InferCopilotInitiator(messages), + "Openai-Intent": "conversation-edits", + } + if hasImages { + headers["Copilot-Vision-Request"] = "true" + } + return headers +} diff --git a/pkg/ai/providers/github_copilot_headers_test.go b/pkg/ai/providers/github_copilot_headers_test.go new file mode 100644 index 00000000..f9d4ac02 --- /dev/null +++ b/pkg/ai/providers/github_copilot_headers_test.go @@ -0,0 +1,59 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestCopilotHeaderHelpers(t *testing.T) { + userOnly := []ai.Message{ + {Role: ai.RoleUser, Text: "hello"}, + } + if got := InferCopilotInitiator(userOnly); got != "user" { + t.Fatalf("expected user initiator, got %q", got) + } + + agentFollowUp := []ai.Message{ + {Role: ai.RoleUser, Text: "hello"}, + {Role: ai.RoleAssistant, Content: []ai.ContentBlock{{Type: ai.ContentTypeText, Text: "hi"}}}, + } + if got := InferCopilotInitiator(agentFollowUp); got != "agent" { + t.Fatalf("expected agent initiator for assistant tail, got %q", got) + } + + withImages := []ai.Message{ + { + Role: ai.RoleUser, + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeImage, MimeType: "image/png", Data: "abc"}, + }, + }, + } + if !HasCopilotVisionInput(withImages) { + t.Fatalf("expected vision input detection for user image") + } + + toolImages := []ai.Message{ + { + Role: ai.RoleToolResult, + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeImage, MimeType: "image/png", Data: "abc"}, + }, + }, + } + if !HasCopilotVisionInput(toolImages) { + t.Fatalf("expected vision input detection for tool result image") + } + + headers := BuildCopilotDynamicHeaders(agentFollowUp, true) + if headers["X-Initiator"] != "agent" { + t.Fatalf("expected X-Initiator=agent, got %#v", headers["X-Initiator"]) + } + if headers["Openai-Intent"] != "conversation-edits" { + t.Fatalf("expected Openai-Intent=conversation-edits") + } + if headers["Copilot-Vision-Request"] != "true" { + t.Fatalf("expected Copilot-Vision-Request=true") + } +} From 75ad89327ba9fb5bb3c686360e7d7bd8b4f9d473 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 03:52:15 +0000 Subject: [PATCH 18/75] Add shared OpenAI responses conversion helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/azure_openai_responses.go | 11 +- pkg/ai/providers/openai_responses.go | 141 +------------ pkg/ai/providers/openai_responses_shared.go | 195 ++++++++++++++++++ .../providers/openai_responses_shared_test.go | 102 +++++++++ 4 files changed, 315 insertions(+), 134 deletions(-) create mode 100644 pkg/ai/providers/openai_responses_shared.go create mode 100644 pkg/ai/providers/openai_responses_shared_test.go diff --git a/pkg/ai/providers/azure_openai_responses.go b/pkg/ai/providers/azure_openai_responses.go index 1741b3a4..373a2ee1 100644 --- a/pkg/ai/providers/azure_openai_responses.go +++ b/pkg/ai/providers/azure_openai_responses.go @@ -12,6 +12,13 @@ const defaultAzureAPIVersion = "v1" var ErrMissingAzureBaseURL = errors.New("azure openai base url is required") +var azureToolCallProviders = map[string]struct{}{ + "openai": {}, + "openai-codex": {}, + "opencode": {}, + "azure-openai-responses": {}, +} + type AzureOpenAIResponsesOptions struct { OpenAIResponsesOptions AzureAPIVersion string @@ -99,7 +106,7 @@ func BuildAzureOpenAIResponsesParams( options AzureOpenAIResponsesOptions, ) map[string]any { deploymentName := ResolveDeploymentName(model, &options) - messages := ConvertOpenAIResponsesMessages(model, context) + messages := ConvertResponsesMessages(model, context, azureToolCallProviders, nil) params := map[string]any{ "model": deploymentName, @@ -114,7 +121,7 @@ func BuildAzureOpenAIResponsesParams( params["temperature"] = *options.StreamOptions.Temperature } if len(context.Tools) > 0 { - params["tools"] = convertResponsesTools(context.Tools) + params["tools"] = ConvertResponsesTools(context.Tools, false) } if model.Reasoning { if options.ReasoningEffort != "" || strings.TrimSpace(options.ReasoningSummary) != "" { diff --git a/pkg/ai/providers/openai_responses.go b/pkg/ai/providers/openai_responses.go index ec793524..77bca9d4 100644 --- a/pkg/ai/providers/openai_responses.go +++ b/pkg/ai/providers/openai_responses.go @@ -1,12 +1,10 @@ package providers import ( - "encoding/json" "os" "strings" "github.com/beeper/ai-bridge/pkg/ai" - "github.com/beeper/ai-bridge/pkg/ai/utils" ) type OpenAIResponsesOptions struct { @@ -16,6 +14,12 @@ type OpenAIResponsesOptions struct { ServiceTier string } +var openAIToolCallProviders = map[string]struct{}{ + "openai": {}, + "openai-codex": {}, + "opencode": {}, +} + func ResolveCacheRetention(cacheRetention ai.CacheRetention) ai.CacheRetention { if cacheRetention != "" { return cacheRetention @@ -37,7 +41,7 @@ func GetPromptCacheRetention(baseURL string, cacheRetention ai.CacheRetention) s } func BuildOpenAIResponsesParams(model ai.Model, context ai.Context, options OpenAIResponsesOptions) map[string]any { - messages := ConvertOpenAIResponsesMessages(model, context) + messages := ConvertResponsesMessages(model, context, openAIToolCallProviders, nil) retention := ResolveCacheRetention(options.StreamOptions.CacheRetention) params := map[string]any{ "model": model.ID, @@ -55,7 +59,7 @@ func BuildOpenAIResponsesParams(model ai.Model, context ai.Context, options Open params["service_tier"] = options.ServiceTier } if context.Tools != nil { - params["tools"] = convertResponsesTools(context.Tools) + params["tools"] = ConvertResponsesTools(context.Tools, false) } if retention != ai.CacheRetentionNone && strings.TrimSpace(options.StreamOptions.SessionID) != "" { params["prompt_cache_key"] = options.StreamOptions.SessionID @@ -93,120 +97,7 @@ func BuildOpenAIResponsesParams(model ai.Model, context ai.Context, options Open } func ConvertOpenAIResponsesMessages(model ai.Model, context ai.Context) []map[string]any { - messages := make([]map[string]any, 0, len(context.Messages)+1) - if strings.TrimSpace(context.SystemPrompt) != "" { - role := "system" - if model.Reasoning { - role = "developer" - } - messages = append(messages, map[string]any{ - "role": role, - "content": utils.SanitizeSurrogates(context.SystemPrompt), - }) - } - - transformed := TransformMessages(context.Messages, model, nil) - for _, msg := range transformed { - switch msg.Role { - case ai.RoleUser: - content := []map[string]any{} - if strings.TrimSpace(msg.Text) != "" { - content = append(content, map[string]any{ - "type": "input_text", - "text": utils.SanitizeSurrogates(msg.Text), - }) - } - for _, block := range msg.Content { - if block.Type == ai.ContentTypeText && strings.TrimSpace(block.Text) != "" { - content = append(content, map[string]any{ - "type": "input_text", - "text": utils.SanitizeSurrogates(block.Text), - }) - } - if block.Type == ai.ContentTypeImage { - content = append(content, map[string]any{ - "type": "input_image", - "detail": "auto", - "image_url": "data:" + block.MimeType + ";base64," + block.Data, - }) - } - } - if len(content) == 0 { - continue - } - messages = append(messages, map[string]any{ - "role": "user", - "content": content, - }) - case ai.RoleAssistant: - for _, block := range msg.Content { - switch block.Type { - case ai.ContentTypeText: - messages = append(messages, map[string]any{ - "type": "message", - "role": "assistant", - "status": "completed", - "id": fallbackTextID(block.TextSignature), - "content": []map[string]any{{ - "type": "output_text", - "text": utils.SanitizeSurrogates(block.Text), - "annotations": []any{}, - }}, - }) - case ai.ContentTypeThinking: - if block.ThinkingSignature != "" { - // signature payload is already serialized response item. - // best-effort keep as text fallback when opaque. - messages = append(messages, map[string]any{ - "type": "reasoning", - "summary": []map[string]any{{"type": "summary_text", "text": block.Thinking}}, - }) - } - case ai.ContentTypeToolCall: - parts := strings.SplitN(block.ID, "|", 2) - callID := block.ID - itemID := "" - if len(parts) == 2 { - callID = parts[0] - itemID = parts[1] - } - args := "{}" - if block.Arguments != nil { - b, _ := json.Marshal(block.Arguments) - args = string(b) - } - messages = append(messages, map[string]any{ - "type": "function_call", - "id": itemID, - "call_id": callID, - "name": block.Name, - "arguments": args, - }) - } - } - case ai.RoleToolResult: - callID := msg.ToolCallID - if strings.Contains(callID, "|") { - callID = strings.SplitN(callID, "|", 2)[0] - } - output := "(see attached image)" - var textParts []string - for _, block := range msg.Content { - if block.Type == ai.ContentTypeText { - textParts = append(textParts, block.Text) - } - } - if len(textParts) > 0 { - output = strings.Join(textParts, "\n") - } - messages = append(messages, map[string]any{ - "type": "function_call_output", - "call_id": callID, - "output": utils.SanitizeSurrogates(output), - }) - } - } - return messages + return ConvertResponsesMessages(model, context, openAIToolCallProviders, nil) } func fallbackTextID(signature string) string { @@ -218,17 +109,3 @@ func fallbackTextID(signature string) string { } return "msg_0" } - -func convertResponsesTools(tools []ai.Tool) []map[string]any { - out := make([]map[string]any, 0, len(tools)) - for _, tool := range tools { - out = append(out, map[string]any{ - "type": "function", - "name": tool.Name, - "description": tool.Description, - "parameters": tool.Parameters, - "strict": false, - }) - } - return out -} diff --git a/pkg/ai/providers/openai_responses_shared.go b/pkg/ai/providers/openai_responses_shared.go new file mode 100644 index 00000000..72777c74 --- /dev/null +++ b/pkg/ai/providers/openai_responses_shared.go @@ -0,0 +1,195 @@ +package providers + +import ( + "encoding/json" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +type ConvertResponsesMessagesOptions struct { + IncludeSystemPrompt bool +} + +func NormalizeResponsesToolCallID(id string) string { + if !strings.Contains(id, "|") { + return id + } + parts := strings.SplitN(id, "|", 2) + callID := sanitizeResponsesIDPart(parts[0], 64) + itemID := sanitizeResponsesIDPart(parts[1], 64) + if !strings.HasPrefix(itemID, "fc") { + itemID = "fc_" + itemID + } + callID = strings.TrimRight(callID, "_") + itemID = strings.TrimRight(itemID, "_") + if callID == "" { + callID = "call" + } + if itemID == "" { + itemID = "fc_item" + } + return callID + "|" + itemID +} + +func ConvertResponsesMessages( + model ai.Model, + context ai.Context, + allowedToolCallProviders map[string]struct{}, + options *ConvertResponsesMessagesOptions, +) []map[string]any { + includeSystemPrompt := true + if options != nil { + includeSystemPrompt = options.IncludeSystemPrompt + } + normalizeToolCallID := func(id string, _ ai.Model, _ ai.Message) string { + if _, ok := allowedToolCallProviders[string(model.Provider)]; !ok { + return id + } + return NormalizeResponsesToolCallID(id) + } + transformed := TransformMessages(context.Messages, model, normalizeToolCallID) + + messages := make([]map[string]any, 0, len(transformed)+1) + if includeSystemPrompt && strings.TrimSpace(context.SystemPrompt) != "" { + role := "system" + if model.Reasoning { + role = "developer" + } + messages = append(messages, map[string]any{ + "role": role, + "content": utils.SanitizeSurrogates(context.SystemPrompt), + }) + } + + for _, msg := range transformed { + switch msg.Role { + case ai.RoleUser: + content := []map[string]any{} + if strings.TrimSpace(msg.Text) != "" { + content = append(content, map[string]any{ + "type": "input_text", + "text": utils.SanitizeSurrogates(msg.Text), + }) + } + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText && strings.TrimSpace(block.Text) != "" { + content = append(content, map[string]any{ + "type": "input_text", + "text": utils.SanitizeSurrogates(block.Text), + }) + } + if block.Type == ai.ContentTypeImage { + content = append(content, map[string]any{ + "type": "input_image", + "detail": "auto", + "image_url": "data:" + block.MimeType + ";base64," + block.Data, + }) + } + } + if len(content) == 0 { + continue + } + messages = append(messages, map[string]any{ + "role": "user", + "content": content, + }) + case ai.RoleAssistant: + for _, block := range msg.Content { + switch block.Type { + case ai.ContentTypeText: + messages = append(messages, map[string]any{ + "type": "message", + "role": "assistant", + "status": "completed", + "id": fallbackTextID(block.TextSignature), + "content": []map[string]any{{ + "type": "output_text", + "text": utils.SanitizeSurrogates(block.Text), + "annotations": []any{}, + }}, + }) + case ai.ContentTypeThinking: + if block.ThinkingSignature != "" { + messages = append(messages, map[string]any{ + "type": "reasoning", + "summary": []map[string]any{{"type": "summary_text", "text": block.Thinking}}, + }) + } + case ai.ContentTypeToolCall: + parts := strings.SplitN(block.ID, "|", 2) + callID := block.ID + itemID := "" + if len(parts) == 2 { + callID = parts[0] + itemID = parts[1] + } + args := "{}" + if block.Arguments != nil { + b, _ := json.Marshal(block.Arguments) + args = string(b) + } + messages = append(messages, map[string]any{ + "type": "function_call", + "id": itemID, + "call_id": callID, + "name": block.Name, + "arguments": args, + }) + } + } + case ai.RoleToolResult: + callID := msg.ToolCallID + if strings.Contains(callID, "|") { + callID = strings.SplitN(callID, "|", 2)[0] + } + output := "(see attached image)" + var textParts []string + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText { + textParts = append(textParts, block.Text) + } + } + if len(textParts) > 0 { + output = strings.Join(textParts, "\n") + } + messages = append(messages, map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": utils.SanitizeSurrogates(output), + }) + } + } + return messages +} + +func ConvertResponsesTools(tools []ai.Tool, strict bool) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + out = append(out, map[string]any{ + "type": "function", + "name": tool.Name, + "description": tool.Description, + "parameters": tool.Parameters, + "strict": strict, + }) + } + return out +} + +func sanitizeResponsesIDPart(id string, maxLen int) string { + var b strings.Builder + for _, r := range id { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + b.WriteRune(r) + } else { + b.WriteRune('_') + } + } + out := b.String() + if len(out) > maxLen { + out = out[:maxLen] + } + return out +} diff --git a/pkg/ai/providers/openai_responses_shared_test.go b/pkg/ai/providers/openai_responses_shared_test.go new file mode 100644 index 00000000..9adf7c3a --- /dev/null +++ b/pkg/ai/providers/openai_responses_shared_test.go @@ -0,0 +1,102 @@ +package providers + +import ( + "strings" + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestNormalizeResponsesToolCallID(t *testing.T) { + got := NormalizeResponsesToolCallID("call_abc|item+/==") + if !strings.Contains(got, "|") { + t.Fatalf("expected normalized id to keep pipe separator, got %q", got) + } + parts := strings.SplitN(got, "|", 2) + if len(parts) != 2 { + t.Fatalf("expected two parts in normalized id, got %q", got) + } + if strings.ContainsAny(parts[0], "+/=") { + t.Fatalf("expected call id sanitized, got %q", parts[0]) + } + if !strings.HasPrefix(parts[1], "fc") { + t.Fatalf("expected item id to start with fc prefix, got %q", parts[1]) + } +} + +func TestConvertResponsesMessages_NormalizesAllowedProviderToolIDs(t *testing.T) { + model := ai.Model{ + ID: "gpt-5", + Provider: "openai", + API: ai.APIOpenAIResponses, + } + context := ai.Context{ + SystemPrompt: "system prompt", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "hi"}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "call_abc|item+/==", + Name: "echo", + Arguments: map[string]any{"message": "hello"}, + }, + }, + Provider: "github-copilot", + API: ai.APIOpenAIResponses, + Model: "gpt-5.2-codex", + StopReason: ai.StopReasonToolUse, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call_abc|item+/==", + ToolName: "echo", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "hello"}, + }, + }, + }, + } + + output := ConvertResponsesMessages(model, context, openAIToolCallProviders, nil) + if len(output) < 4 { + t.Fatalf("expected converted response input items, got %d", len(output)) + } + functionCall := output[2] + callID, _ := functionCall["call_id"].(string) + itemID, _ := functionCall["id"].(string) + if callID == "call_abc" && strings.Contains(itemID, "+") { + t.Fatalf("expected normalized function call ids, got call=%q item=%q", callID, itemID) + } + if !strings.HasPrefix(itemID, "fc") { + t.Fatalf("expected function call item id to start with fc, got %q", itemID) + } + + functionOutput := output[3] + if functionOutput["call_id"] != callID { + t.Fatalf("expected function_call_output call_id to match normalized call_id, got output=%q call=%q", functionOutput["call_id"], callID) + } +} + +func TestConvertResponsesMessages_CanOmitSystemPrompt(t *testing.T) { + output := ConvertResponsesMessages( + ai.Model{Provider: "openai", API: ai.APIOpenAIResponses}, + ai.Context{ + SystemPrompt: "system prompt", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "hello"}, + }, + }, + openAIToolCallProviders, + &ConvertResponsesMessagesOptions{IncludeSystemPrompt: false}, + ) + if len(output) == 0 { + t.Fatalf("expected user message output") + } + first := output[0] + if role, _ := first["role"].(string); role == "system" || role == "developer" { + t.Fatalf("expected no system/developer prompt in output when omitted, got %#v", first) + } +} From f19abadcae3f824263d0ea0bda7637619c653257 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:01:44 +0000 Subject: [PATCH 19/75] Add Google Vertex and Bedrock provider helper parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/amazon_bedrock.go | 158 ++++++++++++++++++++++++ pkg/ai/providers/amazon_bedrock_test.go | 94 ++++++++++++++ pkg/ai/providers/google.go | 63 ++++++++++ pkg/ai/providers/google_shared.go | 48 +++++++ pkg/ai/providers/google_shared_test.go | 48 +++++++ pkg/ai/providers/google_test.go | 63 ++++++++++ pkg/ai/providers/google_vertex.go | 49 ++++++++ pkg/ai/providers/google_vertex_test.go | 38 ++++++ 8 files changed, 561 insertions(+) create mode 100644 pkg/ai/providers/amazon_bedrock.go create mode 100644 pkg/ai/providers/amazon_bedrock_test.go create mode 100644 pkg/ai/providers/google.go create mode 100644 pkg/ai/providers/google_test.go create mode 100644 pkg/ai/providers/google_vertex.go create mode 100644 pkg/ai/providers/google_vertex_test.go diff --git a/pkg/ai/providers/amazon_bedrock.go b/pkg/ai/providers/amazon_bedrock.go new file mode 100644 index 00000000..879bda78 --- /dev/null +++ b/pkg/ai/providers/amazon_bedrock.go @@ -0,0 +1,158 @@ +package providers + +import ( + "os" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +type BedrockToolChoice struct { + Type string + Name string +} + +type BedrockOptions struct { + StreamOptions ai.StreamOptions + Region string + Profile string + ToolChoice any + Reasoning ai.ThinkingLevel + ThinkingBudgets ai.ThinkingBudgets + InterleavedThinking *bool +} + +func SupportsAdaptiveThinking(modelID string) bool { + id := strings.ToLower(modelID) + return strings.Contains(id, "opus-4-6") || + strings.Contains(id, "opus-4.6") || + strings.Contains(id, "sonnet-4-6") || + strings.Contains(id, "sonnet-4.6") +} + +func ResolveBedrockCacheRetention(cacheRetention ai.CacheRetention) ai.CacheRetention { + if cacheRetention != "" { + return cacheRetention + } + if strings.EqualFold(os.Getenv("PI_CACHE_RETENTION"), "long") { + return ai.CacheRetentionLong + } + return ai.CacheRetentionShort +} + +func BuildBedrockSystemPrompt(systemPrompt string, model ai.Model, cacheRetention ai.CacheRetention) []map[string]any { + if strings.TrimSpace(systemPrompt) == "" { + return nil + } + blocks := []map[string]any{ + {"text": utils.SanitizeSurrogates(systemPrompt)}, + } + if cacheRetention != ai.CacheRetentionNone && supportsBedrockPromptCaching(model) { + cachePoint := map[string]any{"type": "default"} + if cacheRetention == ai.CacheRetentionLong { + cachePoint["ttl"] = "1h" + } + blocks = append(blocks, map[string]any{ + "cachePoint": cachePoint, + }) + } + return blocks +} + +func supportsBedrockPromptCaching(model ai.Model) bool { + if model.Cost.CacheRead > 0 || model.Cost.CacheWrite > 0 { + return true + } + id := strings.ToLower(model.ID) + if strings.Contains(id, "claude") && (strings.Contains(id, "-4-") || strings.Contains(id, "-4.")) { + return true + } + if strings.Contains(id, "claude-3-7-sonnet") { + return true + } + if strings.Contains(id, "claude-3-5-haiku") { + return true + } + return false +} + +func MapBedrockStopReason(reason string) ai.StopReason { + switch strings.ToUpper(strings.TrimSpace(reason)) { + case "END_TURN", "STOP_SEQUENCE": + return ai.StopReasonStop + case "MAX_TOKENS", "MODEL_CONTEXT_WINDOW_EXCEEDED": + return ai.StopReasonLength + case "TOOL_USE": + return ai.StopReasonToolUse + default: + return ai.StopReasonError + } +} + +func BuildBedrockAdditionalModelRequestFields(model ai.Model, options BedrockOptions) map[string]any { + if options.Reasoning == "" || !model.Reasoning { + return nil + } + id := strings.ToLower(model.ID) + if !strings.Contains(id, "anthropic.claude") && !strings.Contains(id, "anthropic/claude") { + return nil + } + + if SupportsAdaptiveThinking(model.ID) { + effort := "high" + if options.Reasoning == ai.ThinkingXHigh && (strings.Contains(id, "opus-4-6") || strings.Contains(id, "opus-4.6")) { + effort = "max" + } else { + effort = mapBedrockThinkingEffort(options.Reasoning) + } + return map[string]any{ + "thinking": map[string]any{"type": "adaptive"}, + "output_config": map[string]any{"effort": effort}, + } + } + + level := ClampReasoning(options.Reasoning) + budgets := mergeThinkingBudgets(ai.ThinkingBudgets{ + Minimal: 1024, + Low: 2048, + Medium: 8192, + High: 16384, + }, options.ThinkingBudgets) + budget := budgets.High + switch level { + case ai.ThinkingMinimal: + budget = budgets.Minimal + case ai.ThinkingLow: + budget = budgets.Low + case ai.ThinkingMedium: + budget = budgets.Medium + } + result := map[string]any{ + "thinking": map[string]any{ + "type": "enabled", + "budget_tokens": budget, + }, + } + interleaved := true + if options.InterleavedThinking != nil { + interleaved = *options.InterleavedThinking + } + if interleaved { + result["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"} + } + return result +} + +func mapBedrockThinkingEffort(level ai.ThinkingLevel) string { + switch level { + case ai.ThinkingMinimal, ai.ThinkingLow: + return "low" + case ai.ThinkingMedium: + return "medium" + case ai.ThinkingHigh, ai.ThinkingXHigh: + return "high" + default: + return "high" + } +} diff --git a/pkg/ai/providers/amazon_bedrock_test.go b/pkg/ai/providers/amazon_bedrock_test.go new file mode 100644 index 00000000..bbaf967d --- /dev/null +++ b/pkg/ai/providers/amazon_bedrock_test.go @@ -0,0 +1,94 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestBedrockHelperFunctions(t *testing.T) { + if !SupportsAdaptiveThinking("global.anthropic.claude-opus-4-6-v1") { + t.Fatalf("expected adaptive thinking support for opus 4.6") + } + if SupportsAdaptiveThinking("global.anthropic.claude-sonnet-4-5-v1") { + t.Fatalf("did not expect adaptive thinking support for sonnet 4.5") + } + + t.Setenv("PI_CACHE_RETENTION", "long") + if got := ResolveBedrockCacheRetention(""); got != ai.CacheRetentionLong { + t.Fatalf("expected env long cache retention, got %s", got) + } + if got := ResolveBedrockCacheRetention(ai.CacheRetentionNone); got != ai.CacheRetentionNone { + t.Fatalf("expected explicit cache retention none to win, got %s", got) + } + + system := BuildBedrockSystemPrompt( + "You are helpful", + ai.Model{ID: "global.anthropic.claude-sonnet-4-5-v1:0"}, + ai.CacheRetentionLong, + ) + if len(system) != 2 { + t.Fatalf("expected system + cache point for cacheable model, got %#v", system) + } + cachePoint := system[1]["cachePoint"].(map[string]any) + if cachePoint["ttl"] != "1h" { + t.Fatalf("expected long cache ttl=1h, got %#v", cachePoint) + } + + if got := MapBedrockStopReason("TOOL_USE"); got != ai.StopReasonToolUse { + t.Fatalf("expected TOOL_USE->toolUse, got %s", got) + } + if got := MapBedrockStopReason("MAX_TOKENS"); got != ai.StopReasonLength { + t.Fatalf("expected MAX_TOKENS->length, got %s", got) + } + if got := MapBedrockStopReason("OTHER_REASON"); got != ai.StopReasonError { + t.Fatalf("expected unknown->error, got %s", got) + } +} + +func TestBuildBedrockAdditionalModelRequestFields(t *testing.T) { + interleaved := true + fields := BuildBedrockAdditionalModelRequestFields( + ai.Model{ + ID: "global.anthropic.claude-sonnet-4-5-v1:0", + Provider: "amazon-bedrock", + API: ai.APIBedrockConverse, + Reasoning: true, + }, + BedrockOptions{ + Reasoning: ai.ThinkingMedium, + ThinkingBudgets: ai.ThinkingBudgets{Medium: 6000}, + InterleavedThinking: &interleaved, + }, + ) + if fields == nil { + t.Fatalf("expected additional fields for Claude model reasoning") + } + thinking := fields["thinking"].(map[string]any) + if thinking["type"] != "enabled" || thinking["budget_tokens"] != 6000 { + t.Fatalf("unexpected non-adaptive thinking payload: %#v", thinking) + } + beta := fields["anthropic_beta"].([]string) + if len(beta) != 1 || beta[0] != "interleaved-thinking-2025-05-14" { + t.Fatalf("expected interleaved thinking beta flag, got %#v", beta) + } + + adaptive := BuildBedrockAdditionalModelRequestFields( + ai.Model{ + ID: "global.anthropic.claude-opus-4-6-v1", + Provider: "amazon-bedrock", + API: ai.APIBedrockConverse, + Reasoning: true, + }, + BedrockOptions{ + Reasoning: ai.ThinkingXHigh, + }, + ) + if adaptive["thinking"].(map[string]any)["type"] != "adaptive" { + t.Fatalf("expected adaptive thinking payload for opus-4-6") + } + outputConfig := adaptive["output_config"].(map[string]any) + if outputConfig["effort"] != "max" { + t.Fatalf("expected xhigh on opus-4-6 to map to max effort, got %#v", outputConfig["effort"]) + } +} diff --git a/pkg/ai/providers/google.go b/pkg/ai/providers/google.go new file mode 100644 index 00000000..1cd3990c --- /dev/null +++ b/pkg/ai/providers/google.go @@ -0,0 +1,63 @@ +package providers + +import ( + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +type GoogleThinkingOptions struct { + Enabled bool + BudgetTokens *int + Level string +} + +type GoogleOptions struct { + StreamOptions ai.StreamOptions + ToolChoice string + Thinking *GoogleThinkingOptions +} + +func BuildGoogleGenerateContentParams(model ai.Model, context ai.Context, options GoogleOptions) map[string]any { + params := map[string]any{ + "model": model.ID, + "contents": ConvertGoogleMessages(model, context), + } + + config := map[string]any{} + if options.StreamOptions.Temperature != nil { + config["temperature"] = *options.StreamOptions.Temperature + } + if options.StreamOptions.MaxTokens > 0 { + config["maxOutputTokens"] = options.StreamOptions.MaxTokens + } + if strings.TrimSpace(context.SystemPrompt) != "" { + config["systemInstruction"] = utils.SanitizeSurrogates(context.SystemPrompt) + } + if len(context.Tools) > 0 { + config["tools"] = ConvertGoogleTools(context.Tools, false) + if strings.TrimSpace(options.ToolChoice) != "" { + config["toolConfig"] = map[string]any{ + "functionCallingConfig": map[string]any{ + "mode": MapGoogleToolChoice(options.ToolChoice), + }, + } + } + } + if options.Thinking != nil && options.Thinking.Enabled && model.Reasoning { + thinkingConfig := map[string]any{ + "includeThoughts": true, + } + if strings.TrimSpace(options.Thinking.Level) != "" { + thinkingConfig["thinkingLevel"] = strings.ToUpper(strings.TrimSpace(options.Thinking.Level)) + } else if options.Thinking.BudgetTokens != nil { + thinkingConfig["thinkingBudget"] = *options.Thinking.BudgetTokens + } + config["thinkingConfig"] = thinkingConfig + } + if len(config) > 0 { + params["config"] = config + } + return params +} diff --git a/pkg/ai/providers/google_shared.go b/pkg/ai/providers/google_shared.go index 3c9c52a1..d53a964e 100644 --- a/pkg/ai/providers/google_shared.go +++ b/pkg/ai/providers/google_shared.go @@ -41,6 +41,54 @@ type GoogleInlineData struct { Data string `json:"data"` } +func ConvertGoogleTools(tools []ai.Tool, useParameters bool) []map[string]any { + if len(tools) == 0 { + return nil + } + functions := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + declaration := map[string]any{ + "name": tool.Name, + "description": tool.Description, + } + if useParameters { + declaration["parameters"] = tool.Parameters + } else { + declaration["parametersJsonSchema"] = tool.Parameters + } + functions = append(functions, declaration) + } + return []map[string]any{ + { + "functionDeclarations": functions, + }, + } +} + +func MapGoogleToolChoice(choice string) string { + switch strings.ToLower(strings.TrimSpace(choice)) { + case "none": + return "NONE" + case "any": + return "ANY" + default: + return "AUTO" + } +} + +func MapGoogleStopReason(reason string) ai.StopReason { + switch strings.ToUpper(strings.TrimSpace(reason)) { + case "STOP": + return ai.StopReasonStop + case "MAX_TOKENS": + return ai.StopReasonLength + case "TOOL_USE": + return ai.StopReasonToolUse + default: + return ai.StopReasonError + } +} + func IsThinkingPart(part GooglePart) bool { return part.Thought } diff --git a/pkg/ai/providers/google_shared_test.go b/pkg/ai/providers/google_shared_test.go index 98e2a76c..82ba5efa 100644 --- a/pkg/ai/providers/google_shared_test.go +++ b/pkg/ai/providers/google_shared_test.go @@ -90,3 +90,51 @@ func TestConvertMessages_ConvertsUnsignedToolCallsToHistoricalTextForGemini3(t * t.Fatalf("unexpected historical context text: %s", joined) } } + +func TestGoogleSharedToolAndStopReasonHelpers(t *testing.T) { + tools := ConvertGoogleTools([]ai.Tool{ + { + Name: "search", + Description: "Search", + Parameters: map[string]any{"type": "object"}, + }, + }, false) + if len(tools) != 1 { + t.Fatalf("expected one Gemini tools wrapper entry, got %d", len(tools)) + } + declarations, _ := tools[0]["functionDeclarations"].([]map[string]any) + if len(declarations) != 1 { + t.Fatalf("expected one function declaration, got %#v", tools[0]["functionDeclarations"]) + } + if _, ok := declarations[0]["parametersJsonSchema"]; !ok { + t.Fatalf("expected parametersJsonSchema in default conversion") + } + + legacy := ConvertGoogleTools([]ai.Tool{ + {Name: "search", Parameters: map[string]any{"type": "object"}}, + }, true) + legacyDecls, _ := legacy[0]["functionDeclarations"].([]map[string]any) + if _, ok := legacyDecls[0]["parameters"]; !ok { + t.Fatalf("expected parameters field when useParameters=true") + } + + if got := MapGoogleToolChoice("any"); got != "ANY" { + t.Fatalf("expected any->ANY, got %q", got) + } + if got := MapGoogleToolChoice("none"); got != "NONE" { + t.Fatalf("expected none->NONE, got %q", got) + } + if got := MapGoogleToolChoice("unexpected"); got != "AUTO" { + t.Fatalf("expected unknown->AUTO, got %q", got) + } + + if got := MapGoogleStopReason("STOP"); got != ai.StopReasonStop { + t.Fatalf("expected STOP->stop, got %q", got) + } + if got := MapGoogleStopReason("MAX_TOKENS"); got != ai.StopReasonLength { + t.Fatalf("expected MAX_TOKENS->length, got %q", got) + } + if got := MapGoogleStopReason("other"); got != ai.StopReasonError { + t.Fatalf("expected unknown->error, got %q", got) + } +} diff --git a/pkg/ai/providers/google_test.go b/pkg/ai/providers/google_test.go new file mode 100644 index 00000000..95480350 --- /dev/null +++ b/pkg/ai/providers/google_test.go @@ -0,0 +1,63 @@ +package providers + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestBuildGoogleGenerateContentParams(t *testing.T) { + temp := 0.4 + budget := 2048 + params := BuildGoogleGenerateContentParams( + ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google", + API: ai.APIGoogleGenerativeAI, + Reasoning: true, + Input: []string{"text"}, + }, + ai.Context{ + SystemPrompt: "You are helpful", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "hello"}, + }, + Tools: []ai.Tool{ + {Name: "search", Description: "Search", Parameters: map[string]any{"type": "object"}}, + }, + }, + GoogleOptions{ + StreamOptions: ai.StreamOptions{ + Temperature: &temp, + MaxTokens: 2048, + }, + ToolChoice: "any", + Thinking: &GoogleThinkingOptions{ + Enabled: true, + BudgetTokens: &budget, + }, + }, + ) + if params["model"] != "gemini-2.5-flash" { + t.Fatalf("expected model id in params") + } + config, ok := params["config"].(map[string]any) + if !ok { + t.Fatalf("expected config payload") + } + if config["temperature"] != 0.4 || config["maxOutputTokens"] != 2048 { + t.Fatalf("unexpected generation config: %#v", config) + } + toolConfig, ok := config["toolConfig"].(map[string]any) + if !ok { + t.Fatalf("expected toolConfig when tools+toolChoice are present") + } + mode := toolConfig["functionCallingConfig"].(map[string]any)["mode"] + if mode != "ANY" { + t.Fatalf("expected ANY tool mode, got %#v", mode) + } + thinking := config["thinkingConfig"].(map[string]any) + if thinking["includeThoughts"] != true || thinking["thinkingBudget"] != 2048 { + t.Fatalf("unexpected thinking config: %#v", thinking) + } +} diff --git a/pkg/ai/providers/google_vertex.go b/pkg/ai/providers/google_vertex.go new file mode 100644 index 00000000..d06b8855 --- /dev/null +++ b/pkg/ai/providers/google_vertex.go @@ -0,0 +1,49 @@ +package providers + +import ( + "errors" + "os" + "strings" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +const GoogleVertexAPIVersion = "v1" + +var ( + ErrMissingVertexProject = errors.New("vertex ai project is required") + ErrMissingVertexLocation = errors.New("vertex ai location is required") +) + +type GoogleVertexOptions struct { + GoogleOptions + Project string + Location string +} + +func ResolveGoogleVertexProject(options *GoogleVertexOptions) (string, error) { + if options != nil && strings.TrimSpace(options.Project) != "" { + return strings.TrimSpace(options.Project), nil + } + if env := strings.TrimSpace(os.Getenv("GOOGLE_CLOUD_PROJECT")); env != "" { + return env, nil + } + if env := strings.TrimSpace(os.Getenv("GCLOUD_PROJECT")); env != "" { + return env, nil + } + return "", ErrMissingVertexProject +} + +func ResolveGoogleVertexLocation(options *GoogleVertexOptions) (string, error) { + if options != nil && strings.TrimSpace(options.Location) != "" { + return strings.TrimSpace(options.Location), nil + } + if env := strings.TrimSpace(os.Getenv("GOOGLE_CLOUD_LOCATION")); env != "" { + return env, nil + } + return "", ErrMissingVertexLocation +} + +func BuildGoogleVertexGenerateContentParams(model ai.Model, context ai.Context, options GoogleVertexOptions) map[string]any { + return BuildGoogleGenerateContentParams(model, context, options.GoogleOptions) +} diff --git a/pkg/ai/providers/google_vertex_test.go b/pkg/ai/providers/google_vertex_test.go new file mode 100644 index 00000000..1997c9e7 --- /dev/null +++ b/pkg/ai/providers/google_vertex_test.go @@ -0,0 +1,38 @@ +package providers + +import ( + "testing" +) + +func TestResolveGoogleVertexProjectAndLocation(t *testing.T) { + t.Setenv("GOOGLE_CLOUD_PROJECT", "") + t.Setenv("GCLOUD_PROJECT", "") + t.Setenv("GOOGLE_CLOUD_LOCATION", "") + + if _, err := ResolveGoogleVertexProject(nil); err == nil { + t.Fatalf("expected missing project error") + } + if _, err := ResolveGoogleVertexLocation(nil); err == nil { + t.Fatalf("expected missing location error") + } + + t.Setenv("GOOGLE_CLOUD_PROJECT", "env-project") + t.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") + project, err := ResolveGoogleVertexProject(nil) + if err != nil || project != "env-project" { + t.Fatalf("expected env project, got project=%q err=%v", project, err) + } + location, err := ResolveGoogleVertexLocation(nil) + if err != nil || location != "us-central1" { + t.Fatalf("expected env location, got location=%q err=%v", location, err) + } + + project, err = ResolveGoogleVertexProject(&GoogleVertexOptions{Project: "opt-project"}) + if err != nil || project != "opt-project" { + t.Fatalf("expected option project override, got project=%q err=%v", project, err) + } + location, err = ResolveGoogleVertexLocation(&GoogleVertexOptions{Location: "europe-west4"}) + if err != nil || location != "europe-west4" { + t.Fatalf("expected option location override, got location=%q err=%v", location, err) + } +} From 13fff3e0f4f6786a9e2ab9fc605505523b9cb8d9 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:06:16 +0000 Subject: [PATCH 20/75] Add feature-gated pkg-ai runtime selector in connector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/response_retry.go | 17 +++--- pkg/connector/streaming_runtime_selector.go | 53 +++++++++++++++++++ .../streaming_runtime_selector_test.go | 40 ++++++++++++++ 3 files changed, 100 insertions(+), 10 deletions(-) create mode 100644 pkg/connector/streaming_runtime_selector.go create mode 100644 pkg/connector/streaming_runtime_selector_test.go diff --git a/pkg/connector/response_retry.go b/pkg/connector/response_retry.go index 9d1e3ff2..612625b6 100644 --- a/pkg/connector/response_retry.go +++ b/pkg/connector/response_retry.go @@ -383,17 +383,14 @@ func (oc *AIClient) streamingResponseWithRetry( } func (oc *AIClient) selectResponseFn(meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion) (responseFunc, string) { - // Use Chat Completions API for audio (native support) - // SDK v3.16.0 has ResponseInputAudioParam but it's not wired into the union - if hasAudioContent(prompt) { - return oc.streamChatCompletions, "chat_completions" - } - switch oc.resolveModelAPI(meta) { - case ModelAPIChatCompletions: - return oc.streamChatCompletions, "chat_completions" + path := chooseStreamingRuntimePath(hasAudioContent(prompt), oc.resolveModelAPI(meta), pkgAIRuntimeEnabled()) + switch path { + case streamingRuntimePkgAI: + return oc.streamWithPkgAIBridge, string(streamingRuntimePkgAI) + case streamingRuntimeChatCompletions: + return oc.streamChatCompletions, string(streamingRuntimeChatCompletions) default: - // Use Responses API for other content (images, files, text) - return oc.streamingResponseWithToolSchemaFallback, "responses" + return oc.streamingResponseWithToolSchemaFallback, string(streamingRuntimeResponses) } } diff --git a/pkg/connector/streaming_runtime_selector.go b/pkg/connector/streaming_runtime_selector.go new file mode 100644 index 00000000..0d2c4418 --- /dev/null +++ b/pkg/connector/streaming_runtime_selector.go @@ -0,0 +1,53 @@ +package connector + +import ( + "context" + "os" + "strings" + + "github.com/openai/openai-go/v3" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" +) + +type streamingRuntimePath string + +const ( + streamingRuntimePkgAI streamingRuntimePath = "pkg_ai" + streamingRuntimeChatCompletions streamingRuntimePath = "chat_completions" + streamingRuntimeResponses streamingRuntimePath = "responses" +) + +func pkgAIRuntimeEnabled() bool { + value := strings.ToLower(strings.TrimSpace(os.Getenv("PI_USE_PKG_AI_RUNTIME"))) + return value == "1" || value == "true" || value == "yes" || value == "on" +} + +func chooseStreamingRuntimePath(hasAudio bool, modelAPI ModelAPI, preferPkgAI bool) streamingRuntimePath { + if hasAudio { + return streamingRuntimeChatCompletions + } + if preferPkgAI { + return streamingRuntimePkgAI + } + if modelAPI == ModelAPIChatCompletions { + return streamingRuntimeChatCompletions + } + return streamingRuntimeResponses +} + +func (oc *AIClient) streamWithPkgAIBridge( + ctx context.Context, + evt *event.Event, + portal *bridgev2.Portal, + meta *PortalMetadata, + prompt []openai.ChatCompletionMessageParamUnion, +) (bool, *ContextLengthError, error) { + oc.loggerForContext(ctx).Debug().Msg("pkg/ai runtime bridge flag enabled; delegating to existing runtime path") + switch oc.resolveModelAPI(meta) { + case ModelAPIChatCompletions: + return oc.streamChatCompletions(ctx, evt, portal, meta, prompt) + default: + return oc.streamingResponseWithToolSchemaFallback(ctx, evt, portal, meta, prompt) + } +} diff --git a/pkg/connector/streaming_runtime_selector_test.go b/pkg/connector/streaming_runtime_selector_test.go new file mode 100644 index 00000000..ffb43238 --- /dev/null +++ b/pkg/connector/streaming_runtime_selector_test.go @@ -0,0 +1,40 @@ +package connector + +import "testing" + +func TestPkgAIRuntimeEnabledFromEnv(t *testing.T) { + t.Setenv("PI_USE_PKG_AI_RUNTIME", "") + if pkgAIRuntimeEnabled() { + t.Fatalf("expected runtime flag disabled by default") + } + + t.Setenv("PI_USE_PKG_AI_RUNTIME", "1") + if !pkgAIRuntimeEnabled() { + t.Fatalf("expected runtime flag enabled for value 1") + } + + t.Setenv("PI_USE_PKG_AI_RUNTIME", "true") + if !pkgAIRuntimeEnabled() { + t.Fatalf("expected runtime flag enabled for value true") + } + + t.Setenv("PI_USE_PKG_AI_RUNTIME", "off") + if pkgAIRuntimeEnabled() { + t.Fatalf("expected runtime flag disabled for value off") + } +} + +func TestChooseStreamingRuntimePath(t *testing.T) { + if got := chooseStreamingRuntimePath(true, ModelAPIResponses, true); got != streamingRuntimeChatCompletions { + t.Fatalf("expected audio to force chat completions, got %s", got) + } + if got := chooseStreamingRuntimePath(false, ModelAPIResponses, true); got != streamingRuntimePkgAI { + t.Fatalf("expected pkg_ai path when preferred and no audio, got %s", got) + } + if got := chooseStreamingRuntimePath(false, ModelAPIChatCompletions, false); got != streamingRuntimeChatCompletions { + t.Fatalf("expected chat model api path, got %s", got) + } + if got := chooseStreamingRuntimePath(false, ModelAPIResponses, false); got != streamingRuntimeResponses { + t.Fatalf("expected responses path fallback, got %s", got) + } +} From d372f48e3ca8c595c4d40c8ec02efbd61db0ad68 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:20:56 +0000 Subject: [PATCH 21/75] Add usage total-token normalization and e2e parity scaffolds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/e2e/parity_scaffolds_test.go | 66 +++++++++++++++++++++++++++++ pkg/ai/utils/usage.go | 17 ++++++++ pkg/ai/utils/usage_test.go | 41 ++++++++++++++++++ 3 files changed, 124 insertions(+) create mode 100644 pkg/ai/e2e/parity_scaffolds_test.go create mode 100644 pkg/ai/utils/usage.go create mode 100644 pkg/ai/utils/usage_test.go diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go new file mode 100644 index 00000000..614f6b7c --- /dev/null +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -0,0 +1,66 @@ +package e2e + +import ( + "os" + "testing" +) + +func requirePIAIE2E(t *testing.T) { + t.Helper() + if testing.Short() { + t.Skip("skipping e2e parity scaffolds in short mode") + } + if os.Getenv("PI_AI_E2E") == "" { + t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") + } +} + +func TestToolCallWithoutResultE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for tool-call-without-result.test.ts pending runtime implementation") +} + +func TestInterleavedThinkingE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for interleaved-thinking.test.ts pending runtime implementation") +} + +func TestBedrockModelsE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for bedrock-models.test.ts pending runtime implementation") +} + +func TestToolCallIDNormalizationE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for tool-call-id-normalization.test.ts pending runtime implementation") +} + +func TestAnthropicToolNameNormalizationE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for anthropic-tool-name-normalization.test.ts pending runtime implementation") +} + +func TestTokenStatsOnAbortE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for tokens.test.ts pending runtime implementation") +} + +func TestTotalTokensE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for total-tokens.test.ts pending runtime implementation") +} + +func TestCrossProviderHandoffE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for cross-provider-handoff.test.ts pending runtime implementation") +} + +func TestOpenAIResponsesReasoningReplayE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for openai-responses-reasoning-replay-e2e.test.ts pending runtime implementation") +} + +func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") +} diff --git a/pkg/ai/utils/usage.go b/pkg/ai/utils/usage.go new file mode 100644 index 00000000..c0e28a6b --- /dev/null +++ b/pkg/ai/utils/usage.go @@ -0,0 +1,17 @@ +package utils + +import "github.com/beeper/ai-bridge/pkg/ai" + +// NormalizeUsageTotalTokens keeps usage.totalTokens coherent with component counters. +// Some providers omit total tokens in partial/aborted responses; this computes a safe fallback. +func NormalizeUsageTotalTokens(usage ai.Usage) ai.Usage { + computed := usage.Input + usage.Output + usage.CacheRead + usage.CacheWrite + if usage.TotalTokens <= 0 { + usage.TotalTokens = computed + return usage + } + if usage.TotalTokens < computed { + usage.TotalTokens = computed + } + return usage +} diff --git a/pkg/ai/utils/usage_test.go b/pkg/ai/utils/usage_test.go new file mode 100644 index 00000000..7d46dd1f --- /dev/null +++ b/pkg/ai/utils/usage_test.go @@ -0,0 +1,41 @@ +package utils + +import ( + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestNormalizeUsageTotalTokens(t *testing.T) { + usage := NormalizeUsageTotalTokens(ai.Usage{ + Input: 100, + Output: 50, + CacheRead: 10, + CacheWrite: 5, + }) + if usage.TotalTokens != 165 { + t.Fatalf("expected computed totalTokens=165, got %d", usage.TotalTokens) + } + + usage = NormalizeUsageTotalTokens(ai.Usage{ + Input: 100, + Output: 50, + CacheRead: 10, + CacheWrite: 5, + TotalTokens: 120, + }) + if usage.TotalTokens != 165 { + t.Fatalf("expected totalTokens uplifted to components sum=165, got %d", usage.TotalTokens) + } + + usage = NormalizeUsageTotalTokens(ai.Usage{ + Input: 100, + Output: 50, + CacheRead: 10, + CacheWrite: 5, + TotalTokens: 200, + }) + if usage.TotalTokens != 200 { + t.Fatalf("expected larger totalTokens to be preserved, got %d", usage.TotalTokens) + } +} From a6e655de67a773cc19967cbff0e39ecb58f7e475 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:25:39 +0000 Subject: [PATCH 22/75] Expand Bedrock message and tool config conversion parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/amazon_bedrock.go | 243 ++++++++++++++++++++++++ pkg/ai/providers/amazon_bedrock_test.go | 140 ++++++++++++++ 2 files changed, 383 insertions(+) diff --git a/pkg/ai/providers/amazon_bedrock.go b/pkg/ai/providers/amazon_bedrock.go index 879bda78..f69fcfa4 100644 --- a/pkg/ai/providers/amazon_bedrock.go +++ b/pkg/ai/providers/amazon_bedrock.go @@ -23,6 +23,23 @@ type BedrockOptions struct { InterleavedThinking *bool } +func BuildBedrockConverseInput(model ai.Model, context ai.Context, options BedrockOptions) map[string]any { + cacheRetention := ResolveBedrockCacheRetention(options.StreamOptions.CacheRetention) + input := map[string]any{ + "modelId": model.ID, + "messages": ConvertBedrockMessages(context, model, cacheRetention), + "system": BuildBedrockSystemPrompt(context.SystemPrompt, model, cacheRetention), + "inferenceConfig": map[string]any{"maxTokens": options.StreamOptions.MaxTokens, "temperature": options.StreamOptions.Temperature}, + } + if tc := ConvertBedrockToolConfig(context.Tools, options.ToolChoice); tc != nil { + input["toolConfig"] = tc + } + if extra := BuildBedrockAdditionalModelRequestFields(model, options); extra != nil { + input["additionalModelRequestFields"] = extra + } + return input +} + func SupportsAdaptiveThinking(modelID string) bool { id := strings.ToLower(modelID) return strings.Contains(id, "opus-4-6") || @@ -77,6 +94,183 @@ func supportsBedrockPromptCaching(model ai.Model) bool { return false } +func NormalizeBedrockToolCallID(id string) string { + sanitized := strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' { + return r + } + return '_' + }, id) + if len(sanitized) > 64 { + return sanitized[:64] + } + return sanitized +} + +func ConvertBedrockMessages(context ai.Context, model ai.Model, cacheRetention ai.CacheRetention) []map[string]any { + transformed := TransformMessages(context.Messages, model, func(id string, _ ai.Model, _ ai.Message) string { + return NormalizeBedrockToolCallID(id) + }) + result := make([]map[string]any, 0, len(transformed)) + supportsThinkingSignature := supportsBedrockThinkingSignature(model) + + for i := 0; i < len(transformed); i++ { + msg := transformed[i] + switch msg.Role { + case ai.RoleUser: + content := make([]map[string]any, 0, max(1, len(msg.Content))) + if strings.TrimSpace(msg.Text) != "" { + content = append(content, map[string]any{"text": utils.SanitizeSurrogates(msg.Text)}) + } + for _, c := range msg.Content { + switch c.Type { + case ai.ContentTypeText: + if strings.TrimSpace(c.Text) == "" { + continue + } + content = append(content, map[string]any{"text": utils.SanitizeSurrogates(c.Text)}) + case ai.ContentTypeImage: + content = append(content, map[string]any{ + "image": map[string]any{ + "format": imageFormatFromMIME(c.MimeType), + "source": map[string]any{"bytes": c.Data}, + }, + }) + } + } + if len(content) == 0 { + continue + } + result = append(result, map[string]any{ + "role": "user", + "content": content, + }) + case ai.RoleAssistant: + content := make([]map[string]any, 0, len(msg.Content)) + for _, c := range msg.Content { + switch c.Type { + case ai.ContentTypeText: + if strings.TrimSpace(c.Text) == "" { + continue + } + content = append(content, map[string]any{"text": utils.SanitizeSurrogates(c.Text)}) + case ai.ContentTypeToolCall: + content = append(content, map[string]any{ + "toolUse": map[string]any{ + "toolUseId": c.ID, + "name": c.Name, + "input": c.Arguments, + }, + }) + case ai.ContentTypeThinking: + if strings.TrimSpace(c.Thinking) == "" { + continue + } + reasoningText := map[string]any{"text": utils.SanitizeSurrogates(c.Thinking)} + if supportsThinkingSignature && strings.TrimSpace(c.ThinkingSignature) != "" { + reasoningText["signature"] = c.ThinkingSignature + } + content = append(content, map[string]any{ + "reasoningContent": map[string]any{ + "reasoningText": reasoningText, + }, + }) + } + } + if len(content) == 0 { + continue + } + result = append(result, map[string]any{ + "role": "assistant", + "content": content, + }) + case ai.RoleToolResult: + toolResults := make([]map[string]any, 0, 2) + toolResults = append(toolResults, map[string]any{ + "toolResult": map[string]any{ + "toolUseId": msg.ToolCallID, + "content": bedrockToolResultContent(msg.Content), + "status": bedrockToolResultStatus(msg.IsError), + }, + }) + + j := i + 1 + for ; j < len(transformed) && transformed[j].Role == ai.RoleToolResult; j++ { + next := transformed[j] + toolResults = append(toolResults, map[string]any{ + "toolResult": map[string]any{ + "toolUseId": next.ToolCallID, + "content": bedrockToolResultContent(next.Content), + "status": bedrockToolResultStatus(next.IsError), + }, + }) + } + i = j - 1 + result = append(result, map[string]any{ + "role": "user", + "content": toolResults, + }) + } + } + + if cacheRetention != ai.CacheRetentionNone && supportsBedrockPromptCaching(model) && len(result) > 0 { + last := result[len(result)-1] + if lastRole, _ := last["role"].(string); lastRole == "user" { + content, _ := last["content"].([]map[string]any) + cachePoint := map[string]any{ + "cachePoint": map[string]any{"type": "default"}, + } + if cacheRetention == ai.CacheRetentionLong { + cachePoint["cachePoint"].(map[string]any)["ttl"] = "1h" + } + last["content"] = append(content, cachePoint) + result[len(result)-1] = last + } + } + return result +} + +func ConvertBedrockToolConfig(tools []ai.Tool, toolChoice any) map[string]any { + if len(tools) == 0 { + return nil + } + if choice, ok := toolChoice.(string); ok && strings.EqualFold(choice, "none") { + return nil + } + bedrockTools := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + bedrockTools = append(bedrockTools, map[string]any{ + "toolSpec": map[string]any{ + "name": tool.Name, + "description": tool.Description, + "inputSchema": map[string]any{"json": tool.Parameters}, + }, + }) + } + + config := map[string]any{ + "tools": bedrockTools, + } + switch choice := toolChoice.(type) { + case string: + switch strings.ToLower(strings.TrimSpace(choice)) { + case "auto": + config["toolChoice"] = map[string]any{"auto": map[string]any{}} + case "any": + config["toolChoice"] = map[string]any{"any": map[string]any{}} + } + case BedrockToolChoice: + if strings.EqualFold(choice.Type, "tool") && strings.TrimSpace(choice.Name) != "" { + config["toolChoice"] = map[string]any{"tool": map[string]any{"name": choice.Name}} + } + case *BedrockToolChoice: + if choice != nil && strings.EqualFold(choice.Type, "tool") && strings.TrimSpace(choice.Name) != "" { + config["toolChoice"] = map[string]any{"tool": map[string]any{"name": choice.Name}} + } + } + return config +} + func MapBedrockStopReason(reason string) ai.StopReason { switch strings.ToUpper(strings.TrimSpace(reason)) { case "END_TURN", "STOP_SEQUENCE": @@ -156,3 +350,52 @@ func mapBedrockThinkingEffort(level ai.ThinkingLevel) string { return "high" } } + +func supportsBedrockThinkingSignature(model ai.Model) bool { + id := strings.ToLower(model.ID) + return strings.Contains(id, "anthropic.claude") || strings.Contains(id, "anthropic/claude") +} + +func imageFormatFromMIME(mimeType string) string { + switch strings.ToLower(strings.TrimSpace(mimeType)) { + case "image/jpeg", "image/jpg": + return "jpeg" + case "image/png": + return "png" + case "image/gif": + return "gif" + case "image/webp": + return "webp" + default: + return "png" + } +} + +func bedrockToolResultStatus(isError bool) string { + if isError { + return "error" + } + return "success" +} + +func bedrockToolResultContent(content []ai.ContentBlock) []map[string]any { + out := make([]map[string]any, 0, len(content)) + for _, c := range content { + if c.Type == ai.ContentTypeImage { + out = append(out, map[string]any{ + "image": map[string]any{ + "format": imageFormatFromMIME(c.MimeType), + "source": map[string]any{"bytes": c.Data}, + }, + }) + continue + } + if c.Type == ai.ContentTypeText { + out = append(out, map[string]any{"text": utils.SanitizeSurrogates(c.Text)}) + } + } + if len(out) == 0 { + return []map[string]any{{"text": ""}} + } + return out +} diff --git a/pkg/ai/providers/amazon_bedrock_test.go b/pkg/ai/providers/amazon_bedrock_test.go index bbaf967d..f278c6be 100644 --- a/pkg/ai/providers/amazon_bedrock_test.go +++ b/pkg/ai/providers/amazon_bedrock_test.go @@ -92,3 +92,143 @@ func TestBuildBedrockAdditionalModelRequestFields(t *testing.T) { t.Fatalf("expected xhigh on opus-4-6 to map to max effort, got %#v", outputConfig["effort"]) } } + +func TestConvertBedrockMessages_GroupsToolResultsAndAddsCachePoint(t *testing.T) { + now := int64(1) + model := ai.Model{ + ID: "global.anthropic.claude-sonnet-4-5-v1:0", + Provider: "amazon-bedrock", + API: ai.APIBedrockConverse, + } + messages := ConvertBedrockMessages(ai.Context{ + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "run tools", Timestamp: now}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeToolCall, ID: "call:1", Name: "echo", Arguments: map[string]any{"x": 1}}, + }, + StopReason: ai.StopReasonToolUse, + Timestamp: now + 1, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call:1", + ToolName: "echo", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "first"}, + }, + Timestamp: now + 2, + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call:2", + ToolName: "echo", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "second"}, + }, + Timestamp: now + 3, + }, + }, + }, model, ai.CacheRetentionLong) + + if len(messages) != 3 { + t.Fatalf("expected 3 bedrock messages, got %d", len(messages)) + } + last := messages[2] + if last["role"] != "user" { + t.Fatalf("expected grouped tool results as user message, got %#v", last["role"]) + } + content := last["content"].([]map[string]any) + if len(content) < 3 { + t.Fatalf("expected grouped tool results plus cache point, got %#v", content) + } + tr0 := content[0]["toolResult"].(map[string]any) + if tr0["toolUseId"] != "call_1" { + t.Fatalf("expected normalized toolUseId, got %#v", tr0["toolUseId"]) + } + cachePoint := content[len(content)-1]["cachePoint"].(map[string]any) + if cachePoint["ttl"] != "1h" { + t.Fatalf("expected long cache ttl on last user message, got %#v", cachePoint) + } +} + +func TestConvertBedrockMessages_ThinkingSignatureSupport(t *testing.T) { + anthropicModel := ai.Model{ + ID: "global.anthropic.claude-sonnet-4-5-v1:0", + Provider: "amazon-bedrock", + API: ai.APIBedrockConverse, + } + anthropicContext := ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleAssistant, + Provider: anthropicModel.Provider, + API: anthropicModel.API, + Model: anthropicModel.ID, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeThinking, + Thinking: "reasoning", + ThinkingSignature: "sig123", + }, + }, + }, + }, + } + + anthropicMsgs := ConvertBedrockMessages(anthropicContext, anthropicModel, ai.CacheRetentionShort) + anthropicReasoning := anthropicMsgs[0]["content"].([]map[string]any)[0]["reasoningContent"].(map[string]any) + anthropicText := anthropicReasoning["reasoningText"].(map[string]any) + if anthropicText["signature"] != "sig123" { + t.Fatalf("expected anthropic model to include thinking signature, got %#v", anthropicText) + } + + otherModel := ai.Model{ + ID: "meta.llama-4-maverick", + Provider: "amazon-bedrock", + API: ai.APIBedrockConverse, + } + otherContext := ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleAssistant, + Provider: otherModel.Provider, + API: otherModel.API, + Model: otherModel.ID, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeThinking, + Thinking: "reasoning", + ThinkingSignature: "sig123", + }, + }, + }, + }, + } + otherMsgs := ConvertBedrockMessages(otherContext, otherModel, ai.CacheRetentionShort) + otherReasoning := otherMsgs[0]["content"].([]map[string]any)[0]["reasoningContent"].(map[string]any) + otherText := otherReasoning["reasoningText"].(map[string]any) + if _, ok := otherText["signature"]; ok { + t.Fatalf("expected non-anthropic model to omit signature, got %#v", otherText) + } +} + +func TestConvertBedrockToolConfig(t *testing.T) { + tools := []ai.Tool{{Name: "echo", Description: "Echo", Parameters: map[string]any{"type": "object"}}} + if cfg := ConvertBedrockToolConfig(tools, "none"); cfg != nil { + t.Fatalf("expected nil tool config for choice none") + } + + autoCfg := ConvertBedrockToolConfig(tools, "auto") + choice := autoCfg["toolChoice"].(map[string]any) + if _, ok := choice["auto"]; !ok { + t.Fatalf("expected auto tool choice in config, got %#v", choice) + } + + toolCfg := ConvertBedrockToolConfig(tools, BedrockToolChoice{Type: "tool", Name: "echo"}) + chosen := toolCfg["toolChoice"].(map[string]any)["tool"].(map[string]any) + if chosen["name"] != "echo" { + t.Fatalf("expected explicit tool choice name echo, got %#v", chosen["name"]) + } +} From f2218e163353ebea20f931d55e734d673575bf77 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:26:30 +0000 Subject: [PATCH 23/75] Expand e2e parity scaffolds for remaining TS test suites MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/e2e/parity_scaffolds_test.go | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index 614f6b7c..0997d136 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -64,3 +64,33 @@ func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") } + +func TestXhighE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for xhigh.test.ts pending runtime implementation") +} + +func TestZenE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for zen.test.ts pending runtime implementation") +} + +func TestEmptyE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for empty.test.ts pending runtime implementation") +} + +func TestImageToolResultE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for image-tool-result.test.ts pending runtime implementation") +} + +func TestGoogleGeminiCliClaudeThinkingHeaderE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for google-gemini-cli-claude-thinking-header.test.ts pending runtime implementation") +} + +func TestGithubCopilotAnthropicE2EParityScaffold(t *testing.T) { + requirePIAIE2E(t) + t.Skip("parity scaffold for github-copilot-anthropic.test.ts pending runtime implementation") +} From 49e93284947f44faf16c173779384575968efe8a Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:36:17 +0000 Subject: [PATCH 24/75] Prepare pkg-ai bridge context mapping from chat prompt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/streaming_runtime_selector.go | 108 +++++++++++++++++- .../streaming_runtime_selector_test.go | 69 ++++++++++- 2 files changed, 175 insertions(+), 2 deletions(-) diff --git a/pkg/connector/streaming_runtime_selector.go b/pkg/connector/streaming_runtime_selector.go index 0d2c4418..82490f38 100644 --- a/pkg/connector/streaming_runtime_selector.go +++ b/pkg/connector/streaming_runtime_selector.go @@ -3,8 +3,12 @@ package connector import ( "context" "os" + "strconv" "strings" + "time" + aipkg "github.com/beeper/ai-bridge/pkg/ai" + airuntime "github.com/beeper/ai-bridge/pkg/runtime" "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" @@ -43,7 +47,11 @@ func (oc *AIClient) streamWithPkgAIBridge( meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion, ) (bool, *ContextLengthError, error) { - oc.loggerForContext(ctx).Debug().Msg("pkg/ai runtime bridge flag enabled; delegating to existing runtime path") + aiContext := buildPkgAIContext(oc.effectivePrompt(meta), prompt) + oc.loggerForContext(ctx).Debug(). + Int("prompt_messages", len(prompt)). + Int("ai_messages", len(aiContext.Messages)). + Msg("pkg/ai runtime bridge flag enabled; prepared adapter context and delegating to existing runtime path") switch oc.resolveModelAPI(meta) { case ModelAPIChatCompletions: return oc.streamChatCompletions(ctx, evt, portal, meta, prompt) @@ -51,3 +59,101 @@ func (oc *AIClient) streamWithPkgAIBridge( return oc.streamingResponseWithToolSchemaFallback(ctx, evt, portal, meta, prompt) } } + +func buildPkgAIContext(systemPrompt string, prompt []openai.ChatCompletionMessageParamUnion) aipkg.Context { + unified := chatPromptToUnifiedMessages(prompt) + return toAIContext(systemPrompt, unified, nil) +} + +func chatPromptToUnifiedMessages(prompt []openai.ChatCompletionMessageParamUnion) []UnifiedMessage { + out := make([]UnifiedMessage, 0, len(prompt)) + now := time.Now().UnixMilli() + + for _, msg := range prompt { + switch { + case msg.OfUser != nil: + parts := make([]ContentPart, 0, 2) + userText := strings.TrimSpace(airuntime.ExtractUserContent(msg.OfUser.Content)) + if userText != "" { + parts = append(parts, ContentPart{Type: ContentTypeText, Text: userText}) + } + for _, part := range msg.OfUser.Content.OfArrayOfContentParts { + if part.OfImageURL != nil && strings.TrimSpace(part.OfImageURL.ImageURL.URL) != "" { + parts = append(parts, ContentPart{ + Type: ContentTypeImage, + ImageURL: strings.TrimSpace(part.OfImageURL.ImageURL.URL), + }) + } + } + if len(parts) == 0 { + continue + } + out = append(out, UnifiedMessage{ + Role: RoleUser, + Content: parts, + }) + case msg.OfAssistant != nil: + parts := make([]ContentPart, 0, 1) + assistantText := strings.TrimSpace(airuntime.ExtractAssistantContent(msg.OfAssistant.Content)) + if assistantText != "" { + parts = append(parts, ContentPart{Type: ContentTypeText, Text: assistantText}) + } + toolCalls := make([]ToolCallResult, 0, len(msg.OfAssistant.ToolCalls)) + for _, toolCall := range msg.OfAssistant.ToolCalls { + if toolCall.OfFunction == nil { + continue + } + name := strings.TrimSpace(toolCall.OfFunction.Function.Name) + if name == "" { + continue + } + toolCalls = append(toolCalls, ToolCallResult{ + ID: strings.TrimSpace(toolCall.OfFunction.ID), + Name: name, + Arguments: strings.TrimSpace(toolCall.OfFunction.Function.Arguments), + }) + } + if len(parts) == 0 && len(toolCalls) == 0 { + continue + } + out = append(out, UnifiedMessage{ + Role: RoleAssistant, + Content: parts, + ToolCalls: toolCalls, + }) + case msg.OfTool != nil: + toolText := strings.TrimSpace(airuntime.ExtractToolContent(msg.OfTool.Content)) + parts := []ContentPart{} + if toolText != "" { + parts = append(parts, ContentPart{Type: ContentTypeText, Text: toolText}) + } + out = append(out, UnifiedMessage{ + Role: RoleTool, + ToolCallID: strings.TrimSpace(msg.OfTool.ToolCallID), + Content: parts, + }) + case msg.OfSystem != nil || msg.OfDeveloper != nil: + // System/developer content is carried separately via systemPrompt in buildPkgAIContext. + continue + default: + content, role := airuntime.ExtractMessageContent(msg) + content = strings.TrimSpace(content) + if content == "" { + continue + } + switch role { + case "user": + out = append(out, UnifiedMessage{Role: RoleUser, Content: []ContentPart{{Type: ContentTypeText, Text: content}}}) + case "assistant": + out = append(out, UnifiedMessage{Role: RoleAssistant, Content: []ContentPart{{Type: ContentTypeText, Text: content}}}) + case "tool": + out = append(out, UnifiedMessage{ + Role: RoleTool, + Content: []ContentPart{{Type: ContentTypeText, Text: content}}, + ToolCallID: "tool_" + strconv.FormatInt(now, 10), + }) + } + } + } + return out +} diff --git a/pkg/connector/streaming_runtime_selector_test.go b/pkg/connector/streaming_runtime_selector_test.go index ffb43238..d9c1208c 100644 --- a/pkg/connector/streaming_runtime_selector_test.go +++ b/pkg/connector/streaming_runtime_selector_test.go @@ -1,6 +1,10 @@ package connector -import "testing" +import ( + "testing" + + "github.com/openai/openai-go/v3" +) func TestPkgAIRuntimeEnabledFromEnv(t *testing.T) { t.Setenv("PI_USE_PKG_AI_RUNTIME", "") @@ -38,3 +42,66 @@ func TestChooseStreamingRuntimePath(t *testing.T) { t.Fatalf("expected responses path fallback, got %s", got) } } + +func TestChatPromptToUnifiedMessages_ConvertsRolesAndImages(t *testing.T) { + prompt := []openai.ChatCompletionMessageParamUnion{ + openai.SystemMessage("system guidance"), + { + OfUser: &openai.ChatCompletionUserMessageParam{ + Content: openai.ChatCompletionUserMessageParamContentUnion{ + OfArrayOfContentParts: []openai.ChatCompletionContentPartUnionParam{ + { + OfText: &openai.ChatCompletionContentPartTextParam{ + Text: "look at image", + }, + }, + { + OfImageURL: &openai.ChatCompletionContentPartImageParam{ + ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ + URL: "https://example.com/image.png", + }, + }, + }, + }, + }, + }, + }, + openai.AssistantMessage("ack"), + openai.ToolMessage("tool output", "call_1"), + } + + unified := chatPromptToUnifiedMessages(prompt) + if len(unified) != 3 { + t.Fatalf("expected three non-system unified messages, got %d", len(unified)) + } + if unified[0].Role != RoleUser { + t.Fatalf("expected first role user, got %s", unified[0].Role) + } + if len(unified[0].Content) < 2 || unified[0].Content[1].Type != ContentTypeImage { + t.Fatalf("expected user message to include image content part, got %#v", unified[0].Content) + } + if unified[1].Role != RoleAssistant || unified[1].Text() != "ack" { + t.Fatalf("expected assistant text mapping, got %#v", unified[1]) + } + if unified[2].Role != RoleTool || unified[2].ToolCallID != "call_1" { + t.Fatalf("expected tool mapping with tool_call_id, got %#v", unified[2]) + } +} + +func TestBuildPkgAIContext_UsesSystemPromptAndMappedMessages(t *testing.T) { + prompt := []openai.ChatCompletionMessageParamUnion{ + openai.SystemMessage("inline system"), + openai.UserMessage("hello"), + openai.AssistantMessage("hi"), + } + ctx := buildPkgAIContext("effective system prompt", prompt) + if ctx.SystemPrompt != "effective system prompt" { + t.Fatalf("expected explicit effective system prompt in ai context, got %q", ctx.SystemPrompt) + } + if len(ctx.Messages) != 2 { + t.Fatalf("expected 2 mapped messages (system stripped), got %d", len(ctx.Messages)) + } + if ctx.Messages[0].Role != "user" || ctx.Messages[1].Role != "assistant" { + t.Fatalf("unexpected mapped roles: %#v", ctx.Messages) + } +} From b2ef0fa026b1530f69889ba637de426a604291b0 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:40:21 +0000 Subject: [PATCH 25/75] Add pkg-ai event-to-stream adapter mappings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/pkg_ai_event_adapter.go | 111 +++++++++++++++++ pkg/connector/pkg_ai_event_adapter_test.go | 131 +++++++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 pkg/connector/pkg_ai_event_adapter.go create mode 100644 pkg/connector/pkg_ai_event_adapter_test.go diff --git a/pkg/connector/pkg_ai_event_adapter.go b/pkg/connector/pkg_ai_event_adapter.go new file mode 100644 index 00000000..3123fea4 --- /dev/null +++ b/pkg/connector/pkg_ai_event_adapter.go @@ -0,0 +1,111 @@ +package connector + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + + aipkg "github.com/beeper/ai-bridge/pkg/ai" +) + +func aiUsageToConnectorUsage(usage aipkg.Usage) *UsageInfo { + total := usage.TotalTokens + computed := usage.Input + usage.Output + usage.CacheRead + usage.CacheWrite + if total <= 0 || total < computed { + total = computed + } + if usage.Input == 0 && usage.Output == 0 && usage.CacheRead == 0 && usage.CacheWrite == 0 && total == 0 { + return nil + } + return &UsageInfo{ + PromptTokens: usage.Input + usage.CacheRead + usage.CacheWrite, + CompletionTokens: usage.Output, + TotalTokens: total, + } +} + +func aiEventToStreamEvent(event aipkg.AssistantMessageEvent) (StreamEvent, bool) { + switch event.Type { + case aipkg.EventTextDelta: + return StreamEvent{ + Type: StreamEventDelta, + Delta: event.Delta, + }, true + case aipkg.EventThinkingDelta: + return StreamEvent{ + Type: StreamEventReasoning, + ReasoningDelta: event.Delta, + }, true + case aipkg.EventToolCallEnd: + toolCall := event.ToolCall + if toolCall == nil && event.ContentIndex >= 0 && event.ContentIndex < len(event.Partial.Content) { + toolCall = &event.Partial.Content[event.ContentIndex] + } + if toolCall == nil { + return StreamEvent{}, false + } + args := "{}" + if toolCall.Arguments != nil { + if raw, err := json.Marshal(toolCall.Arguments); err == nil { + args = string(raw) + } + } + return StreamEvent{ + Type: StreamEventToolCall, + ToolCall: &ToolCallResult{ + ID: strings.TrimSpace(toolCall.ID), + Name: strings.TrimSpace(toolCall.Name), + Arguments: args, + }, + }, true + case aipkg.EventDone: + reason := strings.TrimSpace(string(event.Reason)) + if reason == "" { + reason = strings.TrimSpace(string(event.Message.StopReason)) + } + return StreamEvent{ + Type: StreamEventComplete, + FinishReason: reason, + Usage: aiUsageToConnectorUsage(event.Message.Usage), + }, true + case aipkg.EventError: + errText := strings.TrimSpace(event.Error.ErrorMessage) + if errText == "" { + errText = "pkg/ai stream error" + } + return StreamEvent{ + Type: StreamEventError, + Error: fmt.Errorf("%s", errText), + }, true + default: + return StreamEvent{}, false + } +} + +func streamEventsFromAIStream(ctx context.Context, stream *aipkg.AssistantMessageEventStream) <-chan StreamEvent { + events := make(chan StreamEvent, 64) + go func() { + defer close(events) + for { + event, err := stream.Next(ctx) + if err != nil { + if err != io.EOF && err != context.Canceled { + events <- StreamEvent{ + Type: StreamEventError, + Error: err, + } + } + return + } + if mapped, ok := aiEventToStreamEvent(event); ok { + events <- mapped + } + if event.Type == aipkg.EventDone || event.Type == aipkg.EventError { + return + } + } + }() + return events +} diff --git a/pkg/connector/pkg_ai_event_adapter_test.go b/pkg/connector/pkg_ai_event_adapter_test.go new file mode 100644 index 00000000..b1b89130 --- /dev/null +++ b/pkg/connector/pkg_ai_event_adapter_test.go @@ -0,0 +1,131 @@ +package connector + +import ( + "context" + "encoding/json" + "testing" + "time" + + aipkg "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestAIUsageToConnectorUsage(t *testing.T) { + if usage := aiUsageToConnectorUsage(aipkg.Usage{}); usage != nil { + t.Fatalf("expected nil usage for all-zero usage input") + } + + usage := aiUsageToConnectorUsage(aipkg.Usage{ + Input: 100, + Output: 40, + CacheRead: 20, + CacheWrite: 10, + TotalTokens: 120, + }) + if usage == nil { + t.Fatalf("expected mapped usage") + } + if usage.PromptTokens != 130 { + t.Fatalf("expected prompt tokens input+cache=130, got %d", usage.PromptTokens) + } + if usage.CompletionTokens != 40 { + t.Fatalf("expected completion tokens 40, got %d", usage.CompletionTokens) + } + if usage.TotalTokens != 170 { + t.Fatalf("expected total tokens uplifted to computed sum 170, got %d", usage.TotalTokens) + } +} + +func TestAIEventToStreamEvent_Mapping(t *testing.T) { + if evt, ok := aiEventToStreamEvent(aipkg.AssistantMessageEvent{ + Type: aipkg.EventTextDelta, + Delta: "abc", + }); !ok || evt.Type != StreamEventDelta || evt.Delta != "abc" { + t.Fatalf("unexpected text delta mapping: ok=%v evt=%#v", ok, evt) + } + + if evt, ok := aiEventToStreamEvent(aipkg.AssistantMessageEvent{ + Type: aipkg.EventThinkingDelta, + Delta: "reason", + }); !ok || evt.Type != StreamEventReasoning || evt.ReasoningDelta != "reason" { + t.Fatalf("unexpected thinking delta mapping: ok=%v evt=%#v", ok, evt) + } + + toolEvent, ok := aiEventToStreamEvent(aipkg.AssistantMessageEvent{ + Type: aipkg.EventToolCallEnd, + ToolCall: &aipkg.ContentBlock{ + Type: aipkg.ContentTypeToolCall, + ID: "call_1", + Name: "search", + Arguments: map[string]any{"q": "golang"}, + }, + }) + if !ok || toolEvent.Type != StreamEventToolCall || toolEvent.ToolCall == nil { + t.Fatalf("unexpected tool call mapping: ok=%v evt=%#v", ok, toolEvent) + } + args := map[string]any{} + if err := json.Unmarshal([]byte(toolEvent.ToolCall.Arguments), &args); err != nil { + t.Fatalf("expected tool args JSON, got err=%v", err) + } + if args["q"] != "golang" { + t.Fatalf("expected tool arg q=golang, got %#v", args) + } + + doneEvent, ok := aiEventToStreamEvent(aipkg.AssistantMessageEvent{ + Type: aipkg.EventDone, + Reason: aipkg.StopReasonStop, + Message: aipkg.Message{ + Usage: aipkg.Usage{ + Input: 10, + Output: 5, + CacheRead: 2, + CacheWrite: 1, + }, + }, + }) + if !ok || doneEvent.Type != StreamEventComplete || doneEvent.FinishReason != "stop" { + t.Fatalf("unexpected done mapping: ok=%v evt=%#v", ok, doneEvent) + } + if doneEvent.Usage == nil || doneEvent.Usage.TotalTokens != 18 { + t.Fatalf("expected mapped usage total=18, got %#v", doneEvent.Usage) + } + + errEvent, ok := aiEventToStreamEvent(aipkg.AssistantMessageEvent{ + Type: aipkg.EventError, + Error: aipkg.Message{ErrorMessage: "boom"}, + }) + if !ok || errEvent.Type != StreamEventError || errEvent.Error == nil || errEvent.Error.Error() != "boom" { + t.Fatalf("unexpected error mapping: ok=%v evt=%#v", ok, errEvent) + } +} + +func TestStreamEventsFromAIStream(t *testing.T) { + stream := aipkg.NewAssistantMessageEventStream(8) + go func() { + stream.Push(aipkg.AssistantMessageEvent{Type: aipkg.EventStart}) + stream.Push(aipkg.AssistantMessageEvent{Type: aipkg.EventTextDelta, Delta: "hello"}) + stream.Push(aipkg.AssistantMessageEvent{ + Type: aipkg.EventDone, + Reason: aipkg.StopReasonStop, + Message: aipkg.Message{ + Usage: aipkg.Usage{Input: 1, Output: 1, TotalTokens: 2}, + }, + }) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + out := streamEventsFromAIStream(ctx, stream) + var collected []StreamEvent + for evt := range out { + collected = append(collected, evt) + } + if len(collected) != 2 { + t.Fatalf("expected 2 mapped events (delta + complete), got %d", len(collected)) + } + if collected[0].Type != StreamEventDelta || collected[0].Delta != "hello" { + t.Fatalf("unexpected first mapped event: %#v", collected[0]) + } + if collected[1].Type != StreamEventComplete || collected[1].FinishReason != "stop" { + t.Fatalf("unexpected completion mapped event: %#v", collected[1]) + } +} From a3fe6820b8d9079fd883cd98faa6fcd4adc7891a Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:42:17 +0000 Subject: [PATCH 26/75] Add OAuth helper parity for Anthropic and GitHub Copilot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/oauth/anthropic.go | 53 +++++++++++++++++ .../oauth/anthropic_provider_helpers_test.go | 41 +++++++++++++ pkg/ai/oauth/github_copilot.go | 57 +++++++++++++++++++ .../github_copilot_provider_helpers_test.go | 33 +++++++++++ 4 files changed, 184 insertions(+) create mode 100644 pkg/ai/oauth/anthropic.go create mode 100644 pkg/ai/oauth/anthropic_provider_helpers_test.go create mode 100644 pkg/ai/oauth/github_copilot.go create mode 100644 pkg/ai/oauth/github_copilot_provider_helpers_test.go diff --git a/pkg/ai/oauth/anthropic.go b/pkg/ai/oauth/anthropic.go new file mode 100644 index 00000000..af4d3ed8 --- /dev/null +++ b/pkg/ai/oauth/anthropic.go @@ -0,0 +1,53 @@ +package oauth + +import ( + "net/url" + "strings" + "time" +) + +const ( + anthropicAuthorizeURL = "https://claude.ai/oauth/authorize" + anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token" + anthropicRedirectURI = "https://console.anthropic.com/oauth/code/callback" + anthropicScopes = "org:create_api_key user:profile user:inference" + anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" +) + +func AnthropicClientID() string { + return anthropicClientID +} + +func AnthropicTokenURL() string { + return anthropicTokenURL +} + +func BuildAnthropicAuthorizeURL(codeChallenge string, state string) string { + params := url.Values{} + params.Set("code", "true") + params.Set("client_id", anthropicClientID) + params.Set("response_type", "code") + params.Set("redirect_uri", anthropicRedirectURI) + params.Set("scope", anthropicScopes) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + params.Set("state", state) + return anthropicAuthorizeURL + "?" + params.Encode() +} + +func ParseAnthropicAuthorizationCode(input string) (code string, state string) { + raw := strings.TrimSpace(input) + if raw == "" { + return "", "" + } + parts := strings.SplitN(raw, "#", 2) + code = strings.TrimSpace(parts[0]) + if len(parts) == 2 { + state = strings.TrimSpace(parts[1]) + } + return code, state +} + +func OAuthExpiryWithBuffer(now time.Time, expiresInSeconds int64) int64 { + return now.Add(time.Duration(expiresInSeconds)*time.Second - 5*time.Minute).UnixMilli() +} diff --git a/pkg/ai/oauth/anthropic_provider_helpers_test.go b/pkg/ai/oauth/anthropic_provider_helpers_test.go new file mode 100644 index 00000000..4c4d01c7 --- /dev/null +++ b/pkg/ai/oauth/anthropic_provider_helpers_test.go @@ -0,0 +1,41 @@ +package oauth + +import ( + "net/url" + "testing" + "time" +) + +func TestAnthropicOAuthHelperFunctions(t *testing.T) { + authURL := BuildAnthropicAuthorizeURL("challenge", "state-verifier") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("expected valid auth URL, got error: %v", err) + } + query := parsed.Query() + if query.Get("client_id") != AnthropicClientID() { + t.Fatalf("expected anthropic client id in query") + } + if query.Get("code_challenge") != "challenge" { + t.Fatalf("expected code_challenge in query") + } + if query.Get("state") != "state-verifier" { + t.Fatalf("expected state in query") + } + + code, state := ParseAnthropicAuthorizationCode("abc123#xyz") + if code != "abc123" || state != "xyz" { + t.Fatalf("unexpected parsed code/state: code=%q state=%q", code, state) + } + code, state = ParseAnthropicAuthorizationCode("abc123") + if code != "abc123" || state != "" { + t.Fatalf("unexpected parsed code/state for no-state input: code=%q state=%q", code, state) + } + + now := time.UnixMilli(1_700_000_000_000) + expires := OAuthExpiryWithBuffer(now, 3600) + expected := now.Add(55 * time.Minute).UnixMilli() + if expires != expected { + t.Fatalf("expected expiry with 5m buffer %d, got %d", expected, expires) + } +} diff --git a/pkg/ai/oauth/github_copilot.go b/pkg/ai/oauth/github_copilot.go new file mode 100644 index 00000000..f4cdb301 --- /dev/null +++ b/pkg/ai/oauth/github_copilot.go @@ -0,0 +1,57 @@ +package oauth + +import ( + "net/url" + "regexp" + "strings" +) + +const defaultGitHubDomain = "github.com" + +var proxyEndpointPattern = regexp.MustCompile(`proxy-ep=([^;]+)`) + +func NormalizeDomain(input string) string { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return "" + } + rawURL := trimmed + if !strings.Contains(rawURL, "://") { + rawURL = "https://" + rawURL + } + parsed, err := url.Parse(rawURL) + if err != nil || strings.TrimSpace(parsed.Hostname()) == "" { + return "" + } + return strings.ToLower(strings.TrimSpace(parsed.Hostname())) +} + +func getBaseURLFromCopilotToken(token string) string { + matches := proxyEndpointPattern.FindStringSubmatch(token) + if len(matches) != 2 { + return "" + } + proxyHost := strings.TrimSpace(matches[1]) + if proxyHost == "" { + return "" + } + apiHost := strings.TrimPrefix(proxyHost, "proxy.") + return "https://api." + apiHost +} + +func GetGitHubCopilotBaseURL(token string, enterpriseDomain string) string { + if fromToken := getBaseURLFromCopilotToken(token); fromToken != "" { + return fromToken + } + if normalizedEnterprise := NormalizeDomain(enterpriseDomain); normalizedEnterprise != "" { + return "https://copilot-api." + normalizedEnterprise + } + return "https://api.individual.githubcopilot.com" +} + +func ResolveGitHubDomain(enterpriseDomain string) string { + if normalized := NormalizeDomain(enterpriseDomain); normalized != "" { + return normalized + } + return defaultGitHubDomain +} diff --git a/pkg/ai/oauth/github_copilot_provider_helpers_test.go b/pkg/ai/oauth/github_copilot_provider_helpers_test.go new file mode 100644 index 00000000..7d9c9211 --- /dev/null +++ b/pkg/ai/oauth/github_copilot_provider_helpers_test.go @@ -0,0 +1,33 @@ +package oauth + +import "testing" + +func TestGitHubCopilotOAuthHelperFunctions(t *testing.T) { + if got := NormalizeDomain(" https://Company.GHE.com/path "); got != "company.ghe.com" { + t.Fatalf("expected normalized enterprise domain, got %q", got) + } + if got := NormalizeDomain("github.com"); got != "github.com" { + t.Fatalf("expected plain hostname normalization, got %q", got) + } + if got := NormalizeDomain("://bad-url"); got != "" { + t.Fatalf("expected invalid URL to normalize to empty string, got %q", got) + } + + token := "tid=abc;exp=123;proxy-ep=proxy.individual.githubcopilot.com;foo=bar" + if got := GetGitHubCopilotBaseURL(token, ""); got != "https://api.individual.githubcopilot.com" { + t.Fatalf("expected base URL from token proxy endpoint, got %q", got) + } + if got := GetGitHubCopilotBaseURL("", "ghe.example.com"); got != "https://copilot-api.ghe.example.com" { + t.Fatalf("expected enterprise fallback base URL, got %q", got) + } + if got := GetGitHubCopilotBaseURL("", ""); got != "https://api.individual.githubcopilot.com" { + t.Fatalf("expected default individual base URL, got %q", got) + } + + if got := ResolveGitHubDomain("enterprise.github.local"); got != "enterprise.github.local" { + t.Fatalf("expected enterprise domain resolution, got %q", got) + } + if got := ResolveGitHubDomain(""); got != "github.com" { + t.Fatalf("expected default github.com domain, got %q", got) + } +} From 5d325f079f920551e959ecd64b87a9ed87b60f6d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:43:17 +0000 Subject: [PATCH 27/75] Add connector pkg-ai model descriptor mapping helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/pkg_ai_model_adapter.go | 52 ++++++++++++++++++ pkg/connector/pkg_ai_model_adapter_test.go | 59 +++++++++++++++++++++ pkg/connector/streaming_runtime_selector.go | 16 +++++- 3 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 pkg/connector/pkg_ai_model_adapter.go create mode 100644 pkg/connector/pkg_ai_model_adapter_test.go diff --git a/pkg/connector/pkg_ai_model_adapter.go b/pkg/connector/pkg_ai_model_adapter.go new file mode 100644 index 00000000..5d81bee3 --- /dev/null +++ b/pkg/connector/pkg_ai_model_adapter.go @@ -0,0 +1,52 @@ +package connector + +import ( + "strings" + + aipkg "github.com/beeper/ai-bridge/pkg/ai" +) + +func derivePkgAIAPI(provider string, modelAPI ModelAPI) aipkg.Api { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "anthropic": + return aipkg.APIAnthropicMessages + case "google": + return aipkg.APIGoogleGenerativeAI + case "google-vertex": + return aipkg.APIGoogleVertex + case "google-gemini-cli", "google-antigravity": + return aipkg.APIGoogleGeminiCLI + case "azure-openai-responses": + return aipkg.APIAzureOpenAIResponse + case "amazon-bedrock": + return aipkg.APIBedrockConverse + case "openai-codex": + return aipkg.APIOpenAICodexResponse + } + + if modelAPI == ModelAPIChatCompletions { + return aipkg.APIOpenAICompletions + } + return aipkg.APIOpenAIResponses +} + +func derivePkgAIModelDescriptor( + effectiveModel string, + effectiveModelForAPI string, + provider string, + modelAPI ModelAPI, + maxTokens int, +) aipkg.Model { + name := strings.TrimSpace(effectiveModel) + if name == "" { + name = strings.TrimSpace(effectiveModelForAPI) + } + return aipkg.Model{ + ID: strings.TrimSpace(effectiveModelForAPI), + Name: name, + Provider: aipkg.Provider(strings.TrimSpace(provider)), + API: derivePkgAIAPI(provider, modelAPI), + Input: []string{"text"}, + MaxTokens: maxTokens, + } +} diff --git a/pkg/connector/pkg_ai_model_adapter_test.go b/pkg/connector/pkg_ai_model_adapter_test.go new file mode 100644 index 00000000..319aa90d --- /dev/null +++ b/pkg/connector/pkg_ai_model_adapter_test.go @@ -0,0 +1,59 @@ +package connector + +import ( + "testing" + + aipkg "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestDerivePkgAIAPI(t *testing.T) { + cases := []struct { + provider string + modelAPI ModelAPI + want aipkg.Api + }{ + {provider: "anthropic", modelAPI: ModelAPIResponses, want: aipkg.APIAnthropicMessages}, + {provider: "google", modelAPI: ModelAPIResponses, want: aipkg.APIGoogleGenerativeAI}, + {provider: "google-vertex", modelAPI: ModelAPIResponses, want: aipkg.APIGoogleVertex}, + {provider: "google-gemini-cli", modelAPI: ModelAPIResponses, want: aipkg.APIGoogleGeminiCLI}, + {provider: "google-antigravity", modelAPI: ModelAPIResponses, want: aipkg.APIGoogleGeminiCLI}, + {provider: "azure-openai-responses", modelAPI: ModelAPIResponses, want: aipkg.APIAzureOpenAIResponse}, + {provider: "amazon-bedrock", modelAPI: ModelAPIResponses, want: aipkg.APIBedrockConverse}, + {provider: "openai-codex", modelAPI: ModelAPIResponses, want: aipkg.APIOpenAICodexResponse}, + {provider: "openai", modelAPI: ModelAPIChatCompletions, want: aipkg.APIOpenAICompletions}, + {provider: "openrouter", modelAPI: ModelAPIResponses, want: aipkg.APIOpenAIResponses}, + } + for _, tc := range cases { + if got := derivePkgAIAPI(tc.provider, tc.modelAPI); got != tc.want { + t.Fatalf("derivePkgAIAPI(%q,%q)=%q want=%q", tc.provider, tc.modelAPI, got, tc.want) + } + } +} + +func TestDerivePkgAIModelDescriptor(t *testing.T) { + model := derivePkgAIModelDescriptor( + "anthropic/claude-sonnet-4-5", + "claude-sonnet-4-5", + "anthropic", + ModelAPIResponses, + 64000, + ) + if model.ID != "claude-sonnet-4-5" { + t.Fatalf("expected api model id to be set, got %q", model.ID) + } + if model.Name != "anthropic/claude-sonnet-4-5" { + t.Fatalf("expected display model name from effectiveModel, got %q", model.Name) + } + if model.Provider != "anthropic" { + t.Fatalf("expected provider anthropic, got %q", model.Provider) + } + if model.API != aipkg.APIAnthropicMessages { + t.Fatalf("expected anthropic API mapping, got %q", model.API) + } + if model.MaxTokens != 64000 { + t.Fatalf("expected max tokens propagated, got %d", model.MaxTokens) + } + if len(model.Input) != 1 || model.Input[0] != "text" { + t.Fatalf("expected text-only input default, got %#v", model.Input) + } +} diff --git a/pkg/connector/streaming_runtime_selector.go b/pkg/connector/streaming_runtime_selector.go index 82490f38..60d04451 100644 --- a/pkg/connector/streaming_runtime_selector.go +++ b/pkg/connector/streaming_runtime_selector.go @@ -48,10 +48,24 @@ func (oc *AIClient) streamWithPkgAIBridge( prompt []openai.ChatCompletionMessageParamUnion, ) (bool, *ContextLengthError, error) { aiContext := buildPkgAIContext(oc.effectivePrompt(meta), prompt) + providerName := "" + if loginMeta := loginMetadata(oc.UserLogin); loginMeta != nil { + providerName = string(loginMeta.Provider) + } + aiModel := derivePkgAIModelDescriptor( + oc.effectiveModel(meta), + oc.effectiveModelForAPI(meta), + providerName, + oc.resolveModelAPI(meta), + oc.effectiveMaxTokens(meta), + ) oc.loggerForContext(ctx).Debug(). Int("prompt_messages", len(prompt)). Int("ai_messages", len(aiContext.Messages)). - Msg("pkg/ai runtime bridge flag enabled; prepared adapter context and delegating to existing runtime path") + Str("ai_model_api", string(aiModel.API)). + Str("ai_model_provider", string(aiModel.Provider)). + Str("ai_model_id", aiModel.ID). + Msg("pkg/ai runtime bridge flag enabled; prepared adapter context/model and delegating to existing runtime path") switch oc.resolveModelAPI(meta) { case ModelAPIChatCompletions: return oc.streamChatCompletions(ctx, evt, portal, meta, prompt) From b785d25a939ac2baee2fab12770a1fae696a355f Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:53:28 +0000 Subject: [PATCH 28/75] Add OAuth helper scaffolds for Gemini CLI Antigravity and Codex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/oauth/google_antigravity.go | 72 ++++++++++++++ .../oauth/google_antigravity_helpers_test.go | 37 +++++++ pkg/ai/oauth/google_gemini_cli.go | 88 +++++++++++++++++ .../oauth/google_gemini_cli_helpers_test.go | 41 ++++++++ pkg/ai/oauth/openai_codex.go | 96 +++++++++++++++++++ pkg/ai/oauth/openai_codex_helpers_test.go | 56 +++++++++++ 6 files changed, 390 insertions(+) create mode 100644 pkg/ai/oauth/google_antigravity.go create mode 100644 pkg/ai/oauth/google_antigravity_helpers_test.go create mode 100644 pkg/ai/oauth/google_gemini_cli.go create mode 100644 pkg/ai/oauth/google_gemini_cli_helpers_test.go create mode 100644 pkg/ai/oauth/openai_codex.go create mode 100644 pkg/ai/oauth/openai_codex_helpers_test.go diff --git a/pkg/ai/oauth/google_antigravity.go b/pkg/ai/oauth/google_antigravity.go new file mode 100644 index 00000000..0cc621c1 --- /dev/null +++ b/pkg/ai/oauth/google_antigravity.go @@ -0,0 +1,72 @@ +package oauth + +import ( + "net/url" + "strings" +) + +const ( + antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + antigravityClientSecret = "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" + antigravityRedirectURI = "http://localhost:51121/oauth-callback" + antigravityDefaultProject = "rising-fact-p41fc" +) + +var antigravityScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", +} + +func AntigravityClientID() string { + return antigravityClientID +} + +func AntigravityClientSecret() string { + return antigravityClientSecret +} + +func AntigravityRedirectURI() string { + return antigravityRedirectURI +} + +func AntigravityDefaultProjectID() string { + return antigravityDefaultProject +} + +func BuildAntigravityAuthorizeURL(codeChallenge string, state string) string { + params := url.Values{} + params.Set("client_id", antigravityClientID) + params.Set("response_type", "code") + params.Set("redirect_uri", antigravityRedirectURI) + params.Set("scope", strings.Join(antigravityScopes, " ")) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + params.Set("state", state) + params.Set("access_type", "offline") + params.Set("prompt", "consent") + return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() +} + +func ResolveAntigravityProjectID(loadCodeAssistPayload map[string]any) string { + if loadCodeAssistPayload == nil { + return antigravityDefaultProject + } + if raw, ok := loadCodeAssistPayload["cloudaicompanionProject"]; ok { + switch value := raw.(type) { + case string: + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + case map[string]any: + if idRaw, ok := value["id"].(string); ok { + if trimmed := strings.TrimSpace(idRaw); trimmed != "" { + return trimmed + } + } + } + } + return antigravityDefaultProject +} diff --git a/pkg/ai/oauth/google_antigravity_helpers_test.go b/pkg/ai/oauth/google_antigravity_helpers_test.go new file mode 100644 index 00000000..2775f2a9 --- /dev/null +++ b/pkg/ai/oauth/google_antigravity_helpers_test.go @@ -0,0 +1,37 @@ +package oauth + +import ( + "net/url" + "testing" +) + +func TestGoogleAntigravityOAuthHelpers(t *testing.T) { + authURL := BuildAntigravityAuthorizeURL("challenge", "state") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("expected valid antigravity auth URL, got err=%v", err) + } + query := parsed.Query() + if query.Get("client_id") != AntigravityClientID() { + t.Fatalf("expected antigravity client id in query") + } + if query.Get("redirect_uri") != AntigravityRedirectURI() { + t.Fatalf("expected antigravity redirect uri in query") + } + if query.Get("code_challenge") != "challenge" || query.Get("state") != "state" { + t.Fatalf("expected challenge/state in antigravity auth query") + } + + if got := ResolveAntigravityProjectID(nil); got != AntigravityDefaultProjectID() { + t.Fatalf("expected default project for nil payload, got %q", got) + } + if got := ResolveAntigravityProjectID(map[string]any{"cloudaicompanionProject": "explicit-project"}); got != "explicit-project" { + t.Fatalf("expected explicit string project id, got %q", got) + } + if got := ResolveAntigravityProjectID(map[string]any{"cloudaicompanionProject": map[string]any{"id": "nested-project"}}); got != "nested-project" { + t.Fatalf("expected nested project id, got %q", got) + } + if got := ResolveAntigravityProjectID(map[string]any{"cloudaicompanionProject": map[string]any{"id": ""}}); got != AntigravityDefaultProjectID() { + t.Fatalf("expected default project when nested id empty, got %q", got) + } +} diff --git a/pkg/ai/oauth/google_gemini_cli.go b/pkg/ai/oauth/google_gemini_cli.go new file mode 100644 index 00000000..126b6829 --- /dev/null +++ b/pkg/ai/oauth/google_gemini_cli.go @@ -0,0 +1,88 @@ +package oauth + +import ( + "encoding/json" + "net/url" + "strings" +) + +const ( + geminiCliClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + geminiCliClientSecret = "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" + geminiCliRedirectURI = "http://localhost:8085/oauth2callback" + geminiCliAuthURL = "https://accounts.google.com/o/oauth2/v2/auth" + geminiCliTokenURL = "https://oauth2.googleapis.com/token" +) + +var geminiCliScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", +} + +func GeminiCliClientID() string { + return geminiCliClientID +} + +func GeminiCliClientSecret() string { + return geminiCliClientSecret +} + +func GeminiCliRedirectURI() string { + return geminiCliRedirectURI +} + +func GeminiCliTokenURL() string { + return geminiCliTokenURL +} + +func BuildGeminiCliAuthorizeURL(codeChallenge string, state string) string { + params := url.Values{} + params.Set("client_id", geminiCliClientID) + params.Set("response_type", "code") + params.Set("redirect_uri", geminiCliRedirectURI) + params.Set("scope", strings.Join(geminiCliScopes, " ")) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + params.Set("state", state) + params.Set("access_type", "offline") + params.Set("prompt", "consent") + return geminiCliAuthURL + "?" + params.Encode() +} + +func ParseOAuthRedirectURL(input string) (code string, state string) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return "", "" + } + parsed, err := url.Parse(trimmed) + if err != nil { + return "", "" + } + return strings.TrimSpace(parsed.Query().Get("code")), strings.TrimSpace(parsed.Query().Get("state")) +} + +func BuildGoogleOAuthAPIKey(accessToken string, projectID string) (string, error) { + payload := map[string]string{ + "token": strings.TrimSpace(accessToken), + "projectId": strings.TrimSpace(projectID), + } + raw, err := json.Marshal(payload) + if err != nil { + return "", err + } + return string(raw), nil +} + +func ParseGoogleOAuthAPIKey(apiKey string) (token string, projectID string, ok bool) { + var payload struct { + Token string `json:"token"` + ProjectID string `json:"projectId"` + } + if err := json.Unmarshal([]byte(apiKey), &payload); err != nil { + return "", "", false + } + token = strings.TrimSpace(payload.Token) + projectID = strings.TrimSpace(payload.ProjectID) + return token, projectID, token != "" && projectID != "" +} diff --git a/pkg/ai/oauth/google_gemini_cli_helpers_test.go b/pkg/ai/oauth/google_gemini_cli_helpers_test.go new file mode 100644 index 00000000..ff48c13d --- /dev/null +++ b/pkg/ai/oauth/google_gemini_cli_helpers_test.go @@ -0,0 +1,41 @@ +package oauth + +import ( + "net/url" + "testing" +) + +func TestGoogleGeminiCliOAuthHelpers(t *testing.T) { + authURL := BuildGeminiCliAuthorizeURL("challenge", "state") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("expected valid gemini-cli auth URL, got err=%v", err) + } + query := parsed.Query() + if query.Get("client_id") != GeminiCliClientID() { + t.Fatalf("expected gemini client id in query") + } + if query.Get("redirect_uri") != GeminiCliRedirectURI() { + t.Fatalf("expected redirect uri in query") + } + if query.Get("code_challenge") != "challenge" || query.Get("state") != "state" { + t.Fatalf("expected challenge/state in query") + } + + code, state := ParseOAuthRedirectURL("http://localhost:8085/oauth2callback?code=abc&state=xyz") + if code != "abc" || state != "xyz" { + t.Fatalf("expected parsed redirect code/state, got code=%q state=%q", code, state) + } + + apiKey, err := BuildGoogleOAuthAPIKey("token-1", "project-1") + if err != nil { + t.Fatalf("unexpected error building google oauth api key: %v", err) + } + token, projectID, ok := ParseGoogleOAuthAPIKey(apiKey) + if !ok || token != "token-1" || projectID != "project-1" { + t.Fatalf("unexpected parsed google oauth api key: token=%q project=%q ok=%v", token, projectID, ok) + } + if _, _, ok := ParseGoogleOAuthAPIKey("{invalid-json"); ok { + t.Fatalf("expected invalid json api key parse to fail") + } +} diff --git a/pkg/ai/oauth/openai_codex.go b/pkg/ai/oauth/openai_codex.go new file mode 100644 index 00000000..0e5b755b --- /dev/null +++ b/pkg/ai/oauth/openai_codex.go @@ -0,0 +1,96 @@ +package oauth + +import ( + "encoding/base64" + "encoding/json" + "net/url" + "strings" +) + +const ( + openAICodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + openAICodexAuthorize = "https://auth.openai.com/oauth/authorize" + openAICodexTokenURL = "https://auth.openai.com/oauth/token" + openAICodexRedirectURI = "http://localhost:1455/auth/callback" + openAICodexScope = "openid profile email offline_access" + openAICodexJWTClaim = "https://api.openai.com/auth" +) + +func OpenAICodexClientID() string { + return openAICodexClientID +} + +func OpenAICodexTokenURL() string { + return openAICodexTokenURL +} + +func BuildOpenAICodexAuthorizeURL(codeChallenge string, state string, originator string) string { + if strings.TrimSpace(originator) == "" { + originator = "pi" + } + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", openAICodexClientID) + params.Set("redirect_uri", openAICodexRedirectURI) + params.Set("scope", openAICodexScope) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + params.Set("state", state) + params.Set("id_token_add_organizations", "true") + params.Set("codex_cli_simplified_flow", "true") + params.Set("originator", originator) + return openAICodexAuthorize + "?" + params.Encode() +} + +func ParseOpenAICodexAuthorizationInput(input string) (code string, state string) { + value := strings.TrimSpace(input) + if value == "" { + return "", "" + } + if parsed, err := url.Parse(value); err == nil { + query := parsed.Query() + if query.Get("code") != "" || query.Get("state") != "" { + return strings.TrimSpace(query.Get("code")), strings.TrimSpace(query.Get("state")) + } + } + if strings.Contains(value, "#") { + parts := strings.SplitN(value, "#", 2) + return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + } + if strings.Contains(value, "code=") { + query, err := url.ParseQuery(value) + if err == nil { + return strings.TrimSpace(query.Get("code")), strings.TrimSpace(query.Get("state")) + } + } + return value, "" +} + +func ExtractOpenAICodexAccountID(accessToken string) string { + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + return "" + } + payloadSegment := parts[1] + decoded, err := base64.RawURLEncoding.DecodeString(payloadSegment) + if err != nil { + decoded, err = base64.URLEncoding.DecodeString(payloadSegment) + if err != nil { + return "" + } + } + var payload map[string]any + if err := json.Unmarshal(decoded, &payload); err != nil { + return "" + } + authRaw, ok := payload[openAICodexJWTClaim] + if !ok { + return "" + } + authClaims, ok := authRaw.(map[string]any) + if !ok { + return "" + } + accountID, _ := authClaims["chatgpt_account_id"].(string) + return strings.TrimSpace(accountID) +} diff --git a/pkg/ai/oauth/openai_codex_helpers_test.go b/pkg/ai/oauth/openai_codex_helpers_test.go new file mode 100644 index 00000000..d6e80289 --- /dev/null +++ b/pkg/ai/oauth/openai_codex_helpers_test.go @@ -0,0 +1,56 @@ +package oauth + +import ( + "encoding/base64" + "encoding/json" + "net/url" + "testing" +) + +func TestOpenAICodexOAuthHelpers(t *testing.T) { + authURL := BuildOpenAICodexAuthorizeURL("challenge", "state", "pi") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("expected valid codex auth URL, got err=%v", err) + } + query := parsed.Query() + if query.Get("client_id") != OpenAICodexClientID() { + t.Fatalf("expected codex client id in query") + } + if query.Get("code_challenge") != "challenge" || query.Get("state") != "state" { + t.Fatalf("expected challenge/state in codex auth query") + } + if query.Get("originator") != "pi" { + t.Fatalf("expected originator=pi in query") + } + + code, state := ParseOpenAICodexAuthorizationInput("http://localhost:1455/auth/callback?code=abc&state=xyz") + if code != "abc" || state != "xyz" { + t.Fatalf("expected URL parser code/state, got code=%q state=%q", code, state) + } + code, state = ParseOpenAICodexAuthorizationInput("abc#xyz") + if code != "abc" || state != "xyz" { + t.Fatalf("expected hash parser code/state, got code=%q state=%q", code, state) + } + code, state = ParseOpenAICodexAuthorizationInput("code=abc&state=xyz") + if code != "abc" || state != "xyz" { + t.Fatalf("expected query-string parser code/state, got code=%q state=%q", code, state) + } + + payload := map[string]any{ + "https://api.openai.com/auth": map[string]any{ + "chatgpt_account_id": "acct_123", + }, + } + rawPayload, err := json.Marshal(payload) + if err != nil { + t.Fatalf("failed to marshal jwt payload: %v", err) + } + jwt := "header." + base64.RawURLEncoding.EncodeToString(rawPayload) + ".sig" + if got := ExtractOpenAICodexAccountID(jwt); got != "acct_123" { + t.Fatalf("expected extracted account id acct_123, got %q", got) + } + if got := ExtractOpenAICodexAccountID("invalid-token"); got != "" { + t.Fatalf("expected invalid jwt extraction to return empty string, got %q", got) + } +} From 85aa1c983f30a343c145b3b22df6a5ac1f3891b2 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 04:58:35 +0000 Subject: [PATCH 29/75] Add pkg-ai bridge dry-run execution hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/streaming_runtime_selector.go | 34 +++++++++++++++++++ .../streaming_runtime_selector_test.go | 13 +++++++ 2 files changed, 47 insertions(+) diff --git a/pkg/connector/streaming_runtime_selector.go b/pkg/connector/streaming_runtime_selector.go index 60d04451..a1bb1eaf 100644 --- a/pkg/connector/streaming_runtime_selector.go +++ b/pkg/connector/streaming_runtime_selector.go @@ -8,6 +8,7 @@ import ( "time" aipkg "github.com/beeper/ai-bridge/pkg/ai" + aiproviders "github.com/beeper/ai-bridge/pkg/ai/providers" airuntime "github.com/beeper/ai-bridge/pkg/runtime" "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" @@ -27,6 +28,11 @@ func pkgAIRuntimeEnabled() bool { return value == "1" || value == "true" || value == "yes" || value == "on" } +func pkgAIRuntimeDryRunEnabled() bool { + value := strings.ToLower(strings.TrimSpace(os.Getenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN"))) + return value == "1" || value == "true" || value == "yes" || value == "on" +} + func chooseStreamingRuntimePath(hasAudio bool, modelAPI ModelAPI, preferPkgAI bool) streamingRuntimePath { if hasAudio { return streamingRuntimeChatCompletions @@ -66,6 +72,9 @@ func (oc *AIClient) streamWithPkgAIBridge( Str("ai_model_provider", string(aiModel.Provider)). Str("ai_model_id", aiModel.ID). Msg("pkg/ai runtime bridge flag enabled; prepared adapter context/model and delegating to existing runtime path") + if pkgAIRuntimeDryRunEnabled() { + oc.runPkgAIBridgeDryRun(ctx, aiModel, aiContext) + } switch oc.resolveModelAPI(meta) { case ModelAPIChatCompletions: return oc.streamChatCompletions(ctx, evt, portal, meta, prompt) @@ -74,6 +83,31 @@ func (oc *AIClient) streamWithPkgAIBridge( } } +func (oc *AIClient) runPkgAIBridgeDryRun(ctx context.Context, model aipkg.Model, aiContext aipkg.Context) { + aiproviders.RegisterBuiltInAPIProviders() + stream, err := aipkg.Stream(model, aiContext, &aipkg.StreamOptions{ + Ctx: ctx, + MaxTokens: model.MaxTokens, + }) + if err != nil { + oc.loggerForContext(ctx).Warn().Err(err).Msg("pkg/ai dry-run failed to create stream") + return + } + events := streamEventsFromAIStream(ctx, stream) + count := 0 + for evt := range events { + count++ + if evt.Type == StreamEventError { + oc.loggerForContext(ctx).Debug().Err(evt.Error).Int("event_count", count).Msg("pkg/ai dry-run produced error event") + return + } + if evt.Type == StreamEventComplete { + oc.loggerForContext(ctx).Debug().Int("event_count", count).Str("finish_reason", evt.FinishReason).Msg("pkg/ai dry-run completed") + return + } + } +} + func buildPkgAIContext(systemPrompt string, prompt []openai.ChatCompletionMessageParamUnion) aipkg.Context { unified := chatPromptToUnifiedMessages(prompt) return toAIContext(systemPrompt, unified, nil) diff --git a/pkg/connector/streaming_runtime_selector_test.go b/pkg/connector/streaming_runtime_selector_test.go index d9c1208c..ebbe054d 100644 --- a/pkg/connector/streaming_runtime_selector_test.go +++ b/pkg/connector/streaming_runtime_selector_test.go @@ -26,6 +26,19 @@ func TestPkgAIRuntimeEnabledFromEnv(t *testing.T) { if pkgAIRuntimeEnabled() { t.Fatalf("expected runtime flag disabled for value off") } + + t.Setenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN", "") + if pkgAIRuntimeDryRunEnabled() { + t.Fatalf("expected dry-run flag disabled by default") + } + t.Setenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN", "yes") + if !pkgAIRuntimeDryRunEnabled() { + t.Fatalf("expected dry-run flag enabled for value yes") + } + t.Setenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN", "0") + if pkgAIRuntimeDryRunEnabled() { + t.Fatalf("expected dry-run flag disabled for value 0") + } } func TestChooseStreamingRuntimePath(t *testing.T) { From 62ec9f289d5fae9ce38c8ed74c61ca3f37d92419 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:12:27 +0000 Subject: [PATCH 30/75] Add optional pkg-ai provider stream bridge with fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/pkg_ai_provider_bridge.go | 104 +++++++++++++++++++ pkg/connector/pkg_ai_provider_bridge_test.go | 95 +++++++++++++++++ pkg/connector/provider_openai.go | 12 +++ 3 files changed, 211 insertions(+) create mode 100644 pkg/connector/pkg_ai_provider_bridge.go create mode 100644 pkg/connector/pkg_ai_provider_bridge_test.go diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go new file mode 100644 index 00000000..8c7181b3 --- /dev/null +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -0,0 +1,104 @@ +package connector + +import ( + "context" + "os" + "strings" + "time" + + aipkg "github.com/beeper/ai-bridge/pkg/ai" + aiproviders "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func pkgAIProviderRuntimeEnabled() bool { + value := strings.ToLower(strings.TrimSpace(os.Getenv("PI_USE_PKG_AI_PROVIDER_RUNTIME"))) + return value == "1" || value == "true" || value == "yes" || value == "on" +} + +func inferProviderNameFromBaseURL(baseURL string) string { + lower := strings.ToLower(strings.TrimSpace(baseURL)) + switch { + case strings.Contains(lower, "openrouter.ai"): + return "openrouter" + case strings.Contains(lower, "beeper.com"): + return "beeper" + case strings.Contains(lower, "magicproxy"): + return "magic-proxy" + case strings.Contains(lower, "azure.com"): + return "azure-openai-responses" + default: + return "openai" + } +} + +func buildPkgAIModelFromGenerateParams(params GenerateParams, baseURL string) aipkg.Model { + modelID := strings.TrimSpace(params.Model) + provider := inferProviderNameFromBaseURL(baseURL) + api := aipkg.APIOpenAIResponses + if provider == "openrouter" { + api = aipkg.APIOpenAICompletions + } + return aipkg.Model{ + ID: modelID, + Name: modelID, + Provider: aipkg.Provider(provider), + API: api, + BaseURL: strings.TrimSpace(baseURL), + Input: []string{"text"}, + MaxTokens: max(params.MaxCompletionTokens, 4096), + } +} + +func shouldFallbackFromPkgAIEvent(event StreamEvent) bool { + if event.Type != StreamEventError || event.Error == nil { + return false + } + errText := strings.ToLower(strings.TrimSpace(event.Error.Error())) + return strings.Contains(errText, "not implemented yet") || + strings.Contains(errText, "no api provider registered") +} + +func tryGenerateStreamWithPkgAI( + ctx context.Context, + baseURL string, + params GenerateParams, +) (<-chan StreamEvent, bool) { + aiproviders.RegisterBuiltInAPIProviders() + model := buildPkgAIModelFromGenerateParams(params, baseURL) + aiContext := toAIContext(params.SystemPrompt, params.Messages, params.Tools) + + temp := params.Temperature + options := &aipkg.StreamOptions{ + Ctx: ctx, + MaxTokens: params.MaxCompletionTokens, + Temperature: &temp, + } + + stream, err := aipkg.Stream(model, aiContext, options) + if err != nil { + return nil, false + } + + mapped := streamEventsFromAIStream(ctx, stream) + select { + case first, ok := <-mapped: + if !ok { + return nil, false + } + if shouldFallbackFromPkgAIEvent(first) { + return nil, false + } + out := make(chan StreamEvent, 64) + go func() { + defer close(out) + out <- first + for event := range mapped { + out <- event + } + }() + return out, true + case <-time.After(50 * time.Millisecond): + // No immediate events: proceed with pkg/ai channel and let caller consume. + return mapped, true + } +} diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go new file mode 100644 index 00000000..42c58e86 --- /dev/null +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -0,0 +1,95 @@ +package connector + +import ( + "context" + "errors" + "testing" +) + +func TestPkgAIProviderRuntimeEnabled(t *testing.T) { + t.Setenv("PI_USE_PKG_AI_PROVIDER_RUNTIME", "true") + if !pkgAIProviderRuntimeEnabled() { + t.Fatalf("expected runtime flag to be enabled") + } + + t.Setenv("PI_USE_PKG_AI_PROVIDER_RUNTIME", "0") + if pkgAIProviderRuntimeEnabled() { + t.Fatalf("expected runtime flag to be disabled") + } +} + +func TestInferProviderNameFromBaseURL(t *testing.T) { + cases := []struct { + name string + baseURL string + want string + }{ + {name: "default", baseURL: "", want: "openai"}, + {name: "openrouter", baseURL: "https://openrouter.ai/api/v1", want: "openrouter"}, + {name: "beeper proxy", baseURL: "https://ai.beeper.com/openai", want: "beeper"}, + {name: "magic proxy", baseURL: "https://magicproxy.example/v1", want: "magic-proxy"}, + {name: "azure", baseURL: "https://my-openai.azure.com", want: "azure-openai-responses"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := inferProviderNameFromBaseURL(tc.baseURL) + if got != tc.want { + t.Fatalf("inferProviderNameFromBaseURL(%q) = %q, want %q", tc.baseURL, got, tc.want) + } + }) + } +} + +func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { + openRouter := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "openai/gpt-4o-mini", + MaxCompletionTokens: 256, + }, "https://openrouter.ai/api/v1") + if openRouter.API != "openai-completions" { + t.Fatalf("expected openrouter to map to openai-completions API, got %q", openRouter.API) + } + if openRouter.MaxTokens != 4096 { + t.Fatalf("expected minimum max tokens guard, got %d", openRouter.MaxTokens) + } + + openAI := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "gpt-4.1-mini", + MaxCompletionTokens: 16384, + }, "") + if openAI.API != "openai-responses" { + t.Fatalf("expected openai to map to openai-responses API, got %q", openAI.API) + } + if openAI.MaxTokens != 16384 { + t.Fatalf("unexpected max tokens: %d", openAI.MaxTokens) + } +} + +func TestShouldFallbackFromPkgAIEvent(t *testing.T) { + if !shouldFallbackFromPkgAIEvent(StreamEvent{ + Type: StreamEventError, + Error: errors.New("provider runtime is not implemented yet"), + }) { + t.Fatalf("expected not-implemented errors to trigger fallback") + } + if shouldFallbackFromPkgAIEvent(StreamEvent{Type: StreamEventDelta, Delta: "ok"}) { + t.Fatalf("did not expect non-error events to trigger fallback") + } +} + +func TestTryGenerateStreamWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { + events, ok := tryGenerateStreamWithPkgAI(context.Background(), "", GenerateParams{ + Model: "gpt-4.1-mini", + Messages: []UnifiedMessage{ + { + Role: RoleUser, + Content: []ContentPart{ + {Type: ContentTypeText, Text: "hello"}, + }, + }, + }, + }) + if ok { + t.Fatalf("expected fallback mode with stubbed providers, got events=%v", events) + } +} diff --git a/pkg/connector/provider_openai.go b/pkg/connector/provider_openai.go index 52293462..b7405703 100644 --- a/pkg/connector/provider_openai.go +++ b/pkg/connector/provider_openai.go @@ -265,6 +265,18 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R // GenerateStream generates a streaming response from OpenAI using Responses API func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GenerateParams) (<-chan StreamEvent, error) { + if pkgAIProviderRuntimeEnabled() { + if pkgAIEvents, ok := tryGenerateStreamWithPkgAI(ctx, o.baseURL, params); ok { + o.log.Debug(). + Str("model", params.Model). + Msg("Using pkg/ai provider runtime for OpenAI stream") + return pkgAIEvents, nil + } + o.log.Warn(). + Str("model", params.Model). + Msg("pkg/ai provider runtime fallback to existing OpenAI stream path") + } + events := make(chan StreamEvent, 100) go func() { From d1dafacf3ad3bee8671a21e24e9f3d3af5a7ad13 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:14:32 +0000 Subject: [PATCH 31/75] Add Gemini CLI Claude thinking header helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/google_gemini_cli.go | 18 +++++++++++++++++ pkg/ai/providers/google_gemini_cli_test.go | 23 ++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pkg/ai/providers/google_gemini_cli.go b/pkg/ai/providers/google_gemini_cli.go index 90598645..149d5d7b 100644 --- a/pkg/ai/providers/google_gemini_cli.go +++ b/pkg/ai/providers/google_gemini_cli.go @@ -10,6 +10,8 @@ import ( "github.com/beeper/ai-bridge/pkg/ai" ) +const ClaudeThinkingBetaHeader = "interleaved-thinking-2025-05-14" + func ExtractRetryDelay(errorText string, headers http.Header) (int, bool) { return extractRetryDelayAt(errorText, headers, time.Now()) } @@ -130,3 +132,19 @@ func NormalizeGoogleToolCall(name string, args map[string]any, id string, though } return block } + +func IsClaudeThinkingModel(modelID string) bool { + normalized := strings.ToLower(strings.TrimSpace(modelID)) + return strings.Contains(normalized, "claude") && strings.Contains(normalized, "thinking") +} + +func BuildGeminiCLIHeaders(model ai.Model, headers map[string]string) map[string]string { + out := map[string]string{} + for k, v := range headers { + out[k] = v + } + if IsClaudeThinkingModel(model.ID) { + out["anthropic-beta"] = ClaudeThinkingBetaHeader + } + return out +} diff --git a/pkg/ai/providers/google_gemini_cli_test.go b/pkg/ai/providers/google_gemini_cli_test.go index 9f54771b..95f3f6be 100644 --- a/pkg/ai/providers/google_gemini_cli_test.go +++ b/pkg/ai/providers/google_gemini_cli_test.go @@ -5,6 +5,8 @@ import ( "strconv" "testing" "time" + + "github.com/beeper/ai-bridge/pkg/ai" ) func TestExtractRetryDelay(t *testing.T) { @@ -44,3 +46,24 @@ func TestExtractRetryDelay(t *testing.T) { } }) } + +func TestBuildGeminiCLIHeaders(t *testing.T) { + t.Run("adds anthropic beta for claude thinking model", func(t *testing.T) { + headers := BuildGeminiCLIHeaders(ai.Model{ID: "claude-opus-4-5-thinking"}, map[string]string{ + "authorization": "Bearer token", + }) + if headers["anthropic-beta"] != ClaudeThinkingBetaHeader { + t.Fatalf("expected anthropic-beta header %q, got %q", ClaudeThinkingBetaHeader, headers["anthropic-beta"]) + } + if headers["authorization"] != "Bearer token" { + t.Fatalf("expected existing headers to be preserved") + } + }) + + t.Run("does not add anthropic beta for gemini model", func(t *testing.T) { + headers := BuildGeminiCLIHeaders(ai.Model{ID: "gemini-2.5-flash"}, nil) + if _, ok := headers["anthropic-beta"]; ok { + t.Fatalf("did not expect anthropic-beta header for gemini model") + } + }) +} From a382bccb22140315a76229a526ed6924bf436978 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:20:52 +0000 Subject: [PATCH 32/75] Implement pkg-ai OpenAI responses streaming runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/openai_responses_runtime.go | 216 ++++++++++++++++++ .../openai_responses_runtime_test.go | 69 ++++++ pkg/ai/providers/register_builtins.go | 7 +- pkg/connector/pkg_ai_provider_bridge.go | 2 + pkg/connector/pkg_ai_provider_bridge_test.go | 2 +- pkg/connector/provider_openai.go | 5 +- 6 files changed, 298 insertions(+), 3 deletions(-) create mode 100644 pkg/ai/providers/openai_responses_runtime.go create mode 100644 pkg/ai/providers/openai_responses_runtime_test.go diff --git a/pkg/ai/providers/openai_responses_runtime.go b/pkg/ai/providers/openai_responses_runtime.go new file mode 100644 index 00000000..7f63c6b1 --- /dev/null +++ b/pkg/ai/providers/openai_responses_runtime.go @@ -0,0 +1,216 @@ +package providers + +import ( + "context" + "encoding/json" + "strings" + "time" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/shared/httputil" +) + +func streamOpenAIResponses(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + openAIOptions := OpenAIResponsesOptions{} + if options != nil { + openAIOptions.StreamOptions = *options + } + return streamOpenAIResponsesWithOptions(model, c, openAIOptions) +} + +func streamSimpleOpenAIResponses(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + var effort ai.ThinkingLevel + if options != nil { + effort = options.Reasoning + } + if !ai.SupportsXhigh(model) { + effort = ClampReasoning(effort) + } + return streamOpenAIResponsesWithOptions(model, c, OpenAIResponsesOptions{ + StreamOptions: base, + ReasoningEffort: effort, + }) +} + +func streamOpenAIResponsesWithOptions( + model ai.Model, + c ai.Context, + options OpenAIResponsesOptions, +) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + apiKey := strings.TrimSpace(options.StreamOptions.APIKey) + if apiKey == "" { + apiKey = strings.TrimSpace(ai.GetEnvAPIKey(string(model.Provider))) + } + if apiKey == "" { + pushProviderError(stream, model, "missing API key for OpenAI responses runtime") + return + } + + payload := BuildOpenAIResponsesParams(model, c, options) + if options.StreamOptions.OnPayload != nil { + options.StreamOptions.OnPayload(payload) + } + + request := param.Override[responses.ResponseNewParams](payload) + reqOptions := []option.RequestOption{option.WithAPIKey(apiKey)} + if baseURL := strings.TrimSpace(model.BaseURL); baseURL != "" { + reqOptions = append(reqOptions, option.WithBaseURL(baseURL)) + } + reqOptions = httputil.AppendHeaderOptions(reqOptions, model.Headers) + reqOptions = httputil.AppendHeaderOptions(reqOptions, options.StreamOptions.Headers) + + client := openai.NewClient(reqOptions...) + runCtx := options.StreamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + + openAIStream := client.Responses.NewStreaming(runCtx, request) + if openAIStream == nil { + pushProviderError(stream, model, "failed to create OpenAI responses stream") + return + } + + var textBuilder strings.Builder + var thinkingBuilder strings.Builder + toolCalls := make([]ai.ContentBlock, 0, 2) + var completedResponse responses.Response + + for openAIStream.Next() { + event := openAIStream.Current() + switch event.Type { + case "response.output_text.delta": + textBuilder.WriteString(event.Delta) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventTextDelta, + Delta: event.Delta, + }) + case "response.reasoning_text.delta": + thinkingBuilder.WriteString(event.Delta) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventThinkingDelta, + Delta: event.Delta, + }) + case "response.function_call_arguments.done": + toolCall := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: strings.TrimSpace(event.ItemID), + Name: strings.TrimSpace(event.Name), + Arguments: parseToolArguments(event.Arguments), + } + toolCalls = append(toolCalls, toolCall) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventToolCallEnd, + ToolCall: &toolCall, + }) + case "response.completed": + completedResponse = event.Response + case "error": + pushProviderError(stream, model, strings.TrimSpace(event.Message)) + return + } + } + + if err := openAIStream.Err(); err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + assistantMessage := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + Timestamp: time.Now().UnixMilli(), + StopReason: mapOpenAIResponseStatus(completedResponse.Status), + } + assistantMessage.Usage = ai.Usage{ + Input: int(completedResponse.Usage.InputTokens), + Output: int(completedResponse.Usage.OutputTokens), + TotalTokens: int(completedResponse.Usage.TotalTokens), + } + assistantMessage.Usage.Cost = ai.CalculateCost(model, assistantMessage.Usage) + + if thinking := strings.TrimSpace(thinkingBuilder.String()); thinking != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: thinking, + }) + } + if text := strings.TrimSpace(textBuilder.String()); text != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: text, + }) + } + if len(toolCalls) > 0 { + assistantMessage.Content = append(assistantMessage.Content, toolCalls...) + } + if len(toolCalls) > 0 && assistantMessage.StopReason == ai.StopReasonStop { + assistantMessage.StopReason = ai.StopReasonToolUse + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: assistantMessage, + Reason: assistantMessage.StopReason, + }) + }() + return stream +} + +func parseToolArguments(raw string) map[string]any { + raw = strings.TrimSpace(raw) + if raw == "" { + return map[string]any{} + } + args := map[string]any{} + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return map[string]any{"_raw": raw} + } + return args +} + +func pushProviderError(stream *ai.AssistantMessageEventStream, model ai.Model, errText string) { + if strings.TrimSpace(errText) == "" { + errText = "openai responses stream failed" + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventError, + Error: ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + StopReason: ai.StopReasonError, + ErrorMessage: strings.TrimSpace(errText), + Timestamp: time.Now().UnixMilli(), + }, + Reason: ai.StopReasonError, + }) +} + +func mapOpenAIResponseStatus(status responses.ResponseStatus) ai.StopReason { + switch status { + case responses.ResponseStatusCompleted: + return ai.StopReasonStop + case responses.ResponseStatusInProgress, responses.ResponseStatusIncomplete: + return ai.StopReasonLength + case responses.ResponseStatusCancelled: + return ai.StopReasonAborted + case responses.ResponseStatusFailed: + return ai.StopReasonError + default: + if strings.TrimSpace(string(status)) == "" { + return ai.StopReasonStop + } + return ai.StopReasonStop + } +} diff --git a/pkg/ai/providers/openai_responses_runtime_test.go b/pkg/ai/providers/openai_responses_runtime_test.go new file mode 100644 index 00000000..4c047f38 --- /dev/null +++ b/pkg/ai/providers/openai_responses_runtime_test.go @@ -0,0 +1,69 @@ +package providers + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/openai/openai-go/v3/responses" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamOpenAIResponses_MissingAPIKeyEmitsError(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "") + stream := streamOpenAIResponses(ai.Model{ + ID: "gpt-4.1-mini", + Provider: "openai", + API: ai.APIOpenAIResponses, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "api key") { + t.Fatalf("expected missing api key message, got %q", evt.Error.ErrorMessage) + } + if _, err = stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestParseToolArguments(t *testing.T) { + if got := parseToolArguments(""); len(got) != 0 { + t.Fatalf("expected empty map for empty args, got %#v", got) + } + valid := parseToolArguments(`{"x":1}`) + if val, ok := valid["x"]; !ok || val.(float64) != 1 { + t.Fatalf("expected parsed JSON args, got %#v", valid) + } + invalid := parseToolArguments("{oops") + if invalid["_raw"] != "{oops" { + t.Fatalf("expected raw fallback on invalid JSON, got %#v", invalid) + } +} + +func TestMapOpenAIResponseStatus(t *testing.T) { + cases := map[responses.ResponseStatus]ai.StopReason{ + responses.ResponseStatusCompleted: ai.StopReasonStop, + responses.ResponseStatusInProgress: ai.StopReasonLength, + responses.ResponseStatusIncomplete: ai.StopReasonLength, + responses.ResponseStatusCancelled: ai.StopReasonAborted, + responses.ResponseStatusFailed: ai.StopReasonError, + } + for in, want := range cases { + if got := mapOpenAIResponseStatus(in); got != want { + t.Fatalf("mapOpenAIResponseStatus(%q) = %q, want %q", in, got, want) + } + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index f5fc8f18..e9484a0b 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -50,9 +50,14 @@ func notImplementedSimpleStream(apiID ai.Api) ai.StreamSimpleFn { // RegisterBuiltInAPIProviders registers providers implemented in this package. func RegisterBuiltInAPIProviders() { + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIOpenAIResponses, + Stream: streamOpenAIResponses, + StreamSimple: streamSimpleOpenAIResponses, + }, BuiltinProviderSourceID) + for _, apiID := range []ai.Api{ ai.APIOpenAICompletions, - ai.APIOpenAIResponses, ai.APIAzureOpenAIResponse, ai.APIOpenAICodexResponse, ai.APIAnthropicMessages, diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go index 8c7181b3..68af63a1 100644 --- a/pkg/connector/pkg_ai_provider_bridge.go +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -61,6 +61,7 @@ func shouldFallbackFromPkgAIEvent(event StreamEvent) bool { func tryGenerateStreamWithPkgAI( ctx context.Context, baseURL string, + apiKey string, params GenerateParams, ) (<-chan StreamEvent, bool) { aiproviders.RegisterBuiltInAPIProviders() @@ -72,6 +73,7 @@ func tryGenerateStreamWithPkgAI( Ctx: ctx, MaxTokens: params.MaxCompletionTokens, Temperature: &temp, + APIKey: strings.TrimSpace(apiKey), } stream, err := aipkg.Stream(model, aiContext, options) diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 42c58e86..3dbc3641 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -78,7 +78,7 @@ func TestShouldFallbackFromPkgAIEvent(t *testing.T) { } func TestTryGenerateStreamWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { - events, ok := tryGenerateStreamWithPkgAI(context.Background(), "", GenerateParams{ + events, ok := tryGenerateStreamWithPkgAI(context.Background(), "https://openrouter.ai/api/v1", "", GenerateParams{ Model: "gpt-4.1-mini", Messages: []UnifiedMessage{ { diff --git a/pkg/connector/provider_openai.go b/pkg/connector/provider_openai.go index b7405703..0b734411 100644 --- a/pkg/connector/provider_openai.go +++ b/pkg/connector/provider_openai.go @@ -27,6 +27,7 @@ import ( type OpenAIProvider struct { client openai.Client log zerolog.Logger + apiKey string baseURL string } @@ -82,6 +83,7 @@ func NewOpenAIProviderWithUserID(apiKey, baseURL, userID string, log zerolog.Log return &OpenAIProvider{ client: client, log: log.With().Str("provider", "openai").Logger(), + apiKey: apiKey, baseURL: baseURL, }, nil } @@ -195,6 +197,7 @@ func NewOpenAIProviderWithPDFPlugin(apiKey, baseURL, userID, pdfEngine string, h return &OpenAIProvider{ client: client, log: log.With().Str("provider", "openai").Str("pdf_engine", pdfEngine).Logger(), + apiKey: apiKey, baseURL: baseURL, }, nil } @@ -266,7 +269,7 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R // GenerateStream generates a streaming response from OpenAI using Responses API func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GenerateParams) (<-chan StreamEvent, error) { if pkgAIProviderRuntimeEnabled() { - if pkgAIEvents, ok := tryGenerateStreamWithPkgAI(ctx, o.baseURL, params); ok { + if pkgAIEvents, ok := tryGenerateStreamWithPkgAI(ctx, o.baseURL, o.apiKey, params); ok { o.log.Debug(). Str("model", params.Model). Msg("Using pkg/ai provider runtime for OpenAI stream") From 72ec17c3b695f348d1ba96e7c8248edb331e2eb3 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:24:10 +0000 Subject: [PATCH 33/75] Add pkg-ai OpenAI completions streaming runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- .../providers/openai_completions_runtime.go | 207 ++++++++++++++++++ .../openai_completions_runtime_test.go | 54 +++++ pkg/ai/providers/register_builtins.go | 6 +- pkg/connector/pkg_ai_provider_bridge.go | 2 + pkg/connector/pkg_ai_provider_bridge_test.go | 9 +- 5 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 pkg/ai/providers/openai_completions_runtime.go create mode 100644 pkg/ai/providers/openai_completions_runtime_test.go diff --git a/pkg/ai/providers/openai_completions_runtime.go b/pkg/ai/providers/openai_completions_runtime.go new file mode 100644 index 00000000..cd07a5af --- /dev/null +++ b/pkg/ai/providers/openai_completions_runtime.go @@ -0,0 +1,207 @@ +package providers + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/packages/param" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/shared/httputil" +) + +func streamOpenAICompletions(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + openAIOptions := OpenAICompletionsOptions{} + if options != nil { + openAIOptions.StreamOptions = *options + } + return streamOpenAICompletionsWithOptions(model, c, openAIOptions) +} + +func streamSimpleOpenAICompletions(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + var effort ai.ThinkingLevel + if options != nil { + effort = options.Reasoning + } + if !ai.SupportsXhigh(model) { + effort = ClampReasoning(effort) + } + return streamOpenAICompletionsWithOptions(model, c, OpenAICompletionsOptions{ + StreamOptions: base, + ReasoningEffort: effort, + }) +} + +func streamOpenAICompletionsWithOptions( + model ai.Model, + c ai.Context, + options OpenAICompletionsOptions, +) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + apiKey := strings.TrimSpace(options.StreamOptions.APIKey) + if apiKey == "" { + apiKey = strings.TrimSpace(ai.GetEnvAPIKey(string(model.Provider))) + } + if apiKey == "" { + pushProviderError(stream, model, "missing API key for OpenAI completions runtime") + return + } + + payload := BuildOpenAICompletionsParams(model, c, options) + if options.StreamOptions.OnPayload != nil { + options.StreamOptions.OnPayload(payload) + } + + request := param.Override[openai.ChatCompletionNewParams](payload) + reqOptions := []option.RequestOption{option.WithAPIKey(apiKey)} + if baseURL := strings.TrimSpace(model.BaseURL); baseURL != "" { + reqOptions = append(reqOptions, option.WithBaseURL(baseURL)) + } + reqOptions = httputil.AppendHeaderOptions(reqOptions, model.Headers) + reqOptions = httputil.AppendHeaderOptions(reqOptions, options.StreamOptions.Headers) + client := openai.NewClient(reqOptions...) + + runCtx := options.StreamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + openAIStream := client.Chat.Completions.NewStreaming(runCtx, request) + if openAIStream == nil { + pushProviderError(stream, model, "failed to create OpenAI completions stream") + return + } + + var textBuilder strings.Builder + toolStates := map[int]*toolCallState{} + toolOrder := make([]int, 0, 2) + usage := ai.Usage{} + stopReason := ai.StopReasonStop + toolEventsEmitted := false + + for openAIStream.Next() { + chunk := openAIStream.Current() + if chunk.Usage.TotalTokens > 0 || chunk.Usage.PromptTokens > 0 || chunk.Usage.CompletionTokens > 0 { + usage = ai.Usage{ + Input: int(chunk.Usage.PromptTokens), + Output: int(chunk.Usage.CompletionTokens), + TotalTokens: int(chunk.Usage.TotalTokens), + } + } + + for _, choice := range chunk.Choices { + if choice.Delta.Content != "" { + textBuilder.WriteString(choice.Delta.Content) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventTextDelta, + Delta: choice.Delta.Content, + }) + } + + for _, toolDelta := range choice.Delta.ToolCalls { + idx := int(toolDelta.Index) + state, ok := toolStates[idx] + if !ok { + state = &toolCallState{ + ID: fmt.Sprintf("call_%d", idx), + } + toolStates[idx] = state + toolOrder = append(toolOrder, idx) + } + if strings.TrimSpace(toolDelta.ID) != "" { + state.ID = strings.TrimSpace(toolDelta.ID) + } + if strings.TrimSpace(toolDelta.Function.Name) != "" { + state.Name = strings.TrimSpace(toolDelta.Function.Name) + } + if toolDelta.Function.Arguments != "" { + state.Arguments.WriteString(toolDelta.Function.Arguments) + } + } + + if choice.FinishReason != "" { + stopReason = mapChatCompletionFinishReason(string(choice.FinishReason)) + } + } + } + + if err := openAIStream.Err(); err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + content := make([]ai.ContentBlock, 0, len(toolOrder)+1) + if text := strings.TrimSpace(textBuilder.String()); text != "" { + content = append(content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: text, + }) + } + + for _, idx := range toolOrder { + state := toolStates[idx] + if state == nil || strings.TrimSpace(state.Name) == "" { + continue + } + toolCall := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: state.ID, + Name: state.Name, + Arguments: parseToolArguments(state.Arguments.String()), + } + content = append(content, toolCall) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventToolCallEnd, + ToolCall: &toolCall, + }) + toolEventsEmitted = true + } + + if toolEventsEmitted && stopReason == ai.StopReasonStop { + stopReason = ai.StopReasonToolUse + } + usage.Cost = ai.CalculateCost(model, usage) + assistantMessage := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + Content: content, + Usage: usage, + StopReason: stopReason, + Timestamp: time.Now().UnixMilli(), + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: assistantMessage, + Reason: assistantMessage.StopReason, + }) + }() + return stream +} + +type toolCallState struct { + ID string + Name string + Arguments strings.Builder +} + +func mapChatCompletionFinishReason(reason string) ai.StopReason { + switch strings.ToLower(strings.TrimSpace(reason)) { + case "stop": + return ai.StopReasonStop + case "length": + return ai.StopReasonLength + case "tool_calls", "tool": + return ai.StopReasonToolUse + case "content_filter", "error": + return ai.StopReasonError + default: + return ai.StopReasonStop + } +} diff --git a/pkg/ai/providers/openai_completions_runtime_test.go b/pkg/ai/providers/openai_completions_runtime_test.go new file mode 100644 index 00000000..4bd71297 --- /dev/null +++ b/pkg/ai/providers/openai_completions_runtime_test.go @@ -0,0 +1,54 @@ +package providers + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamOpenAICompletions_MissingAPIKeyEmitsError(t *testing.T) { + t.Setenv("OPENROUTER_API_KEY", "") + stream := streamOpenAICompletions(ai.Model{ + ID: "openai/gpt-4o-mini", + Provider: "openrouter", + API: ai.APIOpenAICompletions, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "api key") { + t.Fatalf("expected missing api key message, got %q", evt.Error.ErrorMessage) + } + if _, err = stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestMapChatCompletionFinishReason(t *testing.T) { + cases := map[string]ai.StopReason{ + "stop": ai.StopReasonStop, + "length": ai.StopReasonLength, + "tool_calls": ai.StopReasonToolUse, + "tool": ai.StopReasonToolUse, + "error": ai.StopReasonError, + "": ai.StopReasonStop, + } + for in, want := range cases { + if got := mapChatCompletionFinishReason(in); got != want { + t.Fatalf("mapChatCompletionFinishReason(%q) = %q, want %q", in, got, want) + } + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index e9484a0b..c2614e08 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -55,9 +55,13 @@ func RegisterBuiltInAPIProviders() { Stream: streamOpenAIResponses, StreamSimple: streamSimpleOpenAIResponses, }, BuiltinProviderSourceID) + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIOpenAICompletions, + Stream: streamOpenAICompletions, + StreamSimple: streamSimpleOpenAICompletions, + }, BuiltinProviderSourceID) for _, apiID := range []ai.Api{ - ai.APIOpenAICompletions, ai.APIAzureOpenAIResponse, ai.APIOpenAICodexResponse, ai.APIAnthropicMessages, diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go index 68af63a1..011586b0 100644 --- a/pkg/connector/pkg_ai_provider_bridge.go +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -37,6 +37,8 @@ func buildPkgAIModelFromGenerateParams(params GenerateParams, baseURL string) ai api := aipkg.APIOpenAIResponses if provider == "openrouter" { api = aipkg.APIOpenAICompletions + } else if provider == "azure-openai-responses" { + api = aipkg.APIAzureOpenAIResponse } return aipkg.Model{ ID: modelID, diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 3dbc3641..33ee9afb 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -63,6 +63,13 @@ func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { if openAI.MaxTokens != 16384 { t.Fatalf("unexpected max tokens: %d", openAI.MaxTokens) } + + azure := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "gpt-4.1-mini", + }, "https://my-openai.azure.com") + if azure.API != "azure-openai-responses" { + t.Fatalf("expected azure base URL to map to azure-openai-responses API, got %q", azure.API) + } } func TestShouldFallbackFromPkgAIEvent(t *testing.T) { @@ -78,7 +85,7 @@ func TestShouldFallbackFromPkgAIEvent(t *testing.T) { } func TestTryGenerateStreamWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { - events, ok := tryGenerateStreamWithPkgAI(context.Background(), "https://openrouter.ai/api/v1", "", GenerateParams{ + events, ok := tryGenerateStreamWithPkgAI(context.Background(), "https://my-openai.azure.com", "", GenerateParams{ Model: "gpt-4.1-mini", Messages: []UnifiedMessage{ { From 4dc66e94bd1e4d38c17ce1f364e862df13b5aa54 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:28:25 +0000 Subject: [PATCH 34/75] Add pkg-ai Azure OpenAI responses streaming runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- .../azure_openai_responses_runtime.go | 175 ++++++++++++++++++ .../azure_openai_responses_runtime_test.go | 68 +++++++ pkg/ai/providers/register_builtins.go | 6 +- pkg/connector/pkg_ai_provider_bridge_test.go | 10 +- 4 files changed, 255 insertions(+), 4 deletions(-) create mode 100644 pkg/ai/providers/azure_openai_responses_runtime.go create mode 100644 pkg/ai/providers/azure_openai_responses_runtime_test.go diff --git a/pkg/ai/providers/azure_openai_responses_runtime.go b/pkg/ai/providers/azure_openai_responses_runtime.go new file mode 100644 index 00000000..fbb7028a --- /dev/null +++ b/pkg/ai/providers/azure_openai_responses_runtime.go @@ -0,0 +1,175 @@ +package providers + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/shared/httputil" +) + +func streamAzureOpenAIResponses(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + azureOptions := AzureOpenAIResponsesOptions{} + if options != nil { + azureOptions.StreamOptions = *options + } + return streamAzureOpenAIResponsesWithOptions(model, c, azureOptions) +} + +func streamSimpleAzureOpenAIResponses(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + var effort ai.ThinkingLevel + if options != nil { + effort = options.Reasoning + } + if !ai.SupportsXhigh(model) { + effort = ClampReasoning(effort) + } + return streamAzureOpenAIResponsesWithOptions(model, c, AzureOpenAIResponsesOptions{ + OpenAIResponsesOptions: OpenAIResponsesOptions{ + StreamOptions: base, + ReasoningEffort: effort, + }, + }) +} + +func streamAzureOpenAIResponsesWithOptions( + model ai.Model, + c ai.Context, + options AzureOpenAIResponsesOptions, +) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + baseURL, apiVersion, err := ResolveAzureConfig(model, &options) + if err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + apiKey := strings.TrimSpace(options.StreamOptions.APIKey) + if apiKey == "" { + apiKey = strings.TrimSpace(ai.GetEnvAPIKey("azure-openai-responses")) + } + if apiKey == "" { + pushProviderError(stream, model, "missing API key for Azure OpenAI responses runtime") + return + } + + payload := BuildAzureOpenAIResponsesParams(model, c, options) + if options.StreamOptions.OnPayload != nil { + options.StreamOptions.OnPayload(payload) + } + request := param.Override[responses.ResponseNewParams](payload) + + reqOptions := []option.RequestOption{ + option.WithAPIKey(apiKey), + option.WithBaseURL(baseURL), + option.WithHeader("api-key", apiKey), + } + if apiVersion != "" && apiVersion != "v1" { + reqOptions = append(reqOptions, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + q := req.URL.Query() + q.Set("api-version", apiVersion) + req.URL.RawQuery = q.Encode() + return next(req) + })) + } + reqOptions = httputil.AppendHeaderOptions(reqOptions, model.Headers) + reqOptions = httputil.AppendHeaderOptions(reqOptions, options.StreamOptions.Headers) + + client := openai.NewClient(reqOptions...) + runCtx := options.StreamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + + openAIStream := client.Responses.NewStreaming(runCtx, request) + if openAIStream == nil { + pushProviderError(stream, model, "failed to create Azure OpenAI responses stream") + return + } + + var textBuilder strings.Builder + var thinkingBuilder strings.Builder + toolCalls := make([]ai.ContentBlock, 0, 2) + var completedResponse responses.Response + + for openAIStream.Next() { + event := openAIStream.Current() + switch event.Type { + case "response.output_text.delta": + textBuilder.WriteString(event.Delta) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventTextDelta, Delta: event.Delta}) + case "response.reasoning_text.delta": + thinkingBuilder.WriteString(event.Delta) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventThinkingDelta, Delta: event.Delta}) + case "response.function_call_arguments.done": + toolCall := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: strings.TrimSpace(event.ItemID), + Name: strings.TrimSpace(event.Name), + Arguments: parseToolArguments(event.Arguments), + } + toolCalls = append(toolCalls, toolCall) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventToolCallEnd, ToolCall: &toolCall}) + case "response.completed": + completedResponse = event.Response + case "error": + pushProviderError(stream, model, strings.TrimSpace(event.Message)) + return + } + } + + if err := openAIStream.Err(); err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + assistantMessage := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + Timestamp: time.Now().UnixMilli(), + StopReason: mapOpenAIResponseStatus(completedResponse.Status), + Usage: ai.Usage{ + Input: int(completedResponse.Usage.InputTokens), + Output: int(completedResponse.Usage.OutputTokens), + TotalTokens: int(completedResponse.Usage.TotalTokens), + }, + } + if thinking := strings.TrimSpace(thinkingBuilder.String()); thinking != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: thinking, + }) + } + if text := strings.TrimSpace(textBuilder.String()); text != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: text, + }) + } + if len(toolCalls) > 0 { + assistantMessage.Content = append(assistantMessage.Content, toolCalls...) + } + if len(toolCalls) > 0 && assistantMessage.StopReason == ai.StopReasonStop { + assistantMessage.StopReason = ai.StopReasonToolUse + } + assistantMessage.Usage.Cost = ai.CalculateCost(model, assistantMessage.Usage) + + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: assistantMessage, + Reason: assistantMessage.StopReason, + }) + }() + return stream +} diff --git a/pkg/ai/providers/azure_openai_responses_runtime_test.go b/pkg/ai/providers/azure_openai_responses_runtime_test.go new file mode 100644 index 00000000..07ac5915 --- /dev/null +++ b/pkg/ai/providers/azure_openai_responses_runtime_test.go @@ -0,0 +1,68 @@ +package providers + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamAzureOpenAIResponses_MissingBaseURLEmitsError(t *testing.T) { + t.Setenv("AZURE_OPENAI_BASE_URL", "") + t.Setenv("AZURE_OPENAI_RESOURCE_NAME", "") + t.Setenv("AZURE_OPENAI_API_KEY", "test-key") + stream := streamAzureOpenAIResponses(ai.Model{ + ID: "gpt-4.1-mini", + Provider: "azure-openai-responses", + API: ai.APIAzureOpenAIResponse, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "base url") { + t.Fatalf("expected base url error message, got %q", evt.Error.ErrorMessage) + } + if _, err = stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestStreamAzureOpenAIResponses_MissingAPIKeyEmitsError(t *testing.T) { + t.Setenv("AZURE_OPENAI_BASE_URL", "https://my-resource.openai.azure.com/openai/v1") + t.Setenv("AZURE_OPENAI_API_KEY", "") + stream := streamAzureOpenAIResponses(ai.Model{ + ID: "gpt-4.1-mini", + Provider: "azure-openai-responses", + API: ai.APIAzureOpenAIResponse, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "api key") { + t.Fatalf("expected missing api key message, got %q", evt.Error.ErrorMessage) + } + if _, err = stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index c2614e08..6ec16222 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -60,9 +60,13 @@ func RegisterBuiltInAPIProviders() { Stream: streamOpenAICompletions, StreamSimple: streamSimpleOpenAICompletions, }, BuiltinProviderSourceID) + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIAzureOpenAIResponse, + Stream: streamAzureOpenAIResponses, + StreamSimple: streamSimpleAzureOpenAIResponses, + }, BuiltinProviderSourceID) for _, apiID := range []ai.Api{ - ai.APIAzureOpenAIResponse, ai.APIOpenAICodexResponse, ai.APIAnthropicMessages, ai.APIGoogleGenerativeAI, diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 33ee9afb..107826ad 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -84,7 +84,7 @@ func TestShouldFallbackFromPkgAIEvent(t *testing.T) { } } -func TestTryGenerateStreamWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { +func TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved(t *testing.T) { events, ok := tryGenerateStreamWithPkgAI(context.Background(), "https://my-openai.azure.com", "", GenerateParams{ Model: "gpt-4.1-mini", Messages: []UnifiedMessage{ @@ -96,7 +96,11 @@ func TestTryGenerateStreamWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { }, }, }) - if ok { - t.Fatalf("expected fallback mode with stubbed providers, got events=%v", events) + if !ok { + t.Fatalf("expected pkg/ai stream to be selected") + } + event := <-events + if event.Type != StreamEventError { + t.Fatalf("expected runtime error event without credentials, got %#v", event) } } From 455a400d74a5a71136c11ed21ed5e8bc11e69d70 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:32:40 +0000 Subject: [PATCH 35/75] Map reasoning options through pkg-ai bridge stream simple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/pkg_ai_provider_bridge.go | 39 +++++++++++++++++++- pkg/connector/pkg_ai_provider_bridge_test.go | 37 ++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go index 011586b0..cc864278 100644 --- a/pkg/connector/pkg_ai_provider_bridge.go +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -46,11 +46,37 @@ func buildPkgAIModelFromGenerateParams(params GenerateParams, baseURL string) ai Provider: aipkg.Provider(provider), API: api, BaseURL: strings.TrimSpace(baseURL), + Reasoning: modelSupportsReasoning(modelID) || strings.TrimSpace(params.ReasoningEffort) != "", Input: []string{"text"}, MaxTokens: max(params.MaxCompletionTokens, 4096), } } +func modelSupportsReasoning(modelID string) bool { + modelID = strings.ToLower(strings.TrimSpace(modelID)) + return strings.HasPrefix(modelID, "gpt-5") || + strings.HasPrefix(modelID, "o1") || + strings.HasPrefix(modelID, "o3") || + strings.Contains(modelID, "thinking") +} + +func parseThinkingLevel(value string) aipkg.ThinkingLevel { + switch strings.ToLower(strings.TrimSpace(value)) { + case "minimal": + return aipkg.ThinkingMinimal + case "low": + return aipkg.ThinkingLow + case "medium": + return aipkg.ThinkingMedium + case "high": + return aipkg.ThinkingHigh + case "xhigh": + return aipkg.ThinkingXHigh + default: + return "" + } +} + func shouldFallbackFromPkgAIEvent(event StreamEvent) bool { if event.Type != StreamEventError || event.Error == nil { return false @@ -78,7 +104,18 @@ func tryGenerateStreamWithPkgAI( APIKey: strings.TrimSpace(apiKey), } - stream, err := aipkg.Stream(model, aiContext, options) + var ( + stream *aipkg.AssistantMessageEventStream + err error + ) + if reasoning := parseThinkingLevel(params.ReasoningEffort); reasoning != "" { + stream, err = aipkg.StreamSimple(model, aiContext, &aipkg.SimpleStreamOptions{ + StreamOptions: *options, + Reasoning: reasoning, + }) + } else { + stream, err = aipkg.Stream(model, aiContext, options) + } if err != nil { return nil, false } diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 107826ad..1f5be601 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -54,7 +54,7 @@ func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { } openAI := buildPkgAIModelFromGenerateParams(GenerateParams{ - Model: "gpt-4.1-mini", + Model: "gpt-5-mini", MaxCompletionTokens: 16384, }, "") if openAI.API != "openai-responses" { @@ -63,6 +63,9 @@ func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { if openAI.MaxTokens != 16384 { t.Fatalf("unexpected max tokens: %d", openAI.MaxTokens) } + if !openAI.Reasoning { + t.Fatalf("expected gpt-5 family model to be marked as reasoning capable") + } azure := buildPkgAIModelFromGenerateParams(GenerateParams{ Model: "gpt-4.1-mini", @@ -70,6 +73,21 @@ func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { if azure.API != "azure-openai-responses" { t.Fatalf("expected azure base URL to map to azure-openai-responses API, got %q", azure.API) } + + nonReasoning := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "gpt-4.1-mini", + }, "") + if nonReasoning.Reasoning { + t.Fatalf("did not expect non-reasoning model to be marked as reasoning capable") + } + + withReasoningOverride := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "gpt-4.1-mini", + ReasoningEffort: "high", + }, "") + if !withReasoningOverride.Reasoning { + t.Fatalf("expected reasoning effort override to mark model as reasoning capable") + } } func TestShouldFallbackFromPkgAIEvent(t *testing.T) { @@ -104,3 +122,20 @@ func TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved t.Fatalf("expected runtime error event without credentials, got %#v", event) } } + +func TestParseThinkingLevel(t *testing.T) { + cases := map[string]string{ + "minimal": "minimal", + "low": "low", + "medium": "medium", + "high": "high", + "xhigh": "xhigh", + "none": "", + "": "", + } + for in, want := range cases { + if got := string(parseThinkingLevel(in)); got != want { + t.Fatalf("parseThinkingLevel(%q) = %q, want %q", in, got, want) + } + } +} From 0d2445c92e4b057c7fc5cbdc1f95e0a1d053c186 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:33:58 +0000 Subject: [PATCH 36/75] Assert builtins use real OpenAI runtime implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/register_builtins_test.go | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pkg/ai/providers/register_builtins_test.go b/pkg/ai/providers/register_builtins_test.go index 42ea0ca6..3cb80c75 100644 --- a/pkg/ai/providers/register_builtins_test.go +++ b/pkg/ai/providers/register_builtins_test.go @@ -3,6 +3,7 @@ package providers import ( "context" "io" + "strings" "testing" "time" @@ -40,7 +41,29 @@ func TestRegisterBuiltInAPIProviders(t *testing.T) { if evt.Error.StopReason != ai.StopReasonError { t.Fatalf("expected stopReason=error, got %s", evt.Error.StopReason) } + if strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected openai responses runtime implementation, got stub error: %q", evt.Error.ErrorMessage) + } if _, err := stream.Next(ctx); err != io.EOF { t.Fatalf("expected EOF after terminal event, got %v", err) } + + completionsStream, err := ai.Stream(ai.Model{ + ID: "openai/gpt-4o-mini", + Provider: "openrouter", + API: ai.APIOpenAICompletions, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected completions stream resolve error: %v", err) + } + completionsEvt, err := completionsStream.Next(ctx) + if err != nil { + t.Fatalf("expected completions terminal error event, got %v", err) + } + if completionsEvt.Type != ai.EventError { + t.Fatalf("expected completions error event, got %s", completionsEvt.Type) + } + if strings.Contains(strings.ToLower(completionsEvt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected openai completions runtime implementation, got stub error: %q", completionsEvt.Error.ErrorMessage) + } } From 9f317530da85af37b65c7d9047e0f1f91034c51b Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:40:41 +0000 Subject: [PATCH 37/75] Add Codex responses payload helper parity functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/openai_codex_responses.go | 115 ++++++++++++++++++ .../providers/openai_codex_responses_test.go | 86 +++++++++++++ 2 files changed, 201 insertions(+) diff --git a/pkg/ai/providers/openai_codex_responses.go b/pkg/ai/providers/openai_codex_responses.go index 5ec752a2..6b1f77cc 100644 --- a/pkg/ai/providers/openai_codex_responses.go +++ b/pkg/ai/providers/openai_codex_responses.go @@ -3,11 +3,126 @@ package providers import ( "encoding/json" "fmt" + "net/url" "strings" "github.com/beeper/ai-bridge/pkg/ai" ) +const defaultCodexBaseURL = "https://chatgpt.com/backend-api" + +var codexToolCallProviders = map[string]struct{}{ + "openai": {}, + "openai-codex": {}, + "opencode": {}, +} + +type OpenAICodexResponsesOptions struct { + StreamOptions ai.StreamOptions + ReasoningEffort string + ReasoningSummary string + TextVerbosity string +} + +func BuildOpenAICodexResponsesParams(model ai.Model, context ai.Context, options OpenAICodexResponsesOptions) map[string]any { + messages := ConvertResponsesMessages(model, context, codexToolCallProviders, &ConvertResponsesMessagesOptions{ + IncludeSystemPrompt: false, + }) + params := map[string]any{ + "model": model.ID, + "store": false, + "stream": true, + "instructions": context.SystemPrompt, + "input": messages, + "include": []string{"reasoning.encrypted_content"}, + "tool_choice": "auto", + "parallel_tool_calls": true, + "text": map[string]any{ + "verbosity": coalesceCodexTextVerbosity(options.TextVerbosity), + }, + } + if strings.TrimSpace(options.StreamOptions.SessionID) != "" { + params["prompt_cache_key"] = options.StreamOptions.SessionID + } + if options.StreamOptions.Temperature != nil { + params["temperature"] = *options.StreamOptions.Temperature + } + if len(context.Tools) > 0 { + params["tools"] = ConvertResponsesTools(context.Tools, false) + } + if strings.TrimSpace(options.ReasoningEffort) != "" { + effort := ClampCodexReasoningEffort(model.ID, options.ReasoningEffort) + summary := strings.TrimSpace(options.ReasoningSummary) + if summary == "" { + summary = "auto" + } + params["reasoning"] = map[string]any{ + "effort": effort, + "summary": summary, + } + } + return params +} + +func ClampCodexReasoningEffort(modelID string, effort string) string { + effort = strings.ToLower(strings.TrimSpace(effort)) + id := modelID + if strings.Contains(id, "/") { + parts := strings.Split(id, "/") + id = parts[len(parts)-1] + } + if (strings.HasPrefix(id, "gpt-5.2") || strings.HasPrefix(id, "gpt-5.3")) && effort == "minimal" { + return "low" + } + if id == "gpt-5.1" && effort == "xhigh" { + return "high" + } + if id == "gpt-5.1-codex-mini" { + if effort == "high" || effort == "xhigh" { + return "high" + } + return "medium" + } + return effort +} + +func ResolveCodexURL(baseURL string) string { + normalized := strings.TrimRight(strings.TrimSpace(baseURL), "/") + if normalized == "" { + normalized = defaultCodexBaseURL + } + if strings.HasSuffix(normalized, "/codex/responses") { + return normalized + } + if strings.HasSuffix(normalized, "/codex") { + return normalized + "/responses" + } + return normalized + "/codex/responses" +} + +func ResolveCodexWebSocketURL(baseURL string) string { + resolved := ResolveCodexURL(baseURL) + parsed, err := url.Parse(resolved) + if err != nil { + return resolved + } + if parsed.Scheme == "https" { + parsed.Scheme = "wss" + } else if parsed.Scheme == "http" { + parsed.Scheme = "ws" + } + return parsed.String() +} + +func coalesceCodexTextVerbosity(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "low", "high": + return strings.ToLower(strings.TrimSpace(value)) + default: + return "medium" + } +} + // ProcessCodexSSEPayload maps Codex SSE payload chunks into unified stream events. // This is a deterministic helper used by tests while the full transport integration // is being ported. diff --git a/pkg/ai/providers/openai_codex_responses_test.go b/pkg/ai/providers/openai_codex_responses_test.go index 9df5e81d..0155eb4e 100644 --- a/pkg/ai/providers/openai_codex_responses_test.go +++ b/pkg/ai/providers/openai_codex_responses_test.go @@ -1,6 +1,7 @@ package providers import ( + "encoding/json" "strings" "testing" @@ -63,3 +64,88 @@ func TestProcessCodexSSEPayload_MapsToAssistantEvents(t *testing.T) { t.Fatalf("unexpected output content: %+v", output.Content) } } + +func TestClampCodexReasoningEffort(t *testing.T) { + cases := []struct { + modelID string + effort string + want string + }{ + {modelID: "gpt-5.2", effort: "minimal", want: "low"}, + {modelID: "openai/gpt-5.3-pro", effort: "minimal", want: "low"}, + {modelID: "gpt-5.1", effort: "xhigh", want: "high"}, + {modelID: "gpt-5.1-codex-mini", effort: "low", want: "medium"}, + {modelID: "gpt-5.1-codex-mini", effort: "xhigh", want: "high"}, + {modelID: "gpt-4.1-mini", effort: "high", want: "high"}, + } + for _, tc := range cases { + got := ClampCodexReasoningEffort(tc.modelID, tc.effort) + if got != tc.want { + t.Fatalf("ClampCodexReasoningEffort(%q,%q)=%q want %q", tc.modelID, tc.effort, got, tc.want) + } + } +} + +func TestResolveCodexURLAndWebSocketURL(t *testing.T) { + if got := ResolveCodexURL(""); got != "https://chatgpt.com/backend-api/codex/responses" { + t.Fatalf("unexpected default codex URL: %q", got) + } + if got := ResolveCodexURL("https://chatgpt.com/backend-api/codex"); got != "https://chatgpt.com/backend-api/codex/responses" { + t.Fatalf("unexpected codex URL for /codex base: %q", got) + } + if got := ResolveCodexURL("https://chatgpt.com/backend-api/codex/responses"); got != "https://chatgpt.com/backend-api/codex/responses" { + t.Fatalf("unexpected codex URL when already resolved: %q", got) + } + if got := ResolveCodexWebSocketURL("https://chatgpt.com/backend-api"); !strings.HasPrefix(got, "wss://") { + t.Fatalf("expected websocket URL to use wss scheme, got %q", got) + } +} + +func TestBuildOpenAICodexResponsesParams(t *testing.T) { + temp := 0.2 + params := BuildOpenAICodexResponsesParams(ai.Model{ + ID: "gpt-5.1-codex-mini", + Provider: "openai-codex", + API: ai.APIOpenAICodexResponse, + }, ai.Context{ + SystemPrompt: "you are helpful", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "say hi"}, + }, + Tools: []ai.Tool{ + {Name: "lookup", Description: "Lookup docs", Parameters: map[string]any{"type": "object"}}, + }, + }, OpenAICodexResponsesOptions{ + StreamOptions: ai.StreamOptions{ + SessionID: "session-1", + Temperature: &temp, + }, + ReasoningEffort: "xhigh", + ReasoningSummary: "detailed", + }) + + if params["model"] != "gpt-5.1-codex-mini" { + t.Fatalf("unexpected model in params: %#v", params["model"]) + } + if params["prompt_cache_key"] != "session-1" { + t.Fatalf("expected prompt cache key, got %#v", params["prompt_cache_key"]) + } + reasoning, ok := params["reasoning"].(map[string]any) + if !ok { + t.Fatalf("expected reasoning object in params") + } + if reasoning["effort"] != "high" { + t.Fatalf("expected xhigh clamp to high for codex-mini, got %#v", reasoning["effort"]) + } + if reasoning["summary"] != "detailed" { + t.Fatalf("expected reasoning summary detailed, got %#v", reasoning["summary"]) + } + tools, ok := params["tools"].([]map[string]any) + if !ok || len(tools) != 1 { + t.Fatalf("expected one tool payload, got %#v", params["tools"]) + } + body, _ := json.Marshal(params) + if !strings.Contains(string(body), "reasoning.encrypted_content") { + t.Fatalf("expected include reasoning encrypted content in payload") + } +} From bff4cda7dcfcf17a5460ed2b0b9ce09faf8e22f2 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:47:13 +0000 Subject: [PATCH 38/75] Add pkg-ai OpenAI Codex responses streaming runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- .../openai_codex_responses_runtime.go | 205 ++++++++++++++++++ .../openai_codex_responses_runtime_test.go | 74 +++++++ pkg/ai/providers/register_builtins.go | 6 +- pkg/ai/providers/register_builtins_test.go | 19 ++ 4 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 pkg/ai/providers/openai_codex_responses_runtime.go create mode 100644 pkg/ai/providers/openai_codex_responses_runtime_test.go diff --git a/pkg/ai/providers/openai_codex_responses_runtime.go b/pkg/ai/providers/openai_codex_responses_runtime.go new file mode 100644 index 00000000..286dd576 --- /dev/null +++ b/pkg/ai/providers/openai_codex_responses_runtime.go @@ -0,0 +1,205 @@ +package providers + +import ( + "context" + "encoding/base64" + "encoding/json" + "strings" + "time" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/shared/httputil" +) + +const codexJWTClaimPath = "https://api.openai.com/auth" + +func streamOpenAICodexResponses(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + codexOptions := OpenAICodexResponsesOptions{} + if options != nil { + codexOptions.StreamOptions = *options + } + return streamOpenAICodexResponsesWithOptions(model, c, codexOptions) +} + +func streamSimpleOpenAICodexResponses(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + effort := "" + if options != nil && options.Reasoning != "" { + reasoning := options.Reasoning + if !ai.SupportsXhigh(model) { + reasoning = ClampReasoning(reasoning) + } + effort = string(reasoning) + } + return streamOpenAICodexResponsesWithOptions(model, c, OpenAICodexResponsesOptions{ + StreamOptions: base, + ReasoningEffort: effort, + }) +} + +func streamOpenAICodexResponsesWithOptions( + model ai.Model, + c ai.Context, + options OpenAICodexResponsesOptions, +) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + apiKey := strings.TrimSpace(options.StreamOptions.APIKey) + if apiKey == "" { + apiKey = strings.TrimSpace(ai.GetEnvAPIKey(string(model.Provider))) + } + if apiKey == "" { + pushProviderError(stream, model, "missing API key for OpenAI Codex responses runtime") + return + } + + payload := BuildOpenAICodexResponsesParams(model, c, options) + if options.StreamOptions.OnPayload != nil { + options.StreamOptions.OnPayload(payload) + } + request := param.Override[responses.ResponseNewParams](payload) + + baseURL := resolveCodexSDKBaseURL(model.BaseURL) + reqOptions := []option.RequestOption{ + option.WithAPIKey(apiKey), + option.WithBaseURL(baseURL), + option.WithHeader("OpenAI-Beta", "responses=experimental"), + option.WithHeader("originator", "pi"), + } + if accountID := extractCodexAccountID(apiKey); accountID != "" { + reqOptions = append(reqOptions, option.WithHeader("chatgpt-account-id", accountID)) + } + reqOptions = httputil.AppendHeaderOptions(reqOptions, model.Headers) + reqOptions = httputil.AppendHeaderOptions(reqOptions, options.StreamOptions.Headers) + + client := openai.NewClient(reqOptions...) + runCtx := options.StreamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + + openAIStream := client.Responses.NewStreaming(runCtx, request) + if openAIStream == nil { + pushProviderError(stream, model, "failed to create OpenAI Codex responses stream") + return + } + + var textBuilder strings.Builder + var thinkingBuilder strings.Builder + toolCalls := make([]ai.ContentBlock, 0, 2) + var completedResponse responses.Response + + for openAIStream.Next() { + event := openAIStream.Current() + switch event.Type { + case "response.output_text.delta": + textBuilder.WriteString(event.Delta) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventTextDelta, Delta: event.Delta}) + case "response.reasoning_text.delta": + thinkingBuilder.WriteString(event.Delta) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventThinkingDelta, Delta: event.Delta}) + case "response.function_call_arguments.done": + toolCall := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: strings.TrimSpace(event.ItemID), + Name: strings.TrimSpace(event.Name), + Arguments: parseToolArguments(event.Arguments), + } + toolCalls = append(toolCalls, toolCall) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventToolCallEnd, ToolCall: &toolCall}) + case "response.completed": + completedResponse = event.Response + case "error": + pushProviderError(stream, model, strings.TrimSpace(event.Message)) + return + } + } + + if err := openAIStream.Err(); err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + assistantMessage := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + Timestamp: time.Now().UnixMilli(), + StopReason: mapOpenAIResponseStatus(completedResponse.Status), + Usage: ai.Usage{ + Input: int(completedResponse.Usage.InputTokens), + Output: int(completedResponse.Usage.OutputTokens), + TotalTokens: int(completedResponse.Usage.TotalTokens), + }, + } + if thinking := strings.TrimSpace(thinkingBuilder.String()); thinking != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: thinking, + }) + } + if text := strings.TrimSpace(textBuilder.String()); text != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: text, + }) + } + if len(toolCalls) > 0 { + assistantMessage.Content = append(assistantMessage.Content, toolCalls...) + } + if len(toolCalls) > 0 && assistantMessage.StopReason == ai.StopReasonStop { + assistantMessage.StopReason = ai.StopReasonToolUse + } + assistantMessage.Usage.Cost = ai.CalculateCost(model, assistantMessage.Usage) + + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: assistantMessage, + Reason: assistantMessage.StopReason, + }) + }() + return stream +} + +func resolveCodexSDKBaseURL(baseURL string) string { + resolved := strings.TrimSpace(baseURL) + if resolved == "" { + return strings.TrimRight(defaultCodexBaseURL, "/") + "/codex" + } + resolved = strings.TrimRight(resolved, "/") + if strings.HasSuffix(resolved, "/codex/responses") { + return strings.TrimSuffix(resolved, "/responses") + } + if strings.HasSuffix(resolved, "/codex") { + return resolved + } + return resolved + "/codex" +} + +func extractCodexAccountID(token string) string { + parts := strings.Split(strings.TrimSpace(token), ".") + if len(parts) != 3 { + return "" + } + payload := parts[1] + decoded, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + return "" + } + claims := map[string]any{} + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + authClaims, ok := claims[codexJWTClaimPath].(map[string]any) + if !ok { + return "" + } + accountID, _ := authClaims["chatgpt_account_id"].(string) + return strings.TrimSpace(accountID) +} diff --git a/pkg/ai/providers/openai_codex_responses_runtime_test.go b/pkg/ai/providers/openai_codex_responses_runtime_test.go new file mode 100644 index 00000000..4303f0b1 --- /dev/null +++ b/pkg/ai/providers/openai_codex_responses_runtime_test.go @@ -0,0 +1,74 @@ +package providers + +import ( + "context" + "encoding/base64" + "encoding/json" + "io" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamOpenAICodexResponses_MissingAPIKeyEmitsError(t *testing.T) { + stream := streamOpenAICodexResponses(ai.Model{ + ID: "gpt-5.1-codex-mini", + Provider: "openai-codex", + API: ai.APIOpenAICodexResponse, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "api key") { + t.Fatalf("expected missing api key message, got %q", evt.Error.ErrorMessage) + } + if _, err := stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestResolveCodexSDKBaseURL(t *testing.T) { + cases := []struct { + in string + want string + }{ + {in: "", want: "https://chatgpt.com/backend-api/codex"}, + {in: "https://chatgpt.com/backend-api", want: "https://chatgpt.com/backend-api/codex"}, + {in: "https://chatgpt.com/backend-api/codex", want: "https://chatgpt.com/backend-api/codex"}, + {in: "https://chatgpt.com/backend-api/codex/responses", want: "https://chatgpt.com/backend-api/codex"}, + } + for _, tc := range cases { + if got := resolveCodexSDKBaseURL(tc.in); got != tc.want { + t.Fatalf("resolveCodexSDKBaseURL(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestExtractCodexAccountID(t *testing.T) { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none"}`)) + payloadObj := map[string]any{ + codexJWTClaimPath: map[string]any{ + "chatgpt_account_id": "acct_123", + }, + } + payloadBytes, _ := json.Marshal(payloadObj) + payload := base64.RawURLEncoding.EncodeToString(payloadBytes) + token := header + "." + payload + ".sig" + if got := extractCodexAccountID(token); got != "acct_123" { + t.Fatalf("expected account id acct_123, got %q", got) + } + if got := extractCodexAccountID("not-a-jwt"); got != "" { + t.Fatalf("expected empty account id for invalid token, got %q", got) + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index 6ec16222..f9355d83 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -65,9 +65,13 @@ func RegisterBuiltInAPIProviders() { Stream: streamAzureOpenAIResponses, StreamSimple: streamSimpleAzureOpenAIResponses, }, BuiltinProviderSourceID) + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIOpenAICodexResponse, + Stream: streamOpenAICodexResponses, + StreamSimple: streamSimpleOpenAICodexResponses, + }, BuiltinProviderSourceID) for _, apiID := range []ai.Api{ - ai.APIOpenAICodexResponse, ai.APIAnthropicMessages, ai.APIGoogleGenerativeAI, ai.APIGoogleGeminiCLI, diff --git a/pkg/ai/providers/register_builtins_test.go b/pkg/ai/providers/register_builtins_test.go index 3cb80c75..94f6851b 100644 --- a/pkg/ai/providers/register_builtins_test.go +++ b/pkg/ai/providers/register_builtins_test.go @@ -66,4 +66,23 @@ func TestRegisterBuiltInAPIProviders(t *testing.T) { if strings.Contains(strings.ToLower(completionsEvt.Error.ErrorMessage), "not implemented") { t.Fatalf("expected openai completions runtime implementation, got stub error: %q", completionsEvt.Error.ErrorMessage) } + + codexStream, err := ai.Stream(ai.Model{ + ID: "gpt-5.1-codex-mini", + Provider: "openai-codex", + API: ai.APIOpenAICodexResponse, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected codex stream resolve error: %v", err) + } + codexEvt, err := codexStream.Next(ctx) + if err != nil { + t.Fatalf("expected codex terminal error event, got %v", err) + } + if codexEvt.Type != ai.EventError { + t.Fatalf("expected codex error event, got %s", codexEvt.Type) + } + if strings.Contains(strings.ToLower(codexEvt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected codex runtime implementation, got stub error: %q", codexEvt.Error.ErrorMessage) + } } From 5d99997ffa3320c5a76245e7736d40362371e11d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 05:53:16 +0000 Subject: [PATCH 39/75] Add vertex and bedrock env auth key parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/env_api_keys.go | 41 ++++++++++++++++++++++++++++- pkg/ai/env_api_keys_test.go | 51 +++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 pkg/ai/env_api_keys_test.go diff --git a/pkg/ai/env_api_keys.go b/pkg/ai/env_api_keys.go index b65cb87e..0bc6da47 100644 --- a/pkg/ai/env_api_keys.go +++ b/pkg/ai/env_api_keys.go @@ -1,6 +1,9 @@ package ai -import "os" +import ( + "os" + "path/filepath" +) func GetEnvAPIKey(provider string) string { switch provider { @@ -21,6 +24,21 @@ func GetEnvAPIKey(provider string) string { return os.Getenv("OPENAI_API_KEY") case "azure-openai-responses": return os.Getenv("AZURE_OPENAI_API_KEY") + case "google-vertex": + if hasVertexADCCredentials() && hasVertexProject() && os.Getenv("GOOGLE_CLOUD_LOCATION") != "" { + return "" + } + return "" + case "amazon-bedrock": + if os.Getenv("AWS_PROFILE") != "" || + (os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "") || + os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "" || + os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || + os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" || + os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE") != "" { + return "" + } + return "" case "google": return os.Getenv("GEMINI_API_KEY") case "groq": @@ -51,3 +69,24 @@ func GetEnvAPIKey(provider string) string { return "" } } + +func hasVertexProject() bool { + return os.Getenv("GOOGLE_CLOUD_PROJECT") != "" || os.Getenv("GCLOUD_PROJECT") != "" +} + +func hasVertexADCCredentials() bool { + if path := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); path != "" { + if _, err := os.Stat(path); err == nil { + return true + } + } + home, err := os.UserHomeDir() + if err != nil || home == "" { + return false + } + adcPath := filepath.Join(home, ".config", "gcloud", "application_default_credentials.json") + if _, err := os.Stat(adcPath); err == nil { + return true + } + return false +} diff --git a/pkg/ai/env_api_keys_test.go b/pkg/ai/env_api_keys_test.go new file mode 100644 index 00000000..b9990420 --- /dev/null +++ b/pkg/ai/env_api_keys_test.go @@ -0,0 +1,51 @@ +package ai + +import ( + "os" + "path/filepath" + "testing" +) + +func TestGetEnvAPIKey_GoogleVertexAuthenticated(t *testing.T) { + home := t.TempDir() + adcPath := filepath.Join(home, ".config", "gcloud", "application_default_credentials.json") + if err := os.MkdirAll(filepath.Dir(adcPath), 0o755); err != nil { + t.Fatalf("failed to create ADC directory: %v", err) + } + if err := os.WriteFile(adcPath, []byte(`{"type":"authorized_user"}`), 0o600); err != nil { + t.Fatalf("failed to write ADC file: %v", err) + } + + t.Setenv("HOME", home) + t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "") + t.Setenv("GOOGLE_CLOUD_PROJECT", "test-project") + t.Setenv("GOOGLE_CLOUD_LOCATION", "us-central1") + + if got := GetEnvAPIKey("google-vertex"); got != "" { + t.Fatalf("expected for google-vertex, got %q", got) + } +} + +func TestGetEnvAPIKey_GoogleVertexMissingContext(t *testing.T) { + t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "") + t.Setenv("GOOGLE_CLOUD_PROJECT", "") + t.Setenv("GCLOUD_PROJECT", "") + t.Setenv("GOOGLE_CLOUD_LOCATION", "") + if got := GetEnvAPIKey("google-vertex"); got != "" { + t.Fatalf("expected empty key when google-vertex env incomplete, got %q", got) + } +} + +func TestGetEnvAPIKey_AmazonBedrockAuthenticated(t *testing.T) { + t.Setenv("AWS_PROFILE", "default") + if got := GetEnvAPIKey("amazon-bedrock"); got != "" { + t.Fatalf("expected for amazon-bedrock profile auth, got %q", got) + } + + t.Setenv("AWS_PROFILE", "") + t.Setenv("AWS_ACCESS_KEY_ID", "AKIA123") + t.Setenv("AWS_SECRET_ACCESS_KEY", "secret") + if got := GetEnvAPIKey("amazon-bedrock"); got != "" { + t.Fatalf("expected for amazon-bedrock IAM auth, got %q", got) + } +} From 71f49b4d916295f9edc6c9438fa8afe589be9d6e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 06:11:40 +0000 Subject: [PATCH 40/75] Add pkg-ai Anthropic messages streaming runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- go.mod | 1 + go.sum | 2 + pkg/ai/providers/anthropic_runtime.go | 283 +++++++++++++++++++++ pkg/ai/providers/anthropic_runtime_test.go | 79 ++++++ pkg/ai/providers/register_builtins.go | 6 +- pkg/ai/providers/register_builtins_test.go | 19 ++ 6 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 pkg/ai/providers/anthropic_runtime.go create mode 100644 pkg/ai/providers/anthropic_runtime_test.go diff --git a/go.mod b/go.mod index 1a74ed89..3f589dcc 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( require ( filippo.io/edwards25519 v1.2.0 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect + github.com/anthropics/anthropic-sdk-go v1.26.0 // indirect github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/google/jsonschema-go v0.3.0 // indirect diff --git a/go.sum b/go.sum index 8d66122d..d5960f8b 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/PuerkitoBio/goquery v1.11.0 h1:jZ7pwMQXIITcUXNH83LLk+txlaEy6NVOfTuP43 github.com/PuerkitoBio/goquery v1.11.0/go.mod h1:wQHgxUOU3JGuj3oD/QFfxUdlzW6xPHfqyHre6VMY4DQ= github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= +github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY= +github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q= github.com/beeper/bridge-manager v0.14.0 h1:7XeZfHeDiOuwLUe6UiX/HCywthw1s0Q7xhrmDzzW9FA= github.com/beeper/bridge-manager v0.14.0/go.mod h1:pherlTADz3wkojdc2AvAsR3mS1yG5jF9/OaxkHqPy4Y= github.com/beeper/desktop-api-go v0.2.0 h1:VrwB1FCEiuPycGo6TsYSVVSKQIWFg22xmlRWVJ88E0A= diff --git a/pkg/ai/providers/anthropic_runtime.go b/pkg/ai/providers/anthropic_runtime.go new file mode 100644 index 00000000..fb7d3ff0 --- /dev/null +++ b/pkg/ai/providers/anthropic_runtime.go @@ -0,0 +1,283 @@ +package providers + +import ( + "context" + "strings" + "time" + + anthropic "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" + anthropicparam "github.com/anthropics/anthropic-sdk-go/packages/param" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func streamAnthropicMessages(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + anthropicOptions := AnthropicOptions{} + if options != nil { + anthropicOptions.StreamOptions = *options + } + return streamAnthropicMessagesWithOptions(model, c, anthropicOptions) +} + +func streamSimpleAnthropicMessages(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + if options == nil || options.Reasoning == "" { + return streamAnthropicMessagesWithOptions(model, c, AnthropicOptions{ + StreamOptions: base, + ThinkingEnabled: false, + }) + } + if supportsAdaptiveThinkingModel(model.ID) { + return streamAnthropicMessagesWithOptions(model, c, AnthropicOptions{ + StreamOptions: base, + ThinkingEnabled: true, + Effort: mapAnthropicThinkingEffort(model.ID, options.Reasoning), + }) + } + + adjustedMaxTokens, thinkingBudget := AdjustMaxTokensForThinking( + base.MaxTokens, + model.MaxTokens, + options.Reasoning, + options.ThinkingBudgets, + ) + base.MaxTokens = adjustedMaxTokens + return streamAnthropicMessagesWithOptions(model, c, AnthropicOptions{ + StreamOptions: base, + ThinkingEnabled: true, + ThinkingBudgetTokens: thinkingBudget, + }) +} + +func streamAnthropicMessagesWithOptions( + model ai.Model, + c ai.Context, + options AnthropicOptions, +) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + apiKey := strings.TrimSpace(options.StreamOptions.APIKey) + if apiKey == "" { + apiKey = strings.TrimSpace(ai.GetEnvAPIKey(string(model.Provider))) + } + if apiKey == "" { + pushProviderError(stream, model, "missing API key for Anthropic messages runtime") + return + } + + payload := BuildAnthropicParams(model, c, options) + if options.StreamOptions.OnPayload != nil { + options.StreamOptions.OnPayload(payload) + } + betaHeader := "" + if rawBeta, ok := payload["anthropic-beta"].(string); ok { + betaHeader = strings.TrimSpace(rawBeta) + delete(payload, "anthropic-beta") + } + request := anthropicparam.Override[anthropic.MessageNewParams](payload) + + reqOptions := []anthropicoption.RequestOption{} + if isOAuthAnthropicToken(apiKey) || model.Provider == "github-copilot" { + reqOptions = append(reqOptions, anthropicoption.WithAuthToken(apiKey)) + } else { + reqOptions = append(reqOptions, anthropicoption.WithAPIKey(apiKey)) + } + if baseURL := strings.TrimSpace(model.BaseURL); baseURL != "" { + reqOptions = append(reqOptions, anthropicoption.WithBaseURL(baseURL)) + } + if betaHeader != "" { + reqOptions = append(reqOptions, anthropicoption.WithHeader("anthropic-beta", betaHeader)) + } + reqOptions = appendAnthropicHeaderOptions(reqOptions, model.Headers) + reqOptions = appendAnthropicHeaderOptions(reqOptions, options.StreamOptions.Headers) + + client := anthropic.NewClient(reqOptions...) + runCtx := options.StreamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + + anthropicStream := client.Messages.NewStreaming(runCtx, request) + if anthropicStream == nil { + pushProviderError(stream, model, "failed to create Anthropic messages stream") + return + } + + accumulated := anthropic.Message{} + for anthropicStream.Next() { + event := anthropicStream.Current() + if err := (&accumulated).Accumulate(event); err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + switch eventVariant := event.AsAny().(type) { + case anthropic.ContentBlockDeltaEvent: + switch deltaVariant := eventVariant.Delta.AsAny().(type) { + case anthropic.TextDelta: + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventTextDelta, + Delta: deltaVariant.Text, + }) + case anthropic.ThinkingDelta: + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventThinkingDelta, + Delta: deltaVariant.Thinking, + }) + } + case anthropic.ContentBlockStopEvent: + contentIndex := int(eventVariant.Index) + if contentIndex < 0 || contentIndex >= len(accumulated.Content) { + continue + } + if toolUse, ok := accumulated.Content[contentIndex].AsAny().(anthropic.ToolUseBlock); ok { + toolCall := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: strings.TrimSpace(toolUse.ID), + Name: strings.TrimSpace(toolUse.Name), + Arguments: parseToolArguments(string(toolUse.Input)), + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventToolCallEnd, + ContentIndex: contentIndex, + ToolCall: &toolCall, + }) + } + } + } + + if err := anthropicStream.Err(); err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + assistantMessage := anthropicMessageToAIMessage(model, accumulated) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: assistantMessage, + Reason: assistantMessage.StopReason, + }) + }() + return stream +} + +func anthropicMessageToAIMessage(model ai.Model, msg anthropic.Message) ai.Message { + out := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + Timestamp: time.Now().UnixMilli(), + StopReason: mapAnthropicStopReason(msg.StopReason), + Usage: ai.Usage{ + Input: int(msg.Usage.InputTokens), + Output: int(msg.Usage.OutputTokens), + CacheRead: int(msg.Usage.CacheReadInputTokens), + CacheWrite: int(msg.Usage.CacheCreationInputTokens), + }, + } + + for _, block := range msg.Content { + switch blockVariant := block.AsAny().(type) { + case anthropic.TextBlock: + if strings.TrimSpace(blockVariant.Text) == "" { + continue + } + out.Content = append(out.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: blockVariant.Text, + }) + case anthropic.ThinkingBlock: + if strings.TrimSpace(blockVariant.Thinking) == "" { + continue + } + out.Content = append(out.Content, ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: blockVariant.Thinking, + ThinkingSignature: blockVariant.Signature, + }) + case anthropic.RedactedThinkingBlock: + out.Content = append(out.Content, ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: blockVariant.Data, + Redacted: true, + }) + case anthropic.ToolUseBlock: + out.Content = append(out.Content, ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: strings.TrimSpace(blockVariant.ID), + Name: strings.TrimSpace(blockVariant.Name), + Arguments: parseToolArguments(string(blockVariant.Input)), + }) + } + } + + out.Usage.TotalTokens = out.Usage.Input + out.Usage.Output + out.Usage.CacheRead + out.Usage.CacheWrite + out.Usage.Cost = ai.CalculateCost(model, out.Usage) + if out.StopReason == ai.StopReasonStop { + for _, block := range out.Content { + if block.Type == ai.ContentTypeToolCall { + out.StopReason = ai.StopReasonToolUse + break + } + } + } + return out +} + +func appendAnthropicHeaderOptions( + opts []anthropicoption.RequestOption, + headers map[string]string, +) []anthropicoption.RequestOption { + for key, value := range headers { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + opts = append(opts, anthropicoption.WithHeader(key, trimmed)) + } + return opts +} + +func isOAuthAnthropicToken(apiKey string) bool { + return strings.Contains(apiKey, "sk-ant-oat") +} + +func supportsAdaptiveThinkingModel(modelID string) bool { + id := strings.ToLower(strings.TrimSpace(modelID)) + return strings.Contains(id, "opus-4-6") || strings.Contains(id, "opus-4.6") || + strings.Contains(id, "sonnet-4-6") || strings.Contains(id, "sonnet-4.6") +} + +func mapAnthropicThinkingEffort(modelID string, level ai.ThinkingLevel) string { + switch level { + case ai.ThinkingMinimal, ai.ThinkingLow: + return "low" + case ai.ThinkingMedium: + return "medium" + case ai.ThinkingHigh: + return "high" + case ai.ThinkingXHigh: + id := strings.ToLower(strings.TrimSpace(modelID)) + if strings.Contains(id, "opus-4-6") || strings.Contains(id, "opus-4.6") { + return "max" + } + return "high" + default: + return "high" + } +} + +func mapAnthropicStopReason(reason anthropic.StopReason) ai.StopReason { + switch reason { + case anthropic.StopReasonMaxTokens: + return ai.StopReasonLength + case anthropic.StopReasonToolUse: + return ai.StopReasonToolUse + case anthropic.StopReasonEndTurn, anthropic.StopReasonStopSequence: + return ai.StopReasonStop + default: + return ai.StopReasonStop + } +} diff --git a/pkg/ai/providers/anthropic_runtime_test.go b/pkg/ai/providers/anthropic_runtime_test.go new file mode 100644 index 00000000..ef6a9a11 --- /dev/null +++ b/pkg/ai/providers/anthropic_runtime_test.go @@ -0,0 +1,79 @@ +package providers + +import ( + "context" + "io" + "strings" + "testing" + "time" + + anthropic "github.com/anthropics/anthropic-sdk-go" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamAnthropicMessages_MissingAPIKeyEmitsError(t *testing.T) { + t.Setenv("ANTHROPIC_OAUTH_TOKEN", "") + t.Setenv("ANTHROPIC_API_KEY", "") + stream := streamAnthropicMessages(ai.Model{ + ID: "claude-sonnet-4-5", + Provider: "anthropic", + API: ai.APIAnthropicMessages, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "api key") { + t.Fatalf("expected missing api key message, got %q", evt.Error.ErrorMessage) + } + if _, err = stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestMapAnthropicStopReason(t *testing.T) { + cases := map[anthropic.StopReason]ai.StopReason{ + anthropic.StopReasonEndTurn: ai.StopReasonStop, + anthropic.StopReasonStopSequence: ai.StopReasonStop, + anthropic.StopReasonMaxTokens: ai.StopReasonLength, + anthropic.StopReasonToolUse: ai.StopReasonToolUse, + } + for in, want := range cases { + if got := mapAnthropicStopReason(in); got != want { + t.Fatalf("mapAnthropicStopReason(%q) = %q, want %q", in, got, want) + } + } +} + +func TestMapAnthropicThinkingEffort(t *testing.T) { + if got := mapAnthropicThinkingEffort("claude-opus-4-6", ai.ThinkingXHigh); got != "max" { + t.Fatalf("expected xhigh on opus 4.6 to map to max, got %q", got) + } + if got := mapAnthropicThinkingEffort("claude-sonnet-4-6", ai.ThinkingXHigh); got != "high" { + t.Fatalf("expected xhigh on sonnet 4.6 to map to high, got %q", got) + } + if got := mapAnthropicThinkingEffort("claude-sonnet-4-5", ai.ThinkingMinimal); got != "low" { + t.Fatalf("expected minimal to map to low, got %q", got) + } +} + +func TestSupportsAdaptiveThinkingModel(t *testing.T) { + if !supportsAdaptiveThinkingModel("claude-opus-4-6") { + t.Fatalf("expected opus 4.6 to support adaptive thinking") + } + if !supportsAdaptiveThinkingModel("claude-sonnet-4.6") { + t.Fatalf("expected sonnet 4.6 to support adaptive thinking") + } + if supportsAdaptiveThinkingModel("claude-sonnet-4-5") { + t.Fatalf("did not expect sonnet 4.5 to support adaptive thinking") + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index f9355d83..18f254a6 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -70,9 +70,13 @@ func RegisterBuiltInAPIProviders() { Stream: streamOpenAICodexResponses, StreamSimple: streamSimpleOpenAICodexResponses, }, BuiltinProviderSourceID) + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIAnthropicMessages, + Stream: streamAnthropicMessages, + StreamSimple: streamSimpleAnthropicMessages, + }, BuiltinProviderSourceID) for _, apiID := range []ai.Api{ - ai.APIAnthropicMessages, ai.APIGoogleGenerativeAI, ai.APIGoogleGeminiCLI, ai.APIGoogleVertex, diff --git a/pkg/ai/providers/register_builtins_test.go b/pkg/ai/providers/register_builtins_test.go index 94f6851b..887ed645 100644 --- a/pkg/ai/providers/register_builtins_test.go +++ b/pkg/ai/providers/register_builtins_test.go @@ -85,4 +85,23 @@ func TestRegisterBuiltInAPIProviders(t *testing.T) { if strings.Contains(strings.ToLower(codexEvt.Error.ErrorMessage), "not implemented") { t.Fatalf("expected codex runtime implementation, got stub error: %q", codexEvt.Error.ErrorMessage) } + + anthropicStream, err := ai.Stream(ai.Model{ + ID: "claude-sonnet-4-5", + Provider: "anthropic", + API: ai.APIAnthropicMessages, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected anthropic stream resolve error: %v", err) + } + anthropicEvt, err := anthropicStream.Next(ctx) + if err != nil { + t.Fatalf("expected anthropic terminal error event, got %v", err) + } + if anthropicEvt.Type != ai.EventError { + t.Fatalf("expected anthropic error event, got %s", anthropicEvt.Type) + } + if strings.Contains(strings.ToLower(anthropicEvt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected anthropic runtime implementation, got stub error: %q", anthropicEvt.Error.ErrorMessage) + } } From cca7fccd4447136a20c8a8928786b152a171da8e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 06:17:10 +0000 Subject: [PATCH 41/75] Add Gemini CLI empty-stream retry parity helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/google_gemini_cli.go | 16 +++++++++++++ pkg/ai/providers/google_gemini_cli_test.go | 28 ++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/pkg/ai/providers/google_gemini_cli.go b/pkg/ai/providers/google_gemini_cli.go index 149d5d7b..4f4db9f0 100644 --- a/pkg/ai/providers/google_gemini_cli.go +++ b/pkg/ai/providers/google_gemini_cli.go @@ -11,6 +11,10 @@ import ( ) const ClaudeThinkingBetaHeader = "interleaved-thinking-2025-05-14" +const ( + MaxGeminiEmptyStreamRetries = 2 + EmptyStreamBaseDelayMs = 500 +) func ExtractRetryDelay(errorText string, headers http.Header) (int, bool) { return extractRetryDelayAt(errorText, headers, time.Now()) @@ -148,3 +152,15 @@ func BuildGeminiCLIHeaders(model ai.Model, headers map[string]string) map[string } return out } + +func GeminiEmptyStreamBackoff(attempt int) (time.Duration, bool) { + if attempt <= 0 || attempt > MaxGeminiEmptyStreamRetries { + return 0, false + } + delayMs := EmptyStreamBaseDelayMs * (1 << (attempt - 1)) + return time.Duration(delayMs) * time.Millisecond, true +} + +func ShouldRetryGeminiEmptyStream(hasContent bool, emptyAttempt int) bool { + return !hasContent && emptyAttempt < MaxGeminiEmptyStreamRetries +} diff --git a/pkg/ai/providers/google_gemini_cli_test.go b/pkg/ai/providers/google_gemini_cli_test.go index 95f3f6be..4d1f3882 100644 --- a/pkg/ai/providers/google_gemini_cli_test.go +++ b/pkg/ai/providers/google_gemini_cli_test.go @@ -67,3 +67,31 @@ func TestBuildGeminiCLIHeaders(t *testing.T) { } }) } + +func TestGeminiEmptyStreamRetryHelpers(t *testing.T) { + if delay, ok := GeminiEmptyStreamBackoff(1); !ok || delay != 500*time.Millisecond { + t.Fatalf("expected first retry backoff 500ms, got %v (ok=%v)", delay, ok) + } + if delay, ok := GeminiEmptyStreamBackoff(2); !ok || delay != time.Second { + t.Fatalf("expected second retry backoff 1s, got %v (ok=%v)", delay, ok) + } + if _, ok := GeminiEmptyStreamBackoff(0); ok { + t.Fatalf("did not expect backoff for attempt 0") + } + if _, ok := GeminiEmptyStreamBackoff(3); ok { + t.Fatalf("did not expect backoff beyond max retries") + } + + if !ShouldRetryGeminiEmptyStream(false, 0) { + t.Fatalf("expected retry on first empty attempt") + } + if !ShouldRetryGeminiEmptyStream(false, 1) { + t.Fatalf("expected retry on second empty attempt") + } + if ShouldRetryGeminiEmptyStream(false, 2) { + t.Fatalf("did not expect retry beyond max attempts") + } + if ShouldRetryGeminiEmptyStream(true, 0) { + t.Fatalf("did not expect retry when content was received") + } +} From a040e102eaf7366cf7f75959c6144c36d46f2119 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 06:45:44 +0000 Subject: [PATCH 42/75] Add pkg-ai Google and Vertex streaming runtimes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- go.mod | 13 + go.sum | 104 ++++++ pkg/ai/providers/google_runtime.go | 378 +++++++++++++++++++++ pkg/ai/providers/google_runtime_test.go | 85 +++++ pkg/ai/providers/register_builtins.go | 12 +- pkg/ai/providers/register_builtins_test.go | 38 +++ 6 files changed, 628 insertions(+), 2 deletions(-) create mode 100644 pkg/ai/providers/google_runtime.go create mode 100644 pkg/ai/providers/google_runtime_test.go diff --git a/go.mod b/go.mod index 3f589dcc..b2c80c58 100644 --- a/go.mod +++ b/go.mod @@ -26,12 +26,20 @@ require ( ) require ( + cloud.google.com/go v0.116.0 // indirect + cloud.google.com/go/auth v0.9.3 // indirect + cloud.google.com/go/compute/metadata v0.5.0 // indirect filippo.io/edwards25519 v1.2.0 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect github.com/anthropics/anthropic-sdk-go v1.26.0 // indirect github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/jsonschema-go v0.3.0 // indirect + github.com/google/s2a-go v0.1.8 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/lib/pq v1.11.2 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -44,6 +52,7 @@ require ( github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.7.16 // indirect go.mau.fi/zeroconfig v0.2.0 // indirect + go.opencensus.io v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect golang.org/x/mod v0.33.0 // indirect @@ -52,6 +61,10 @@ require ( golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect + google.golang.org/genai v1.48.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/grpc v1.66.2 // indirect + google.golang.org/protobuf v1.36.11 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index d5960f8b..e8f629a7 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,13 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= +cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U= +cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk= +cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY= +cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY= filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/PuerkitoBio/goquery v1.11.0 h1:jZ7pwMQXIITcUXNH83LLk+txlaEy6NVOfTuP43xxfqw= @@ -12,27 +20,62 @@ github.com/beeper/bridge-manager v0.14.0 h1:7XeZfHeDiOuwLUe6UiX/HCywthw1s0Q7xhrm github.com/beeper/bridge-manager v0.14.0/go.mod h1:pherlTADz3wkojdc2AvAsR3mS1yG5jF9/OaxkHqPy4Y= github.com/beeper/desktop-api-go v0.2.0 h1:VrwB1FCEiuPycGo6TsYSVVSKQIWFg22xmlRWVJ88E0A= github.com/beeper/desktop-api-go v0.2.0/go.mod h1:y9Mk83OdQWo6ldLTcPyaUPrwjkmvy/3QkhHqZLhU/mA= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dyatlov/go-opengraph/opengraph v0.0.0-20220524092352-606d7b1e5f8a h1:etIrTD8BQqzColk9nKRusM9um5+1q0iOEJLqfBMIK64= github.com/dyatlov/go-opengraph/opengraph v0.0.0-20220524092352-606d7b1e5f8a/go.mod h1:emQhSYTXqB0xxjLITTw4EaWZ+8IIQYw+kx9GqNUKdLg= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= +github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -55,6 +98,7 @@ github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/Q github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= @@ -63,6 +107,12 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -86,7 +136,10 @@ go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= @@ -94,10 +147,14 @@ golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= golang.org/x/image v0.35.0 h1:LKjiHdgMtO8z7Fh18nGY6KDcoEtVfsgLDPeLyguqb7I= golang.org/x/image v0.35.0/go.mod h1:MwPLTVgvxSASsxdLzKrl8BRFuyqMyGhLwmC+TO1Sybk= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -105,7 +162,13 @@ golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220520000938-2e3eb7b945c2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -117,8 +180,11 @@ golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -128,7 +194,10 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -165,6 +234,10 @@ golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= @@ -173,12 +246,43 @@ golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxb golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genai v1.48.0 h1:1vb15G291wAjJJueisMDpUhssljhEdJU2t5qTidrVPs= +google.golang.org/genai v1.48.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo= +google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mautrix v0.26.4-0.20260304003850-fef4326fbce6 h1:Kv/cJeaI7yObLLydiIAyKhl7/8Rto+PmGZTBBfJM3Q0= diff --git a/pkg/ai/providers/google_runtime.go b/pkg/ai/providers/google_runtime.go new file mode 100644 index 00000000..ee8b19a1 --- /dev/null +++ b/pkg/ai/providers/google_runtime.go @@ -0,0 +1,378 @@ +package providers + +import ( + "context" + "encoding/base64" + "strings" + "time" + + "google.golang.org/genai" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func streamGoogleGenerativeAI(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + googleOptions := GoogleOptions{} + if options != nil { + googleOptions.StreamOptions = *options + } + return streamGoogleWithBackend(model, c, googleOptions, genai.BackendGeminiAPI) +} + +func streamSimpleGoogleGenerativeAI(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + return streamGoogleWithBackend(model, c, buildGoogleOptionsFromSimple(model, base, options), genai.BackendGeminiAPI) +} + +func streamGoogleVertex(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + googleOptions := GoogleOptions{} + if options != nil { + googleOptions.StreamOptions = *options + } + return streamGoogleWithBackend(model, c, googleOptions, genai.BackendVertexAI) +} + +func streamSimpleGoogleVertex(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + return streamGoogleWithBackend(model, c, buildGoogleOptionsFromSimple(model, base, options), genai.BackendVertexAI) +} + +func buildGoogleOptionsFromSimple(model ai.Model, base ai.StreamOptions, options *ai.SimpleStreamOptions) GoogleOptions { + out := GoogleOptions{StreamOptions: base} + if options == nil || options.Reasoning == "" || !model.Reasoning { + return out + } + level := strings.ToLower(strings.TrimSpace(string(options.Reasoning))) + if level == "xhigh" { + level = "high" + } + adjustedMaxTokens, thinkingBudget := AdjustMaxTokensForThinking( + base.MaxTokens, + model.MaxTokens, + options.Reasoning, + options.ThinkingBudgets, + ) + out.StreamOptions.MaxTokens = adjustedMaxTokens + out.Thinking = &GoogleThinkingOptions{ + Enabled: true, + Level: level, + } + if thinkingBudget > 0 { + out.Thinking.BudgetTokens = &thinkingBudget + } + return out +} + +func streamGoogleWithBackend( + model ai.Model, + c ai.Context, + options GoogleOptions, + backend genai.Backend, +) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + runCtx := options.StreamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + + client, err := newGoogleClient(runCtx, backend, options.StreamOptions.APIKey) + if err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + payload := BuildGoogleGenerateContentParams(model, c, options) + if options.StreamOptions.OnPayload != nil { + options.StreamOptions.OnPayload(payload) + } + + contents := convertGoogleContextToGenAIContents(model, c) + config := buildGenAIContentConfig(model, c, options) + textBuilder := strings.Builder{} + thinkingBuilder := strings.Builder{} + toolCalls := make([]ai.ContentBlock, 0, 2) + usage := ai.Usage{} + stopReason := ai.StopReasonStop + + for result, err := range client.Models.GenerateContentStream(runCtx, model.ID, contents, config) { + if err != nil { + pushProviderError(stream, model, err.Error()) + return + } + if result == nil { + continue + } + if result.UsageMetadata != nil { + usage = ai.Usage{ + Input: int(result.UsageMetadata.PromptTokenCount), + Output: int(result.UsageMetadata.CandidatesTokenCount), + TotalTokens: int(result.UsageMetadata.TotalTokenCount), + } + } + + for _, candidate := range result.Candidates { + if candidate == nil { + continue + } + if candidate.FinishReason != "" { + stopReason = MapGoogleStopReason(string(candidate.FinishReason)) + } + if candidate.Content == nil { + continue + } + for _, part := range candidate.Content.Parts { + if part == nil { + continue + } + if part.FunctionCall != nil { + toolCall := ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: strings.TrimSpace(part.FunctionCall.ID), + Name: strings.TrimSpace(part.FunctionCall.Name), + Arguments: part.FunctionCall.Args, + } + toolCalls = append(toolCalls, toolCall) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventToolCallEnd, + ToolCall: &toolCall, + }) + } + if strings.TrimSpace(part.Text) != "" { + if part.Thought { + thinkingBuilder.WriteString(part.Text) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventThinkingDelta, + Delta: part.Text, + }) + } else { + textBuilder.WriteString(part.Text) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventTextDelta, + Delta: part.Text, + }) + } + } + } + } + } + + usage.Cost = ai.CalculateCost(model, usage) + assistantMessage := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + Usage: usage, + StopReason: stopReason, + Timestamp: time.Now().UnixMilli(), + } + if thinking := strings.TrimSpace(thinkingBuilder.String()); thinking != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: thinking, + }) + } + if text := strings.TrimSpace(textBuilder.String()); text != "" { + assistantMessage.Content = append(assistantMessage.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: text, + }) + } + if len(toolCalls) > 0 { + assistantMessage.Content = append(assistantMessage.Content, toolCalls...) + } + if len(toolCalls) > 0 && assistantMessage.StopReason == ai.StopReasonStop { + assistantMessage.StopReason = ai.StopReasonToolUse + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: assistantMessage, + Reason: assistantMessage.StopReason, + }) + }() + return stream +} + +func newGoogleClient(ctx context.Context, backend genai.Backend, apiKey string) (*genai.Client, error) { + switch backend { + case genai.BackendGeminiAPI: + if strings.TrimSpace(apiKey) == "" { + apiKey = ai.GetEnvAPIKey("google") + } + if strings.TrimSpace(apiKey) == "" { + return nil, errProvider("missing API key for Google Generative AI runtime") + } + return genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: apiKey, + Backend: genai.BackendGeminiAPI, + }) + case genai.BackendVertexAI: + project, err := ResolveGoogleVertexProject(nil) + if err != nil { + return nil, err + } + location, err := ResolveGoogleVertexLocation(nil) + if err != nil { + return nil, err + } + if ai.GetEnvAPIKey("google-vertex") == "" { + return nil, errProvider("missing ADC credentials for Google Vertex runtime") + } + return genai.NewClient(ctx, &genai.ClientConfig{ + Project: project, + Location: location, + Backend: genai.BackendVertexAI, + }) + default: + return nil, errProvider("unsupported Google backend") + } +} + +func convertGoogleContextToGenAIContents(model ai.Model, c ai.Context) []*genai.Content { + googleMessages := ConvertGoogleMessages(model, c) + out := make([]*genai.Content, 0, len(googleMessages)) + for _, msg := range googleMessages { + parts := make([]*genai.Part, 0, len(msg.Parts)) + for _, part := range msg.Parts { + switch { + case strings.TrimSpace(part.Text) != "": + p := &genai.Part{ + Text: part.Text, + Thought: part.Thought, + } + if strings.TrimSpace(part.ThoughtSignature) != "" { + p.ThoughtSignature = []byte(part.ThoughtSignature) + } + parts = append(parts, p) + case part.FunctionCall != nil: + parts = append(parts, genai.NewPartFromFunctionCall(part.FunctionCall.Name, part.FunctionCall.Args)) + case part.FunctionResponse != nil: + parts = append(parts, genai.NewPartFromFunctionResponse(part.FunctionResponse.Name, part.FunctionResponse.Response)) + case part.InlineData != nil: + if data, ok := decodeBase64(part.InlineData.Data); ok { + parts = append(parts, genai.NewPartFromBytes(data, part.InlineData.MimeType)) + } + } + } + if len(parts) == 0 { + continue + } + out = append(out, &genai.Content{ + Role: msg.Role, + Parts: parts, + }) + } + return out +} + +func buildGenAIContentConfig(model ai.Model, c ai.Context, options GoogleOptions) *genai.GenerateContentConfig { + config := &genai.GenerateContentConfig{} + if options.StreamOptions.Temperature != nil { + temp := float32(*options.StreamOptions.Temperature) + config.Temperature = &temp + } + if options.StreamOptions.MaxTokens > 0 { + config.MaxOutputTokens = int32(options.StreamOptions.MaxTokens) + } + if strings.TrimSpace(c.SystemPrompt) != "" { + config.SystemInstruction = &genai.Content{ + Parts: []*genai.Part{{Text: c.SystemPrompt}}, + } + } + if len(c.Tools) > 0 { + config.Tools = convertGoogleToolsToGenAI(c.Tools) + if strings.TrimSpace(options.ToolChoice) != "" { + config.ToolConfig = &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: mapGoogleToolChoiceToGenAI(options.ToolChoice), + }, + } + } + } + if options.Thinking != nil && options.Thinking.Enabled && model.Reasoning { + thinking := &genai.ThinkingConfig{ + IncludeThoughts: true, + } + if options.Thinking.BudgetTokens != nil && *options.Thinking.BudgetTokens > 0 { + value := int32(*options.Thinking.BudgetTokens) + thinking.ThinkingBudget = &value + } + if level := mapThinkingLevelToGenAI(options.Thinking.Level); level != "" { + thinking.ThinkingLevel = level + } + config.ThinkingConfig = thinking + } + return config +} + +func convertGoogleToolsToGenAI(tools []ai.Tool) []*genai.Tool { + out := make([]*genai.Tool, 0, len(tools)) + if len(tools) == 0 { + return out + } + declarations := make([]*genai.FunctionDeclaration, 0, len(tools)) + for _, tool := range tools { + declarations = append(declarations, &genai.FunctionDeclaration{ + Name: tool.Name, + Description: tool.Description, + ParametersJsonSchema: tool.Parameters, + }) + } + out = append(out, &genai.Tool{ + FunctionDeclarations: declarations, + }) + return out +} + +func mapGoogleToolChoiceToGenAI(choice string) genai.FunctionCallingConfigMode { + switch strings.ToLower(strings.TrimSpace(choice)) { + case "none": + return genai.FunctionCallingConfigModeNone + case "any": + return genai.FunctionCallingConfigModeAny + default: + return genai.FunctionCallingConfigModeAuto + } +} + +func mapThinkingLevelToGenAI(level string) genai.ThinkingLevel { + switch strings.ToLower(strings.TrimSpace(level)) { + case "minimal": + return genai.ThinkingLevelMinimal + case "low": + return genai.ThinkingLevelLow + case "medium": + return genai.ThinkingLevelMedium + case "high", "xhigh": + return genai.ThinkingLevelHigh + default: + return "" + } +} + +func decodeBase64(value string) ([]byte, bool) { + value = strings.TrimSpace(value) + if value == "" { + return nil, false + } + if data, err := base64.StdEncoding.DecodeString(value); err == nil { + return data, true + } + if data, err := base64.RawStdEncoding.DecodeString(value); err == nil { + return data, true + } + return nil, false +} + +func errProvider(message string) error { + return &providerError{message: strings.TrimSpace(message)} +} + +type providerError struct { + message string +} + +func (e *providerError) Error() string { + return e.message +} diff --git a/pkg/ai/providers/google_runtime_test.go b/pkg/ai/providers/google_runtime_test.go new file mode 100644 index 00000000..054de2fe --- /dev/null +++ b/pkg/ai/providers/google_runtime_test.go @@ -0,0 +1,85 @@ +package providers + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "google.golang.org/genai" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamGoogleGenerativeAI_MissingAPIKeyEmitsError(t *testing.T) { + t.Setenv("GEMINI_API_KEY", "") + stream := streamGoogleGenerativeAI(ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google", + API: ai.APIGoogleGenerativeAI, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "api key") { + t.Fatalf("expected missing api key message, got %q", evt.Error.ErrorMessage) + } + if _, err := stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestStreamGoogleVertex_MissingEnvEmitsError(t *testing.T) { + t.Setenv("GOOGLE_CLOUD_PROJECT", "") + t.Setenv("GCLOUD_PROJECT", "") + t.Setenv("GOOGLE_CLOUD_LOCATION", "") + stream := streamGoogleVertex(ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google-vertex", + API: ai.APIGoogleVertex, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "project") { + t.Fatalf("expected missing project error, got %q", evt.Error.ErrorMessage) + } +} + +func TestGoogleRuntimeHelperMappings(t *testing.T) { + if got := mapGoogleToolChoiceToGenAI("any"); got != genai.FunctionCallingConfigModeAny { + t.Fatalf("expected any tool choice, got %q", got) + } + if got := mapGoogleToolChoiceToGenAI("none"); got != genai.FunctionCallingConfigModeNone { + t.Fatalf("expected none tool choice, got %q", got) + } + if got := mapGoogleToolChoiceToGenAI("auto"); got != genai.FunctionCallingConfigModeAuto { + t.Fatalf("expected auto tool choice, got %q", got) + } + + if got := mapThinkingLevelToGenAI("xhigh"); got != genai.ThinkingLevelHigh { + t.Fatalf("expected xhigh to clamp to high, got %q", got) + } + if got := mapThinkingLevelToGenAI("minimal"); got != genai.ThinkingLevelMinimal { + t.Fatalf("expected minimal thinking level, got %q", got) + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index 18f254a6..c3d878a1 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -75,11 +75,19 @@ func RegisterBuiltInAPIProviders() { Stream: streamAnthropicMessages, StreamSimple: streamSimpleAnthropicMessages, }, BuiltinProviderSourceID) + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIGoogleGenerativeAI, + Stream: streamGoogleGenerativeAI, + StreamSimple: streamSimpleGoogleGenerativeAI, + }, BuiltinProviderSourceID) + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIGoogleVertex, + Stream: streamGoogleVertex, + StreamSimple: streamSimpleGoogleVertex, + }, BuiltinProviderSourceID) for _, apiID := range []ai.Api{ - ai.APIGoogleGenerativeAI, ai.APIGoogleGeminiCLI, - ai.APIGoogleVertex, ai.APIBedrockConverse, } { ai.RegisterAPIProvider(ai.APIProvider{ diff --git a/pkg/ai/providers/register_builtins_test.go b/pkg/ai/providers/register_builtins_test.go index 887ed645..07483bc1 100644 --- a/pkg/ai/providers/register_builtins_test.go +++ b/pkg/ai/providers/register_builtins_test.go @@ -104,4 +104,42 @@ func TestRegisterBuiltInAPIProviders(t *testing.T) { if strings.Contains(strings.ToLower(anthropicEvt.Error.ErrorMessage), "not implemented") { t.Fatalf("expected anthropic runtime implementation, got stub error: %q", anthropicEvt.Error.ErrorMessage) } + + googleStream, err := ai.Stream(ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google", + API: ai.APIGoogleGenerativeAI, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected google stream resolve error: %v", err) + } + googleEvt, err := googleStream.Next(ctx) + if err != nil { + t.Fatalf("expected google terminal error event, got %v", err) + } + if googleEvt.Type != ai.EventError { + t.Fatalf("expected google error event, got %s", googleEvt.Type) + } + if strings.Contains(strings.ToLower(googleEvt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected google runtime implementation, got stub error: %q", googleEvt.Error.ErrorMessage) + } + + vertexStream, err := ai.Stream(ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google-vertex", + API: ai.APIGoogleVertex, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected vertex stream resolve error: %v", err) + } + vertexEvt, err := vertexStream.Next(ctx) + if err != nil { + t.Fatalf("expected vertex terminal error event, got %v", err) + } + if vertexEvt.Type != ai.EventError { + t.Fatalf("expected vertex error event, got %s", vertexEvt.Type) + } + if strings.Contains(strings.ToLower(vertexEvt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected vertex runtime implementation, got stub error: %q", vertexEvt.Error.ErrorMessage) + } } From 0a11d303779462a8032ad97682749d05878659cc Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 06:56:01 +0000 Subject: [PATCH 43/75] Add pkg-ai Bedrock converse runtime implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- go.mod | 16 + go.sum | 32 ++ pkg/ai/providers/amazon_bedrock_runtime.go | 350 ++++++++++++++++++ .../providers/amazon_bedrock_runtime_test.go | 73 ++++ pkg/ai/providers/register_builtins.go | 6 +- pkg/ai/providers/register_builtins_test.go | 19 + 6 files changed, 495 insertions(+), 1 deletion(-) create mode 100644 pkg/ai/providers/amazon_bedrock_runtime.go create mode 100644 pkg/ai/providers/amazon_bedrock_runtime_test.go diff --git a/go.mod b/go.mod index b2c80c58..b5eb8cb3 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,22 @@ require ( filippo.io/edwards25519 v1.2.0 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect github.com/anthropics/anthropic-sdk-go v1.26.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.3 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.11 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.11 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect diff --git a/go.sum b/go.sum index e8f629a7..00b22819 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,38 @@ github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kk github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY= github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 h1:N4lRUXZpZ1KVEUn6hxtco/1d2lgYhNn1fHkkl8WhlyQ= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= +github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= +github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11/go.mod h1:30yY2zqkMPdrvxBqzI9xQCM+WrlrZKSOpSJEsylVU+8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19/go.mod h1:FpZN2QISLdEBWkayloda+sZjVJL+e9Gl0k1SyTgcswU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.1 h1:tnLUbtNW5c056BEbQ4xvlZaakvgdaEdiKF87R1fxuoo= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.1/go.mod h1:DYDD64rVUpCvpLyuWCiTaaSfrW2O9GiDo8S6fNo8ZI0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19/go.mod h1:/rARO8psX+4sfjUQXp5LLifjUt8DuATZ31WptNJTyQA= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 h1:Y2cAXlClHsXkkOvWZFXATr34b0hxxloeQu/pAZz2row= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7/go.mod h1:idzZ7gmDeqeNrSPkdbtMp9qWMgcBwykA7P7Rzh5DXVU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12/go.mod h1:fEWYKTRGoZNl8tZ77i61/ccwOMJdGxwOhWCkp6TXAr0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16/go.mod h1:Jic/xv0Rq/pFNCh3WwpH4BEqdbSAl+IyHro8LbibHD8= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8/go.mod h1:Xgx+PR1NUOjNmQY+tRMnouRp83JRM8pRMw/vCaVhPkI= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beeper/bridge-manager v0.14.0 h1:7XeZfHeDiOuwLUe6UiX/HCywthw1s0Q7xhrmDzzW9FA= github.com/beeper/bridge-manager v0.14.0/go.mod h1:pherlTADz3wkojdc2AvAsR3mS1yG5jF9/OaxkHqPy4Y= github.com/beeper/desktop-api-go v0.2.0 h1:VrwB1FCEiuPycGo6TsYSVVSKQIWFg22xmlRWVJ88E0A= diff --git a/pkg/ai/providers/amazon_bedrock_runtime.go b/pkg/ai/providers/amazon_bedrock_runtime.go new file mode 100644 index 00000000..2d47f5d8 --- /dev/null +++ b/pkg/ai/providers/amazon_bedrock_runtime.go @@ -0,0 +1,350 @@ +package providers + +import ( + "context" + "os" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + bedrockdocument "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document" + bedrocktypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func streamBedrockConverse(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + var streamOptions ai.StreamOptions + if options != nil { + streamOptions = *options + } + runCtx := streamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + + if ai.GetEnvAPIKey("amazon-bedrock") == "" { + pushProviderError(stream, model, "missing AWS credentials for Amazon Bedrock runtime") + return + } + cfg, err := loadBedrockAWSConfig(runCtx) + if err != nil { + pushProviderError(stream, model, err.Error()) + return + } + client := bedrockruntime.NewFromConfig(cfg) + + payload := BuildBedrockConverseInput(model, c, BedrockOptions{StreamOptions: streamOptions}) + if streamOptions.OnPayload != nil { + streamOptions.OnPayload(payload) + } + + input := buildBedrockConverseInput(model, c, streamOptions) + resp, err := client.Converse(runCtx, input) + if err != nil { + pushProviderError(stream, model, err.Error()) + return + } + message := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + StopReason: mapBedrockStopReason(resp.StopReason), + Timestamp: time.Now().UnixMilli(), + } + if resp.Usage != nil { + message.Usage = ai.Usage{ + Input: int(aws.ToInt32(resp.Usage.InputTokens)), + Output: int(aws.ToInt32(resp.Usage.OutputTokens)), + TotalTokens: int(aws.ToInt32(resp.Usage.TotalTokens)), + CacheRead: int(aws.ToInt32(resp.Usage.CacheReadInputTokens)), + CacheWrite: int(aws.ToInt32(resp.Usage.CacheWriteInputTokens)), + } + message.Usage.Cost = ai.CalculateCost(model, message.Usage) + } + if outputMessage, ok := resp.Output.(*bedrocktypes.ConverseOutputMemberMessage); ok { + for _, block := range outputMessage.Value.Content { + switch blockValue := block.(type) { + case *bedrocktypes.ContentBlockMemberText: + if strings.TrimSpace(blockValue.Value) == "" { + continue + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventTextDelta, + Delta: blockValue.Value, + }) + message.Content = append(message.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: blockValue.Value, + }) + case *bedrocktypes.ContentBlockMemberReasoningContent: + thinkingBlock := convertBedrockReasoningBlock(blockValue.Value) + if thinkingBlock == nil { + continue + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventThinkingDelta, + Delta: thinkingBlock.Thinking, + }) + message.Content = append(message.Content, *thinkingBlock) + case *bedrocktypes.ContentBlockMemberToolUse: + toolCall := convertBedrockToolUseBlock(blockValue.Value) + if toolCall == nil { + continue + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventToolCallEnd, + ToolCall: toolCall, + }) + message.Content = append(message.Content, *toolCall) + } + } + } + if message.StopReason == ai.StopReasonStop { + for _, block := range message.Content { + if block.Type == ai.ContentTypeToolCall { + message.StopReason = ai.StopReasonToolUse + break + } + } + } + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: message, + Reason: message.StopReason, + }) + }() + return stream +} + +func streamSimpleBedrockConverse(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + return streamBedrockConverse(model, c, &base) +} + +func loadBedrockAWSConfig(ctx context.Context) (aws.Config, error) { + if region := strings.TrimSpace(os.Getenv("AWS_REGION")); region != "" { + return config.LoadDefaultConfig(ctx, config.WithRegion(region)) + } + if region := strings.TrimSpace(os.Getenv("AWS_DEFAULT_REGION")); region != "" { + return config.LoadDefaultConfig(ctx, config.WithRegion(region)) + } + return config.LoadDefaultConfig(ctx) +} + +func buildBedrockConverseInput(model ai.Model, c ai.Context, options ai.StreamOptions) *bedrockruntime.ConverseInput { + input := &bedrockruntime.ConverseInput{ + ModelId: aws.String(model.ID), + Messages: convertContextToBedrockMessages(model, c), + } + if strings.TrimSpace(c.SystemPrompt) != "" { + input.System = []bedrocktypes.SystemContentBlock{ + &bedrocktypes.SystemContentBlockMemberText{Value: c.SystemPrompt}, + } + } + if len(c.Tools) > 0 { + input.ToolConfig = &bedrocktypes.ToolConfiguration{ + Tools: convertBedrockTools(c.Tools), + } + input.ToolConfig.ToolChoice = mapBedrockToolChoice("auto") + } + if options.MaxTokens > 0 || options.Temperature != nil { + inference := &bedrocktypes.InferenceConfiguration{} + if options.MaxTokens > 0 { + inference.MaxTokens = aws.Int32(int32(options.MaxTokens)) + } + if options.Temperature != nil { + inference.Temperature = aws.Float32(float32(*options.Temperature)) + } + input.InferenceConfig = inference + } + return input +} + +func convertContextToBedrockMessages(model ai.Model, c ai.Context) []bedrocktypes.Message { + normalized := TransformMessages(c.Messages, model, func(id string, _ ai.Model, _ ai.Message) string { + return NormalizeBedrockToolCallID(id) + }) + out := make([]bedrocktypes.Message, 0, len(normalized)) + for _, msg := range normalized { + switch msg.Role { + case ai.RoleUser: + blocks := make([]bedrocktypes.ContentBlock, 0, max(1, len(msg.Content))) + if strings.TrimSpace(msg.Text) != "" { + blocks = append(blocks, &bedrocktypes.ContentBlockMemberText{Value: msg.Text}) + } + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText && strings.TrimSpace(block.Text) != "" { + blocks = append(blocks, &bedrocktypes.ContentBlockMemberText{Value: block.Text}) + } + } + if len(blocks) == 0 { + continue + } + out = append(out, bedrocktypes.Message{ + Role: bedrocktypes.ConversationRoleUser, + Content: blocks, + }) + case ai.RoleAssistant: + blocks := make([]bedrocktypes.ContentBlock, 0, len(msg.Content)) + for _, block := range msg.Content { + switch block.Type { + case ai.ContentTypeText: + if strings.TrimSpace(block.Text) == "" { + continue + } + blocks = append(blocks, &bedrocktypes.ContentBlockMemberText{Value: block.Text}) + case ai.ContentTypeThinking: + if strings.TrimSpace(block.Thinking) == "" { + continue + } + reasoning := bedrocktypes.ReasoningTextBlock{ + Text: aws.String(block.Thinking), + } + if strings.TrimSpace(block.ThinkingSignature) != "" { + reasoning.Signature = aws.String(block.ThinkingSignature) + } + blocks = append(blocks, &bedrocktypes.ContentBlockMemberReasoningContent{ + Value: &bedrocktypes.ReasoningContentBlockMemberReasoningText{Value: reasoning}, + }) + case ai.ContentTypeToolCall: + blocks = append(blocks, &bedrocktypes.ContentBlockMemberToolUse{ + Value: bedrocktypes.ToolUseBlock{ + Name: aws.String(block.Name), + ToolUseId: aws.String(block.ID), + Input: bedrockdocument.NewLazyDocument(block.Arguments), + }, + }) + } + } + if len(blocks) == 0 { + continue + } + out = append(out, bedrocktypes.Message{ + Role: bedrocktypes.ConversationRoleAssistant, + Content: blocks, + }) + case ai.RoleToolResult: + content := make([]bedrocktypes.ToolResultContentBlock, 0, 1) + resultText := msg.Text + if strings.TrimSpace(resultText) == "" { + var parts []string + for _, block := range msg.Content { + if block.Type == ai.ContentTypeText && strings.TrimSpace(block.Text) != "" { + parts = append(parts, block.Text) + } + } + resultText = strings.Join(parts, "\n") + } + if strings.TrimSpace(resultText) == "" { + resultText = "(empty tool result)" + } + content = append(content, &bedrocktypes.ToolResultContentBlockMemberText{Value: resultText}) + status := bedrocktypes.ToolResultStatusSuccess + if msg.IsError { + status = bedrocktypes.ToolResultStatusError + } + out = append(out, bedrocktypes.Message{ + Role: bedrocktypes.ConversationRoleUser, + Content: []bedrocktypes.ContentBlock{ + &bedrocktypes.ContentBlockMemberToolResult{ + Value: bedrocktypes.ToolResultBlock{ + ToolUseId: aws.String(msg.ToolCallID), + Content: content, + Status: status, + }, + }, + }, + }) + } + } + return out +} + +func convertBedrockTools(tools []ai.Tool) []bedrocktypes.Tool { + out := make([]bedrocktypes.Tool, 0, len(tools)) + for _, tool := range tools { + spec := bedrocktypes.ToolSpecification{ + Name: aws.String(tool.Name), + Description: aws.String(tool.Description), + InputSchema: &bedrocktypes.ToolInputSchemaMemberJson{ + Value: bedrockdocument.NewLazyDocument(tool.Parameters), + }, + } + out = append(out, &bedrocktypes.ToolMemberToolSpec{Value: spec}) + } + return out +} + +func mapBedrockToolChoice(choice string) bedrocktypes.ToolChoice { + switch strings.ToLower(strings.TrimSpace(choice)) { + case "any": + return &bedrocktypes.ToolChoiceMemberAny{Value: bedrocktypes.AnyToolChoice{}} + case "none": + return nil + default: + return &bedrocktypes.ToolChoiceMemberAuto{Value: bedrocktypes.AutoToolChoice{}} + } +} + +func mapBedrockStopReason(reason bedrocktypes.StopReason) ai.StopReason { + switch reason { + case bedrocktypes.StopReasonMaxTokens: + return ai.StopReasonLength + case bedrocktypes.StopReasonToolUse: + return ai.StopReasonToolUse + case bedrocktypes.StopReasonEndTurn, bedrocktypes.StopReasonStopSequence: + return ai.StopReasonStop + default: + return ai.StopReasonError + } +} + +func convertBedrockToolUseBlock(block bedrocktypes.ToolUseBlock) *ai.ContentBlock { + if strings.TrimSpace(aws.ToString(block.Name)) == "" { + return nil + } + arguments := map[string]any{} + if block.Input != nil { + _ = block.Input.UnmarshalSmithyDocument(&arguments) + } + return &ai.ContentBlock{ + Type: ai.ContentTypeToolCall, + ID: strings.TrimSpace(aws.ToString(block.ToolUseId)), + Name: strings.TrimSpace(aws.ToString(block.Name)), + Arguments: arguments, + } +} + +func convertBedrockReasoningBlock(block bedrocktypes.ReasoningContentBlock) *ai.ContentBlock { + switch value := block.(type) { + case *bedrocktypes.ReasoningContentBlockMemberReasoningText: + if value.Value.Text == nil || strings.TrimSpace(*value.Value.Text) == "" { + return nil + } + out := &ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: strings.TrimSpace(*value.Value.Text), + } + if value.Value.Signature != nil { + out.ThinkingSignature = strings.TrimSpace(*value.Value.Signature) + } + return out + case *bedrocktypes.ReasoningContentBlockMemberRedactedContent: + if len(value.Value) == 0 { + return nil + } + return &ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: string(value.Value), + Redacted: true, + } + default: + return nil + } +} diff --git a/pkg/ai/providers/amazon_bedrock_runtime_test.go b/pkg/ai/providers/amazon_bedrock_runtime_test.go new file mode 100644 index 00000000..3224fb98 --- /dev/null +++ b/pkg/ai/providers/amazon_bedrock_runtime_test.go @@ -0,0 +1,73 @@ +package providers + +import ( + "context" + "io" + "strings" + "testing" + "time" + + bedrocktypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamBedrockConverse_MissingCredentialsEmitsError(t *testing.T) { + t.Setenv("AWS_PROFILE", "") + t.Setenv("AWS_ACCESS_KEY_ID", "") + t.Setenv("AWS_SECRET_ACCESS_KEY", "") + t.Setenv("AWS_BEARER_TOKEN_BEDROCK", "") + t.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "") + t.Setenv("AWS_CONTAINER_CREDENTIALS_FULL_URI", "") + t.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "") + + stream := streamBedrockConverse(ai.Model{ + ID: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + Provider: "amazon-bedrock", + API: ai.APIBedrockConverse, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "credentials") { + t.Fatalf("expected missing credentials message, got %q", evt.Error.ErrorMessage) + } + if _, err := stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestMapBedrockStopReason(t *testing.T) { + cases := map[bedrocktypes.StopReason]ai.StopReason{ + bedrocktypes.StopReasonEndTurn: ai.StopReasonStop, + bedrocktypes.StopReasonStopSequence: ai.StopReasonStop, + bedrocktypes.StopReasonMaxTokens: ai.StopReasonLength, + bedrocktypes.StopReasonToolUse: ai.StopReasonToolUse, + } + for in, want := range cases { + if got := mapBedrockStopReason(in); got != want { + t.Fatalf("mapBedrockStopReason(%q) = %q, want %q", in, got, want) + } + } +} + +func TestMapBedrockToolChoice(t *testing.T) { + if mapBedrockToolChoice("none") != nil { + t.Fatalf("expected none tool choice to map to nil") + } + if got := mapBedrockToolChoice("any"); got == nil { + t.Fatalf("expected any tool choice") + } + if got := mapBedrockToolChoice("auto"); got == nil { + t.Fatalf("expected auto tool choice") + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index c3d878a1..f696cb31 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -85,10 +85,14 @@ func RegisterBuiltInAPIProviders() { Stream: streamGoogleVertex, StreamSimple: streamSimpleGoogleVertex, }, BuiltinProviderSourceID) + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIBedrockConverse, + Stream: streamBedrockConverse, + StreamSimple: streamSimpleBedrockConverse, + }, BuiltinProviderSourceID) for _, apiID := range []ai.Api{ ai.APIGoogleGeminiCLI, - ai.APIBedrockConverse, } { ai.RegisterAPIProvider(ai.APIProvider{ API: apiID, diff --git a/pkg/ai/providers/register_builtins_test.go b/pkg/ai/providers/register_builtins_test.go index 07483bc1..ba5e3233 100644 --- a/pkg/ai/providers/register_builtins_test.go +++ b/pkg/ai/providers/register_builtins_test.go @@ -142,4 +142,23 @@ func TestRegisterBuiltInAPIProviders(t *testing.T) { if strings.Contains(strings.ToLower(vertexEvt.Error.ErrorMessage), "not implemented") { t.Fatalf("expected vertex runtime implementation, got stub error: %q", vertexEvt.Error.ErrorMessage) } + + bedrockStream, err := ai.Stream(ai.Model{ + ID: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + Provider: "amazon-bedrock", + API: ai.APIBedrockConverse, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected bedrock stream resolve error: %v", err) + } + bedrockEvt, err := bedrockStream.Next(ctx) + if err != nil { + t.Fatalf("expected bedrock terminal error event, got %v", err) + } + if bedrockEvt.Type != ai.EventError { + t.Fatalf("expected bedrock error event, got %s", bedrockEvt.Type) + } + if strings.Contains(strings.ToLower(bedrockEvt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected bedrock runtime implementation, got stub error: %q", bedrockEvt.Error.ErrorMessage) + } } From 10f965e3e0f8c94bffbedd32ddc9f1559df9605e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 06:58:59 +0000 Subject: [PATCH 44/75] Expand pkg-ai bridge provider inference coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/pkg_ai_provider_bridge.go | 45 +++++++++++++++++--- pkg/connector/pkg_ai_provider_bridge_test.go | 42 ++++++++++++++++++ 2 files changed, 81 insertions(+), 6 deletions(-) diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go index cc864278..e23ff02b 100644 --- a/pkg/connector/pkg_ai_provider_bridge.go +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -20,6 +20,16 @@ func inferProviderNameFromBaseURL(baseURL string) string { switch { case strings.Contains(lower, "openrouter.ai"): return "openrouter" + case strings.Contains(lower, "api.anthropic.com"): + return "anthropic" + case strings.Contains(lower, "cloudcode-pa.googleapis.com"): + return "google-gemini-cli" + case strings.Contains(lower, "aiplatform.googleapis.com"), strings.Contains(lower, "vertex"): + return "google-vertex" + case strings.Contains(lower, "googleapis.com"), strings.Contains(lower, "generativelanguage.googleapis.com"): + return "google" + case strings.Contains(lower, "bedrock"): + return "amazon-bedrock" case strings.Contains(lower, "beeper.com"): return "beeper" case strings.Contains(lower, "magicproxy"): @@ -34,12 +44,7 @@ func inferProviderNameFromBaseURL(baseURL string) string { func buildPkgAIModelFromGenerateParams(params GenerateParams, baseURL string) aipkg.Model { modelID := strings.TrimSpace(params.Model) provider := inferProviderNameFromBaseURL(baseURL) - api := aipkg.APIOpenAIResponses - if provider == "openrouter" { - api = aipkg.APIOpenAICompletions - } else if provider == "azure-openai-responses" { - api = aipkg.APIAzureOpenAIResponse - } + api := inferAPIFromProviderModel(provider, modelID) return aipkg.Model{ ID: modelID, Name: modelID, @@ -52,6 +57,34 @@ func buildPkgAIModelFromGenerateParams(params GenerateParams, baseURL string) ai } } +func inferAPIFromProviderModel(provider string, modelID string) aipkg.Api { + switch provider { + case "openrouter": + return aipkg.APIOpenAICompletions + case "azure-openai-responses": + return aipkg.APIAzureOpenAIResponse + case "anthropic": + return aipkg.APIAnthropicMessages + case "google": + return aipkg.APIGoogleGenerativeAI + case "google-gemini-cli": + return aipkg.APIGoogleGeminiCLI + case "google-vertex": + return aipkg.APIGoogleVertex + case "amazon-bedrock": + return aipkg.APIBedrockConverse + } + model := strings.ToLower(strings.TrimSpace(modelID)) + switch { + case strings.HasPrefix(model, "claude-"): + return aipkg.APIAnthropicMessages + case strings.HasPrefix(model, "gemini-"): + return aipkg.APIGoogleGenerativeAI + default: + return aipkg.APIOpenAIResponses + } +} + func modelSupportsReasoning(modelID string) bool { modelID = strings.ToLower(strings.TrimSpace(modelID)) return strings.HasPrefix(modelID, "gpt-5") || diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 1f5be601..60cf6dc3 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -4,6 +4,8 @@ import ( "context" "errors" "testing" + + aipkg "github.com/beeper/ai-bridge/pkg/ai" ) func TestPkgAIProviderRuntimeEnabled(t *testing.T) { @@ -29,6 +31,11 @@ func TestInferProviderNameFromBaseURL(t *testing.T) { {name: "beeper proxy", baseURL: "https://ai.beeper.com/openai", want: "beeper"}, {name: "magic proxy", baseURL: "https://magicproxy.example/v1", want: "magic-proxy"}, {name: "azure", baseURL: "https://my-openai.azure.com", want: "azure-openai-responses"}, + {name: "anthropic", baseURL: "https://api.anthropic.com", want: "anthropic"}, + {name: "google cloudcode", baseURL: "https://cloudcode-pa.googleapis.com", want: "google-gemini-cli"}, + {name: "google mldev", baseURL: "https://generativelanguage.googleapis.com", want: "google"}, + {name: "google vertex", baseURL: "https://us-central1-aiplatform.googleapis.com", want: "google-vertex"}, + {name: "bedrock", baseURL: "https://bedrock-runtime.us-east-1.amazonaws.com", want: "amazon-bedrock"}, } for _, tc := range cases { @@ -74,6 +81,27 @@ func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { t.Fatalf("expected azure base URL to map to azure-openai-responses API, got %q", azure.API) } + anthropic := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "claude-sonnet-4-5", + }, "https://api.anthropic.com") + if anthropic.API != "anthropic-messages" { + t.Fatalf("expected anthropic base URL to map to anthropic-messages API, got %q", anthropic.API) + } + + google := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "gemini-2.5-flash", + }, "https://generativelanguage.googleapis.com") + if google.API != "google-generative-ai" { + t.Fatalf("expected google base URL to map to google-generative-ai API, got %q", google.API) + } + + bedrock := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + }, "https://bedrock-runtime.us-east-1.amazonaws.com") + if bedrock.API != aipkg.APIBedrockConverse { + t.Fatalf("expected bedrock base URL to map to %q API, got %q", aipkg.APIBedrockConverse, bedrock.API) + } + nonReasoning := buildPkgAIModelFromGenerateParams(GenerateParams{ Model: "gpt-4.1-mini", }, "") @@ -88,6 +116,20 @@ func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { if !withReasoningOverride.Reasoning { t.Fatalf("expected reasoning effort override to mark model as reasoning capable") } + + heuristicAnthropic := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "claude-3-7-sonnet-latest", + }, "") + if heuristicAnthropic.API != "anthropic-messages" { + t.Fatalf("expected claude model heuristic to map to anthropic API, got %q", heuristicAnthropic.API) + } + + heuristicGoogle := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "gemini-2.5-pro", + }, "") + if heuristicGoogle.API != "google-generative-ai" { + t.Fatalf("expected gemini model heuristic to map to google API, got %q", heuristicGoogle.API) + } } func TestShouldFallbackFromPkgAIEvent(t *testing.T) { From 3d89e55e519ca112758066ef377408ebb4bdb506 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 07:05:13 +0000 Subject: [PATCH 45/75] Route OpenAI generate through pkg-ai bridge path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/pkg_ai_provider_bridge.go | 109 +++++++++++++++++++ pkg/connector/pkg_ai_provider_bridge_test.go | 82 ++++++++++++++ pkg/connector/provider_openai.go | 15 +++ 3 files changed, 206 insertions(+) diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go index e23ff02b..d1e48b14 100644 --- a/pkg/connector/pkg_ai_provider_bridge.go +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -2,6 +2,8 @@ package connector import ( "context" + "encoding/json" + "errors" "os" "strings" "time" @@ -119,6 +121,17 @@ func shouldFallbackFromPkgAIEvent(event StreamEvent) bool { strings.Contains(errText, "no api provider registered") } +func shouldFallbackFromPkgAIError(err error) bool { + if err == nil { + return false + } + errText := strings.ToLower(strings.TrimSpace(err.Error())) + return strings.Contains(errText, "not implemented yet") || + strings.Contains(errText, "no api provider registered") || + strings.Contains(errText, "has no stream function") || + strings.Contains(errText, "has no streamsimple function") +} + func tryGenerateStreamWithPkgAI( ctx context.Context, baseURL string, @@ -176,3 +189,99 @@ func tryGenerateStreamWithPkgAI( return mapped, true } } + +func tryGenerateWithPkgAI( + ctx context.Context, + baseURL string, + apiKey string, + params GenerateParams, +) (*GenerateResponse, bool, error) { + aiproviders.RegisterBuiltInAPIProviders() + model := buildPkgAIModelFromGenerateParams(params, baseURL) + aiContext := toAIContext(params.SystemPrompt, params.Messages, params.Tools) + + temp := params.Temperature + options := &aipkg.StreamOptions{ + Ctx: ctx, + MaxTokens: params.MaxCompletionTokens, + Temperature: &temp, + APIKey: strings.TrimSpace(apiKey), + } + + var ( + message aipkg.Message + err error + ) + if reasoning := parseThinkingLevel(params.ReasoningEffort); reasoning != "" { + message, err = aipkg.CompleteSimple(model, aiContext, &aipkg.SimpleStreamOptions{ + StreamOptions: *options, + Reasoning: reasoning, + }) + } else { + message, err = aipkg.Complete(model, aiContext, options) + } + if err != nil { + if shouldFallbackFromPkgAIError(err) { + return nil, false, nil + } + return nil, true, err + } + if message.StopReason == aipkg.StopReasonError && strings.TrimSpace(message.ErrorMessage) != "" { + runtimeErr := errors.New(strings.TrimSpace(message.ErrorMessage)) + if shouldFallbackFromPkgAIError(runtimeErr) { + return nil, false, nil + } + return nil, true, runtimeErr + } + return generateResponseFromAIMessage(message), true, nil +} + +func generateResponseFromAIMessage(message aipkg.Message) *GenerateResponse { + var contentParts []string + var thinkingParts []string + toolCalls := make([]ToolCallResult, 0) + for _, block := range message.Content { + switch block.Type { + case aipkg.ContentTypeText: + if text := strings.TrimSpace(block.Text); text != "" { + contentParts = append(contentParts, text) + } + case aipkg.ContentTypeThinking: + if thinking := strings.TrimSpace(block.Thinking); thinking != "" { + thinkingParts = append(thinkingParts, thinking) + } + case aipkg.ContentTypeToolCall: + argumentsJSON := "{}" + if block.Arguments != nil { + if raw, err := json.Marshal(block.Arguments); err == nil { + argumentsJSON = string(raw) + } + } + toolCalls = append(toolCalls, ToolCallResult{ + ID: strings.TrimSpace(block.ID), + Name: strings.TrimSpace(block.Name), + Arguments: argumentsJSON, + }) + } + } + + content := strings.Join(contentParts, "\n") + if strings.TrimSpace(content) == "" && len(thinkingParts) > 0 { + content = strings.Join(thinkingParts, "\n") + } + finishReason := strings.TrimSpace(string(message.StopReason)) + if finishReason == "" { + finishReason = "stop" + } + + return &GenerateResponse{ + Content: content, + FinishReason: finishReason, + ToolCalls: toolCalls, + Usage: UsageInfo{ + PromptTokens: message.Usage.Input, + CompletionTokens: message.Usage.Output, + TotalTokens: message.Usage.TotalTokens, + }, + } +} diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 60cf6dc3..ecbe8036 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -144,6 +144,18 @@ func TestShouldFallbackFromPkgAIEvent(t *testing.T) { } } +func TestShouldFallbackFromPkgAIError(t *testing.T) { + if !shouldFallbackFromPkgAIError(errors.New("provider runtime is not implemented yet")) { + t.Fatalf("expected not-implemented errors to trigger fallback") + } + if !shouldFallbackFromPkgAIError(errors.New("provider x has no stream function")) { + t.Fatalf("expected missing stream function errors to trigger fallback") + } + if shouldFallbackFromPkgAIError(errors.New("missing API key for provider")) { + t.Fatalf("did not expect runtime credential errors to trigger fallback") + } +} + func TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved(t *testing.T) { events, ok := tryGenerateStreamWithPkgAI(context.Background(), "https://my-openai.azure.com", "", GenerateParams{ Model: "gpt-4.1-mini", @@ -165,6 +177,76 @@ func TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved } } +func TestTryGenerateWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { + resp, handled, err := tryGenerateWithPkgAI(context.Background(), "https://cloudcode-pa.googleapis.com", "", GenerateParams{ + Model: "gemini-2.5-flash", + Messages: []UnifiedMessage{ + { + Role: RoleUser, + Content: []ContentPart{ + {Type: ContentTypeText, Text: "hello"}, + }, + }, + }, + }) + if err != nil { + t.Fatalf("expected nil error on fallback path, got %v", err) + } + if handled { + t.Fatalf("expected fallback for stubbed google-gemini-cli runtime") + } + if resp != nil { + t.Fatalf("expected nil response on fallback path") + } +} + +func TestTryGenerateWithPkgAIReturnsRuntimeErrorWhenProviderResolved(t *testing.T) { + resp, handled, err := tryGenerateWithPkgAI(context.Background(), "https://api.anthropic.com", "", GenerateParams{ + Model: "claude-sonnet-4-5", + Messages: []UnifiedMessage{ + { + Role: RoleUser, + Content: []ContentPart{ + {Type: ContentTypeText, Text: "hello"}, + }, + }, + }, + }) + if !handled { + t.Fatalf("expected anthropic provider to be handled by pkg/ai runtime") + } + if err == nil { + t.Fatalf("expected runtime error without credentials") + } + if resp != nil { + t.Fatalf("expected nil response when runtime returns error") + } +} + +func TestGenerateResponseFromAIMessage(t *testing.T) { + resp := generateResponseFromAIMessage(aipkg.Message{ + StopReason: aipkg.StopReasonToolUse, + Usage: aipkg.Usage{Input: 7, Output: 3, TotalTokens: 10}, + Content: []aipkg.ContentBlock{ + {Type: aipkg.ContentTypeThinking, Thinking: "plan"}, + {Type: aipkg.ContentTypeText, Text: "answer"}, + {Type: aipkg.ContentTypeToolCall, ID: "call_1", Name: "search", Arguments: map[string]any{"q": "go"}}, + }, + }) + if resp.Content != "answer" { + t.Fatalf("expected content text extraction, got %q", resp.Content) + } + if resp.FinishReason != string(aipkg.StopReasonToolUse) { + t.Fatalf("unexpected finish reason: %q", resp.FinishReason) + } + if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "search" { + t.Fatalf("expected tool call mapping, got %#v", resp.ToolCalls) + } + if resp.Usage.TotalTokens != 10 { + t.Fatalf("expected usage mapping, got %#v", resp.Usage) + } +} + func TestParseThinkingLevel(t *testing.T) { cases := map[string]string{ "minimal": "minimal", diff --git a/pkg/connector/provider_openai.go b/pkg/connector/provider_openai.go index 0b734411..35de563f 100644 --- a/pkg/connector/provider_openai.go +++ b/pkg/connector/provider_openai.go @@ -375,6 +375,21 @@ func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GeneratePara // Generate performs a non-streaming generation using Responses API func (o *OpenAIProvider) Generate(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { + if pkgAIProviderRuntimeEnabled() { + if pkgAIResp, handled, err := tryGenerateWithPkgAI(ctx, o.baseURL, o.apiKey, params); handled { + if err != nil { + return nil, fmt.Errorf("pkg/ai generation failed: %w", err) + } + o.log.Debug(). + Str("model", params.Model). + Msg("Using pkg/ai provider runtime for OpenAI generate") + return pkgAIResp, nil + } + o.log.Warn(). + Str("model", params.Model). + Msg("pkg/ai provider runtime fallback to existing OpenAI generate path") + } + // Responses input supports images and PDFs but not audio/video, so fall back to // Chat Completions when unsupported media is present. if hasUnsupportedResponsesUnifiedMessages(params.Messages) { From 4573e7c55848c2bc839f6e8f1bd3b097eceaf562 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 07:09:01 +0000 Subject: [PATCH 46/75] Add connector test for pkg-ai generate bridge usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/provider_openai_pkg_ai_test.go | 36 ++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 pkg/connector/provider_openai_pkg_ai_test.go diff --git a/pkg/connector/provider_openai_pkg_ai_test.go b/pkg/connector/provider_openai_pkg_ai_test.go new file mode 100644 index 00000000..e3344daa --- /dev/null +++ b/pkg/connector/provider_openai_pkg_ai_test.go @@ -0,0 +1,36 @@ +package connector + +import ( + "context" + "strings" + "testing" + + "github.com/rs/zerolog" +) + +func TestOpenAIProviderGenerate_UsesPkgAIBridgeWhenEnabled(t *testing.T) { + t.Setenv("PI_USE_PKG_AI_PROVIDER_RUNTIME", "true") + t.Setenv("ANTHROPIC_API_KEY", "") + t.Setenv("ANTHROPIC_OAUTH_TOKEN", "") + + provider, err := NewOpenAIProviderWithBaseURL("", "https://api.anthropic.com", zerolog.Nop()) + if err != nil { + t.Fatalf("unexpected provider init error: %v", err) + } + + _, err = provider.Generate(context.Background(), GenerateParams{ + Model: "claude-sonnet-4-5", + Messages: []UnifiedMessage{ + { + Role: RoleUser, + Content: []ContentPart{{Type: ContentTypeText, Text: "hello"}}, + }, + }, + }) + if err == nil { + t.Fatalf("expected pkg/ai runtime error without anthropic credentials") + } + if !strings.Contains(strings.ToLower(err.Error()), "pkg/ai generation failed") { + t.Fatalf("expected pkg/ai bridge error prefix, got %q", err.Error()) + } +} From dfa464ebaceb0ca6231b773d8cf266b96b872721 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 07:27:30 +0000 Subject: [PATCH 47/75] Add controlled pkg-ai stream bridge execution path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/streaming_runtime_selector.go | 153 ++++++++++++++++++ .../streaming_runtime_selector_test.go | 101 ++++++++++++ 2 files changed, 254 insertions(+) diff --git a/pkg/connector/streaming_runtime_selector.go b/pkg/connector/streaming_runtime_selector.go index a1bb1eaf..6d0f0514 100644 --- a/pkg/connector/streaming_runtime_selector.go +++ b/pkg/connector/streaming_runtime_selector.go @@ -75,6 +75,20 @@ func (oc *AIClient) streamWithPkgAIBridge( if pkgAIRuntimeDryRunEnabled() { oc.runPkgAIBridgeDryRun(ctx, aiModel, aiContext) } + if oc.shouldUsePkgAIBridgeStreaming(meta, prompt) { + if baseURL, apiKey, ok := oc.pkgAIProviderBridgeCredentials(); ok { + params := oc.buildPkgAIBridgeGenerateParams(meta, prompt) + if events, handled := tryGenerateStreamWithPkgAI(ctx, baseURL, apiKey, params); handled { + oc.loggerForContext(ctx).Debug(). + Str("model", params.Model). + Msg("Executing pkg/ai runtime bridge event stream path") + return oc.streamPkgAIBridgeEvents(ctx, evt, portal, meta, prompt, events) + } + oc.loggerForContext(ctx).Debug(). + Str("model", params.Model). + Msg("pkg/ai bridge event stream path requested fallback") + } + } switch oc.resolveModelAPI(meta) { case ModelAPIChatCompletions: return oc.streamChatCompletions(ctx, evt, portal, meta, prompt) @@ -205,3 +219,142 @@ func chatPromptToUnifiedMessages(prompt []openai.ChatCompletionMessageParamUnion } return out } + +func (oc *AIClient) pkgAIProviderBridgeCredentials() (string, string, bool) { + provider, ok := oc.provider.(*OpenAIProvider) + if !ok || provider == nil { + return "", "", false + } + return provider.baseURL, provider.apiKey, true +} + +func (oc *AIClient) shouldUsePkgAIBridgeStreaming( + meta *PortalMetadata, + prompt []openai.ChatCompletionMessageParamUnion, +) bool { + if meta != nil && meta.Capabilities.SupportsToolCalling { + return false + } + return !promptContainsToolCalls(prompt) +} + +func promptContainsToolCalls(prompt []openai.ChatCompletionMessageParamUnion) bool { + for _, msg := range prompt { + if msg.OfTool != nil { + return true + } + if msg.OfAssistant != nil && len(msg.OfAssistant.ToolCalls) > 0 { + return true + } + } + return false +} + +func (oc *AIClient) buildPkgAIBridgeGenerateParams( + meta *PortalMetadata, + prompt []openai.ChatCompletionMessageParamUnion, +) GenerateParams { + return GenerateParams{ + Model: oc.effectiveModel(meta), + Messages: chatPromptToUnifiedMessages(prompt), + SystemPrompt: oc.effectivePrompt(meta), + Temperature: oc.effectiveTemperature(meta), + MaxCompletionTokens: oc.effectiveMaxTokens(meta), + ReasoningEffort: oc.effectiveReasoningEffort(meta), + } +} + +func (oc *AIClient) streamPkgAIBridgeEvents( + ctx context.Context, + evt *event.Event, + portal *bridgev2.Portal, + meta *PortalMetadata, + prompt []openai.ChatCompletionMessageParamUnion, + events <-chan StreamEvent, +) (bool, *ContextLengthError, error) { + log := oc.loggerForContext(ctx).With(). + Str("action", "stream_pkg_ai_bridge_events"). + Logger() + + prep, _, typingCleanup := oc.prepareStreamingRun(ctx, log, evt, portal, meta, prompt) + defer typingCleanup() + state := prep.State + typingSignals := prep.TypingSignals + touchTyping := prep.TouchTyping + isHeartbeat := prep.IsHeartbeat + + oc.emitUIStart(ctx, portal, state, meta) + + for { + select { + case <-ctx.Done(): + state.finishReason = "cancelled" + if state.hasInitialMessageTarget() && state.accumulated.Len() > 0 { + oc.flushPartialStreamingMessage(context.Background(), portal, state, meta) + } + oc.uiEmitter(state).EmitUIAbort(ctx, portal, "cancelled") + oc.emitUIFinish(ctx, portal, state, meta) + return false, nil, streamFailureError(state, ctx.Err()) + case event, ok := <-events: + if !ok { + state.completedAtMs = time.Now().UnixMilli() + oc.finalizeResponsesStream(ctx, log, portal, state, meta) + return true, nil, nil + } + + oc.markMessageSendSuccess(ctx, portal, evt, state) + switch event.Type { + case StreamEventDelta: + touchTyping() + if err := oc.handleResponseOutputTextDelta( + ctx, + log, + portal, + state, + meta, + typingSignals, + isHeartbeat, + event.Delta, + "failed to send initial streaming message", + "Failed to send initial streaming message", + ); err != nil { + return false, nil, &PreDeltaError{Err: err} + } + case StreamEventReasoning: + touchTyping() + if err := oc.handleResponseReasoningTextDelta( + ctx, + log, + portal, + state, + meta, + isHeartbeat, + event.ReasoningDelta, + "failed to send initial streaming message", + "Failed to send initial streaming message", + ); err != nil { + return false, nil, &PreDeltaError{Err: err} + } + case StreamEventComplete: + if reason := strings.TrimSpace(event.FinishReason); reason != "" { + state.finishReason = reason + } + state.responseID = strings.TrimSpace(event.ResponseID) + if event.Usage != nil { + state.promptTokens = int64(event.Usage.PromptTokens) + state.completionTokens = int64(event.Usage.CompletionTokens) + state.reasoningTokens = int64(event.Usage.ReasoningTokens) + state.totalTokens = int64(event.Usage.TotalTokens) + oc.uiEmitter(state).EmitUIMessageMetadata(ctx, portal, oc.buildUIMessageMetadata(state, meta, true)) + } + case StreamEventError: + if cle := ParseContextLengthError(event.Error); cle != nil { + return false, cle, nil + } + oc.uiEmitter(state).EmitUIError(ctx, portal, event.Error.Error()) + oc.emitUIFinish(ctx, portal, state, meta) + return false, nil, streamFailureError(state, event.Error) + } + } + } +} diff --git a/pkg/connector/streaming_runtime_selector_test.go b/pkg/connector/streaming_runtime_selector_test.go index ebbe054d..d65942c1 100644 --- a/pkg/connector/streaming_runtime_selector_test.go +++ b/pkg/connector/streaming_runtime_selector_test.go @@ -118,3 +118,104 @@ func TestBuildPkgAIContext_UsesSystemPromptAndMappedMessages(t *testing.T) { t.Fatalf("unexpected mapped roles: %#v", ctx.Messages) } } + +func TestPromptContainsToolCalls(t *testing.T) { + if promptContainsToolCalls([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + }) { + t.Fatalf("did not expect tool call detection for plain user prompt") + } + if !promptContainsToolCalls([]openai.ChatCompletionMessageParamUnion{ + { + OfAssistant: &openai.ChatCompletionAssistantMessageParam{ + ToolCalls: []openai.ChatCompletionMessageToolCallUnionParam{ + { + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: "call_1", + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: "search", + Arguments: "{}", + }, + }, + }, + }, + }, + }, + }) { + t.Fatalf("expected assistant tool calls to be detected") + } + if !promptContainsToolCalls([]openai.ChatCompletionMessageParamUnion{ + openai.ToolMessage("tool result", "call_1"), + }) { + t.Fatalf("expected tool role messages to be detected") + } +} + +func TestShouldUsePkgAIBridgeStreaming(t *testing.T) { + client := &AIClient{} + if !client.shouldUsePkgAIBridgeStreaming(&PortalMetadata{}, []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + }) { + t.Fatalf("expected bridge streaming to be enabled for non-tool prompt") + } + if client.shouldUsePkgAIBridgeStreaming(&PortalMetadata{ + Capabilities: ModelCapabilities{SupportsToolCalling: true}, + }, []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + }) { + t.Fatalf("expected bridge streaming disabled when tool calling is enabled") + } +} + +func TestBuildPkgAIBridgeGenerateParams(t *testing.T) { + client := &AIClient{} + meta := &PortalMetadata{ + Model: "claude-sonnet-4-5", + SystemPrompt: "You are helpful", + Temperature: 0.2, + MaxCompletionTokens: 2048, + ReasoningEffort: "medium", + Capabilities: ModelCapabilities{ + SupportsReasoning: true, + }, + } + params := client.buildPkgAIBridgeGenerateParams(meta, []openai.ChatCompletionMessageParamUnion{ + openai.SystemMessage("ignored"), + openai.UserMessage("hello"), + openai.AssistantMessage("hi"), + }) + if params.Model != "claude-sonnet-4-5" { + t.Fatalf("unexpected model mapping: %q", params.Model) + } + if params.SystemPrompt != "You are helpful" { + t.Fatalf("unexpected system prompt mapping: %q", params.SystemPrompt) + } + if params.Temperature != 0.2 { + t.Fatalf("unexpected temperature mapping: %f", params.Temperature) + } + if params.MaxCompletionTokens != 2048 { + t.Fatalf("unexpected max token mapping: %d", params.MaxCompletionTokens) + } + if params.ReasoningEffort != "medium" { + t.Fatalf("unexpected reasoning mapping: %q", params.ReasoningEffort) + } + if len(params.Messages) != 2 { + t.Fatalf("expected mapped user+assistant messages, got %d", len(params.Messages)) + } +} + +func TestPkgAIProviderBridgeCredentials(t *testing.T) { + client := &AIClient{ + provider: &OpenAIProvider{ + baseURL: "https://api.anthropic.com", + apiKey: "secret", + }, + } + baseURL, apiKey, ok := client.pkgAIProviderBridgeCredentials() + if !ok { + t.Fatalf("expected credential extraction for OpenAIProvider") + } + if baseURL != "https://api.anthropic.com" || apiKey != "secret" { + t.Fatalf("unexpected credential extraction: %q %q", baseURL, apiKey) + } +} From a82843cd85284fec748f740e6f8bf6377bb1e441 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 07:35:33 +0000 Subject: [PATCH 48/75] Document pkg-ai runtime bridge flags and rollout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- README.md | 1 + docs/pkg-ai-runtime-migration.md | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 docs/pkg-ai-runtime-migration.md diff --git a/README.md b/README.md index 08f7054e..8aab51a0 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Experimental Matrix ↔ AI bridge for Beeper, built on top of [mautrix/bridgev2] - `docs/matrix-ai-matrix-spec-v1.md`: Full Matrix transport spec (events, streaming, approvals, state, and schema examples). - `docs/bridge-orchestrator.md`: One-command bridge management in this repo. +- `docs/pkg-ai-runtime-migration.md`: Feature flags and rollout notes for connector ↔ `pkg/ai` runtime bridging. ## Bridge Orchestrator diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md new file mode 100644 index 00000000..72e1d7d4 --- /dev/null +++ b/docs/pkg-ai-runtime-migration.md @@ -0,0 +1,48 @@ +# pkg/ai Runtime Migration Notes + +This repository now includes a standalone `pkg/ai` Go port of `pi-mono/packages/ai`, +plus controlled connector bridge paths that can route runtime execution to `pkg/ai`. + +## Feature flags + +### Connector runtime selector + +- `PI_USE_PKG_AI_RUNTIME=1` + - Enables connector runtime selection path (`streamWithPkgAIBridge`). + - Keeps safe fallback to existing Responses/Chat Completions code paths. + +- `PI_USE_PKG_AI_RUNTIME_DRY_RUN=1` + - Runs optional `pkg/ai` dry-run stream consumption for diagnostics while still + executing the existing connector runtime path. + +### Provider runtime bridge + +- `PI_USE_PKG_AI_PROVIDER_RUNTIME=1` + - Enables `OpenAIProvider` bridging for: + - `GenerateStream(...)` via `tryGenerateStreamWithPkgAI(...)` + - `Generate(...)` via `tryGenerateWithPkgAI(...)` + - Includes guarded fallback for unresolved/stubbed provider APIs. + +## Current bridge behavior + +- Streaming (`PI_USE_PKG_AI_RUNTIME`): + - Controlled live pkg/ai event consumption is enabled for safe non-tool + scenarios. + - Falls back to legacy streaming runtime when bridge conditions are not met. + +- Provider abstraction (`PI_USE_PKG_AI_PROVIDER_RUNTIME`): + - Routes both streaming and non-streaming provider calls through pkg/ai where + possible. + - Preserves existing connector behavior on fallback-class errors. + +## High-signal test commands + +```bash +go test ./pkg/ai/... +CGO_ENABLED=0 go test ./pkg/connector -run "TestPkgAIProviderRuntimeEnabled|TestInferProviderNameFromBaseURL|TestBuildPkgAIModelFromGenerateParams|TestShouldFallbackFromPkgAIEvent|TestShouldFallbackFromPkgAIError|TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved|TestTryGenerateWithPkgAIFallsBackOnStubbedProviders|TestTryGenerateWithPkgAIReturnsRuntimeErrorWhenProviderResolved|TestGenerateResponseFromAIMessage|TestParseThinkingLevel|TestOpenAIProviderGenerate_UsesPkgAIBridgeWhenEnabled|TestPkgAIRuntimeEnabledFromEnv|TestChooseStreamingRuntimePath|TestPromptContainsToolCalls|TestShouldUsePkgAIBridgeStreaming|TestBuildPkgAIBridgeGenerateParams|TestPkgAIProviderBridgeCredentials|TestAIEventToStreamEvent_Mapping|TestStreamEventsFromAIStream|TestToAIContext_MapsMessagesAndTools" +``` + +## Notes + +- Full integration remains feature-gated. +- Fallback behavior is intentional and required for incremental rollout safety. From 460560ea4a10b59758a0bc19b1f839ec2c76fb79 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 07:57:50 +0000 Subject: [PATCH 49/75] Implement pkg-ai Google Gemini CLI streaming runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/google_gemini_cli_runtime.go | 587 ++++++++++++++++++ .../google_gemini_cli_runtime_test.go | 134 ++++ pkg/ai/providers/register_builtins.go | 15 +- pkg/ai/providers/register_builtins_test.go | 19 + pkg/connector/pkg_ai_provider_bridge_test.go | 12 +- 5 files changed, 751 insertions(+), 16 deletions(-) create mode 100644 pkg/ai/providers/google_gemini_cli_runtime.go create mode 100644 pkg/ai/providers/google_gemini_cli_runtime_test.go diff --git a/pkg/ai/providers/google_gemini_cli_runtime.go b/pkg/ai/providers/google_gemini_cli_runtime.go new file mode 100644 index 00000000..f57be129 --- /dev/null +++ b/pkg/ai/providers/google_gemini_cli_runtime.go @@ -0,0 +1,587 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "os" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/oauth" + "github.com/beeper/ai-bridge/pkg/ai/utils" +) + +const ( + defaultGeminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" + antigravityDailyEndpoint = "https://daily-cloudcode-pa.sandbox.googleapis.com" + defaultAntigravityVersion = "1.18.3" + maxGeminiCLIRetries = 3 + baseGeminiCLIRetryDelay = 1000 * time.Millisecond + geminiCLIScannerBufferMaxSize = 16 * 1024 * 1024 +) + +var antigravityEndpointFallbacks = []string{antigravityDailyEndpoint, defaultGeminiCLIEndpoint} + +const antigravitySystemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding." + + "You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question." + + "**Absolute paths only**" + + "**Proactiveness**" + +type googleGeminiCLIOptions struct { + StreamOptions ai.StreamOptions + ToolChoice string + Thinking *GoogleThinkingOptions + ProjectID string +} + +type cloudCodeAssistResponseChunk struct { + Response *struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + FunctionCall *struct { + Name string `json:"name"` + Args map[string]any `json:"args"` + ID string `json:"id,omitempty"` + } `json:"functionCall,omitempty"` + } `json:"parts"` + } `json:"content"` + FinishReason string `json:"finishReason,omitempty"` + } `json:"candidates"` + UsageMetadata *struct { + PromptTokenCount int `json:"promptTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` + ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` + TotalTokenCount int `json:"totalTokenCount,omitempty"` + CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` + } `json:"usageMetadata,omitempty"` + } `json:"response,omitempty"` +} + +func streamGoogleGeminiCLI(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { + geminiOptions := googleGeminiCLIOptions{} + if options != nil { + geminiOptions.StreamOptions = *options + } + return streamGoogleGeminiCLIWithOptions(model, c, geminiOptions) +} + +func streamSimpleGoogleGeminiCLI(model ai.Model, c ai.Context, options *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { + base := BuildBaseOptions(model, options, "") + if options == nil || options.Reasoning == "" { + return streamGoogleGeminiCLIWithOptions(model, c, googleGeminiCLIOptions{ + StreamOptions: base, + Thinking: &GoogleThinkingOptions{Enabled: false}, + }) + } + + effort := ClampReasoning(options.Reasoning) + if isGemini3Model(model.ID) { + return streamGoogleGeminiCLIWithOptions(model, c, googleGeminiCLIOptions{ + StreamOptions: base, + Thinking: &GoogleThinkingOptions{ + Enabled: true, + Level: getGeminiCLIThinkingLevel(effort, model.ID), + }, + }) + } + + maxTokens, thinkingBudget := AdjustMaxTokensForThinking( + base.MaxTokens, + model.MaxTokens, + effort, + options.ThinkingBudgets, + ) + base.MaxTokens = maxTokens + return streamGoogleGeminiCLIWithOptions(model, c, googleGeminiCLIOptions{ + StreamOptions: base, + Thinking: &GoogleThinkingOptions{ + Enabled: true, + BudgetTokens: &thinkingBudget, + }, + }) +} + +func streamGoogleGeminiCLIWithOptions( + model ai.Model, + c ai.Context, + options googleGeminiCLIOptions, +) *ai.AssistantMessageEventStream { + stream := ai.NewAssistantMessageEventStream(128) + go func() { + runCtx := options.StreamOptions.Ctx + if runCtx == nil { + runCtx = context.Background() + } + + apiKeyRaw := strings.TrimSpace(options.StreamOptions.APIKey) + if apiKeyRaw == "" { + pushProviderError(stream, model, "google cloud code assist requires OAuth authentication") + return + } + accessToken, projectID, ok := oauth.ParseGoogleOAuthAPIKey(apiKeyRaw) + if !ok { + pushProviderError(stream, model, "invalid google cloud credentials, re-authentication required") + return + } + if strings.TrimSpace(options.ProjectID) != "" { + projectID = strings.TrimSpace(options.ProjectID) + } + + isAntigravity := strings.EqualFold(string(model.Provider), "google-antigravity") + endpoints := geminiCLIEndpoints(model, isAntigravity) + requestBody := BuildGoogleGeminiCLIRequest(model, c, projectID, options, isAntigravity) + if options.StreamOptions.OnPayload != nil { + options.StreamOptions.OnPayload(requestBody) + } + requestBodyJSON, err := json.Marshal(requestBody) + if err != nil { + pushProviderError(stream, model, err.Error()) + return + } + + requestHeaders := map[string]string{ + "Authorization": "Bearer " + accessToken, + "Content-Type": "application/json", + "Accept": "text/event-stream", + } + for key, value := range geminiCLIBaseHeaders(isAntigravity) { + requestHeaders[key] = value + } + for key, value := range BuildGeminiCLIHeaders(model, model.Headers) { + requestHeaders[key] = value + } + for key, value := range options.StreamOptions.Headers { + requestHeaders[key] = value + } + + stopReason := ai.StopReasonStop + usage := ai.Usage{} + var textBuilder strings.Builder + var thinkingBuilder strings.Builder + toolCalls := make([]ai.ContentBlock, 0) + + requestURL := "" + var response *http.Response + var lastErr error + for attempt := 0; attempt <= maxGeminiCLIRetries; attempt++ { + if runCtx.Err() != nil { + pushProviderError(stream, model, runCtx.Err().Error()) + return + } + endpoint := endpoints[minInt(attempt, len(endpoints)-1)] + requestURL = strings.TrimRight(endpoint, "/") + "/v1internal:streamGenerateContent?alt=sse" + response, lastErr = doGeminiCLIRequest(runCtx, requestURL, requestHeaders, requestBodyJSON) + if lastErr == nil && response != nil && response.StatusCode >= 200 && response.StatusCode < 300 { + break + } + if response != nil { + bodyBytes, _ := io.ReadAll(response.Body) + _ = response.Body.Close() + errorText := string(bodyBytes) + if shouldRetryGeminiCLIStatus(response.StatusCode, errorText) && attempt < maxGeminiCLIRetries { + delay := baseGeminiCLIRetryDelay * time.Duration(1< maxDelay { + pushProviderError(stream, model, fmt.Sprintf("server requested %ds retry delay (max %ds): %s", parsedDelayMs/1000, maxDelay/1000, extractGeminiCLIErrorMessage(errorText))) + return + } + delay = time.Duration(parsedDelayMs) * time.Millisecond + } + if sleepErr := sleepWithContext(runCtx, delay); sleepErr != nil { + pushProviderError(stream, model, sleepErr.Error()) + return + } + continue + } + pushProviderError(stream, model, fmt.Sprintf("cloud code assist API error (%d): %s", response.StatusCode, extractGeminiCLIErrorMessage(errorText))) + return + } + if lastErr != nil && attempt < maxGeminiCLIRetries { + delay := baseGeminiCLIRetryDelay * time.Duration(1<= 300 { + bodyBytes, _ := io.ReadAll(retryResp.Body) + _ = retryResp.Body.Close() + pushProviderError(stream, model, fmt.Sprintf("cloud code assist API error (%d): %s", retryResp.StatusCode, extractGeminiCLIErrorMessage(string(bodyBytes)))) + return + } + textBuilder.Reset() + thinkingBuilder.Reset() + toolCalls = toolCalls[:0] + usage = ai.Usage{} + stopReason = ai.StopReasonStop + currentResponse = retryResp + } + if !receivedContent { + pushProviderError(stream, model, "cloud code assist API returned an empty response") + return + } + + message := ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + StopReason: stopReason, + Usage: usage, + Timestamp: time.Now().UnixMilli(), + } + if thinking := strings.TrimSpace(thinkingBuilder.String()); thinking != "" { + message.Content = append(message.Content, ai.ContentBlock{ + Type: ai.ContentTypeThinking, + Thinking: thinking, + }) + } + if text := strings.TrimSpace(textBuilder.String()); text != "" { + message.Content = append(message.Content, ai.ContentBlock{ + Type: ai.ContentTypeText, + Text: text, + }) + } + if len(toolCalls) > 0 { + message.Content = append(message.Content, toolCalls...) + } + if message.StopReason == ai.StopReasonStop && len(toolCalls) > 0 { + message.StopReason = ai.StopReasonToolUse + } + message.Usage.Cost = ai.CalculateCost(model, message.Usage) + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: message, + Reason: message.StopReason, + }) + }() + return stream +} + +func BuildGoogleGeminiCLIRequest( + model ai.Model, + context ai.Context, + projectID string, + options googleGeminiCLIOptions, + isAntigravity bool, +) map[string]any { + base := BuildGoogleGenerateContentParams(model, context, GoogleOptions{ + StreamOptions: options.StreamOptions, + ToolChoice: options.ToolChoice, + Thinking: options.Thinking, + }) + request := map[string]any{ + "contents": base["contents"], + } + if cfg, ok := base["config"].(map[string]any); ok { + generationConfig := map[string]any{} + for key, value := range cfg { + if key == "systemInstruction" { + continue + } + generationConfig[key] = value + } + if len(generationConfig) > 0 { + request["generationConfig"] = generationConfig + } + if systemInstruction, ok := cfg["systemInstruction"].(string); ok && strings.TrimSpace(systemInstruction) != "" { + request["systemInstruction"] = map[string]any{ + "parts": []map[string]any{ + {"text": utils.SanitizeSurrogates(systemInstruction)}, + }, + } + } + } + if sessionID := strings.TrimSpace(options.StreamOptions.SessionID); sessionID != "" { + request["sessionId"] = sessionID + } + if isAntigravity { + existingParts := []map[string]any{} + if instruction, ok := request["systemInstruction"].(map[string]any); ok { + if parts, ok := instruction["parts"].([]map[string]any); ok { + existingParts = append(existingParts, parts...) + } + } + request["systemInstruction"] = map[string]any{ + "role": "user", + "parts": append([]map[string]any{ + {"text": antigravitySystemInstruction}, + {"text": "Please ignore following [ignore]" + antigravitySystemInstruction + "[/ignore]"}, + }, existingParts...), + } + } + out := map[string]any{ + "project": projectID, + "model": model.ID, + "request": request, + "userAgent": "pi-coding-agent", + "requestId": fmt.Sprintf("pi-%d-%d", time.Now().UnixMilli(), rand.Int63()), + } + if isAntigravity { + out["requestType"] = "agent" + out["userAgent"] = "antigravity" + } + return out +} + +func consumeGeminiCLIResponse( + body io.Reader, + textBuilder *strings.Builder, + thinkingBuilder *strings.Builder, + toolCalls *[]ai.ContentBlock, + stream *ai.AssistantMessageEventStream, + usage *ai.Usage, + stopReason *ai.StopReason, +) (bool, error) { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 4096), geminiCLIScannerBufferMaxSize) + hasContent := false + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + rawJSON := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if rawJSON == "" { + continue + } + var chunk cloudCodeAssistResponseChunk + if err := json.Unmarshal([]byte(rawJSON), &chunk); err != nil { + continue + } + if chunk.Response == nil { + continue + } + if chunk.Response.UsageMetadata != nil && usage != nil { + promptTokens := chunk.Response.UsageMetadata.PromptTokenCount + cacheReadTokens := chunk.Response.UsageMetadata.CachedContentTokenCount + usage.Input = promptTokens - cacheReadTokens + usage.Output = chunk.Response.UsageMetadata.CandidatesTokenCount + chunk.Response.UsageMetadata.ThoughtsTokenCount + usage.CacheRead = cacheReadTokens + usage.CacheWrite = 0 + usage.TotalTokens = chunk.Response.UsageMetadata.TotalTokenCount + } + if len(chunk.Response.Candidates) == 0 { + continue + } + candidate := chunk.Response.Candidates[0] + if strings.TrimSpace(candidate.FinishReason) != "" && stopReason != nil { + *stopReason = MapGoogleStopReason(candidate.FinishReason) + } + for _, part := range candidate.Content.Parts { + if strings.TrimSpace(part.Text) != "" { + hasContent = true + if part.Thought { + thinkingBuilder.WriteString(part.Text) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventThinkingDelta, Delta: part.Text}) + } else { + textBuilder.WriteString(part.Text) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventTextDelta, Delta: part.Text}) + } + } + if part.FunctionCall != nil { + hasContent = true + toolCallID := strings.TrimSpace(part.FunctionCall.ID) + if toolCallID == "" { + toolCallID = fmt.Sprintf("%s_%d", strings.TrimSpace(part.FunctionCall.Name), time.Now().UnixMilli()) + } + toolCall := NormalizeGoogleToolCall( + part.FunctionCall.Name, + part.FunctionCall.Args, + toolCallID, + part.ThoughtSignature, + ) + *toolCalls = append(*toolCalls, toolCall) + stream.Push(ai.AssistantMessageEvent{Type: ai.EventToolCallEnd, ToolCall: &toolCall}) + } + } + } + if err := scanner.Err(); err != nil { + return hasContent, err + } + return hasContent, nil +} + +func geminiCLIBaseHeaders(isAntigravity bool) map[string]string { + if isAntigravity { + version := strings.TrimSpace(os.Getenv("PI_AI_ANTIGRAVITY_VERSION")) + if version == "" { + version = defaultAntigravityVersion + } + return map[string]string{ + "User-Agent": fmt.Sprintf("antigravity/%s darwin/arm64", version), + "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", + "Client-Metadata": `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`, + } + } + return map[string]string{ + "User-Agent": "google-cloud-sdk vscode_cloudshelleditor/0.1", + "X-Goog-Api-Client": "gl-node/22.17.0", + "Client-Metadata": `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`, + } +} + +func geminiCLIEndpoints(model ai.Model, isAntigravity bool) []string { + if baseURL := strings.TrimSpace(model.BaseURL); baseURL != "" { + return []string{baseURL} + } + if isAntigravity { + return antigravityEndpointFallbacks + } + return []string{defaultGeminiCLIEndpoint} +} + +func doGeminiCLIRequest( + ctx context.Context, + url string, + headers map[string]string, + body []byte, +) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + for key, value := range headers { + if strings.TrimSpace(value) == "" { + continue + } + req.Header.Set(key, value) + } + return http.DefaultClient.Do(req) +} + +func shouldRetryGeminiCLIStatus(status int, errorText string) bool { + if status == http.StatusTooManyRequests || + status == http.StatusInternalServerError || + status == http.StatusBadGateway || + status == http.StatusServiceUnavailable || + status == http.StatusGatewayTimeout { + return true + } + lower := strings.ToLower(errorText) + return strings.Contains(lower, "resource exhausted") || + strings.Contains(lower, "rate limit") || + strings.Contains(lower, "overloaded") || + strings.Contains(lower, "service unavailable") || + strings.Contains(lower, "other side closed") +} + +func extractGeminiCLIErrorMessage(errorText string) string { + var payload struct { + Error *struct { + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal([]byte(errorText), &payload); err == nil && payload.Error != nil && strings.TrimSpace(payload.Error.Message) != "" { + return strings.TrimSpace(payload.Error.Message) + } + return strings.TrimSpace(errorText) +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func isGemini3Model(modelID string) bool { + id := strings.ToLower(strings.TrimSpace(modelID)) + return strings.Contains(id, "gemini-3-pro") || strings.Contains(id, "gemini-3-flash") || + strings.Contains(id, "gemini-3.1-pro") || strings.Contains(id, "gemini-3.1-flash") +} + +func isGemini3ProModel(modelID string) bool { + id := strings.ToLower(strings.TrimSpace(modelID)) + return strings.Contains(id, "gemini-3-pro") || strings.Contains(id, "gemini-3.1-pro") +} + +func getGeminiCLIThinkingLevel(level ai.ThinkingLevel, modelID string) string { + if isGemini3ProModel(modelID) { + switch level { + case ai.ThinkingMinimal, ai.ThinkingLow: + return "LOW" + default: + return "HIGH" + } + } + switch level { + case ai.ThinkingMinimal: + return "MINIMAL" + case ai.ThinkingLow: + return "LOW" + case ai.ThinkingMedium: + return "MEDIUM" + case ai.ThinkingHigh, ai.ThinkingXHigh: + return "HIGH" + default: + return "MEDIUM" + } +} diff --git a/pkg/ai/providers/google_gemini_cli_runtime_test.go b/pkg/ai/providers/google_gemini_cli_runtime_test.go new file mode 100644 index 00000000..b457f934 --- /dev/null +++ b/pkg/ai/providers/google_gemini_cli_runtime_test.go @@ -0,0 +1,134 @@ +package providers + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestStreamGoogleGeminiCLI_MissingAPIKeyEmitsError(t *testing.T) { + stream := streamGoogleGeminiCLI(ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google-gemini-cli", + API: ai.APIGoogleGeminiCLI, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "oauth") { + t.Fatalf("expected oauth auth message, got %q", evt.Error.ErrorMessage) + } + if _, err := stream.Next(ctx); err != io.EOF { + t.Fatalf("expected EOF after terminal event, got %v", err) + } +} + +func TestStreamGoogleGeminiCLI_InvalidAPIKeyEmitsError(t *testing.T) { + stream := streamGoogleGeminiCLI(ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google-gemini-cli", + API: ai.APIGoogleGeminiCLI, + }, ai.Context{ + Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}, + }, &ai.StreamOptions{ + APIKey: "not-json", + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + evt, err := stream.Next(ctx) + if err != nil { + t.Fatalf("expected terminal error event, got %v", err) + } + if evt.Type != ai.EventError { + t.Fatalf("expected error event, got %s", evt.Type) + } + if !strings.Contains(strings.ToLower(evt.Error.ErrorMessage), "invalid google cloud credentials") { + t.Fatalf("expected invalid credentials message, got %q", evt.Error.ErrorMessage) + } +} + +func TestBuildGoogleGeminiCLIRequest(t *testing.T) { + temp := 0.4 + request := BuildGoogleGeminiCLIRequest( + ai.Model{ID: "gemini-2.5-flash", Reasoning: true}, + ai.Context{ + SystemPrompt: "You are helpful", + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "hello"}, + }, + Tools: []ai.Tool{ + {Name: "search", Description: "search tool", Parameters: map[string]any{"type": "object"}}, + }, + }, + "project-123", + googleGeminiCLIOptions{ + StreamOptions: ai.StreamOptions{ + Temperature: &temp, + MaxTokens: 2048, + SessionID: "session-1", + }, + ToolChoice: "any", + Thinking: &GoogleThinkingOptions{ + Enabled: true, + Level: "HIGH", + }, + }, + false, + ) + + if request["project"] != "project-123" { + t.Fatalf("expected project mapping, got %#v", request["project"]) + } + if request["model"] != "gemini-2.5-flash" { + t.Fatalf("expected model mapping, got %#v", request["model"]) + } + req, ok := request["request"].(map[string]any) + if !ok { + t.Fatalf("expected nested request object, got %#v", request["request"]) + } + if req["sessionId"] != "session-1" { + t.Fatalf("expected session id mapping, got %#v", req["sessionId"]) + } + if _, ok := req["contents"]; !ok { + t.Fatalf("expected converted contents in request") + } + if _, ok := req["systemInstruction"]; !ok { + t.Fatalf("expected system instruction object") + } + genCfg, ok := req["generationConfig"].(map[string]any) + if !ok { + t.Fatalf("expected generationConfig object, got %#v", req["generationConfig"]) + } + if genCfg["maxOutputTokens"] != 2048 { + t.Fatalf("expected maxOutputTokens mapping, got %#v", genCfg["maxOutputTokens"]) + } + if genCfg["temperature"] != 0.4 { + t.Fatalf("expected temperature mapping, got %#v", genCfg["temperature"]) + } +} + +func TestGeminiCLIThinkingLevel(t *testing.T) { + if got := getGeminiCLIThinkingLevel(ai.ThinkingMinimal, "gemini-3-pro"); got != "LOW" { + t.Fatalf("expected pro minimal -> LOW, got %q", got) + } + if got := getGeminiCLIThinkingLevel(ai.ThinkingMedium, "gemini-3-pro"); got != "HIGH" { + t.Fatalf("expected pro medium -> HIGH, got %q", got) + } + if got := getGeminiCLIThinkingLevel(ai.ThinkingMedium, "gemini-3-flash"); got != "MEDIUM" { + t.Fatalf("expected flash medium -> MEDIUM, got %q", got) + } +} diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index f696cb31..529b256e 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -90,16 +90,11 @@ func RegisterBuiltInAPIProviders() { Stream: streamBedrockConverse, StreamSimple: streamSimpleBedrockConverse, }, BuiltinProviderSourceID) - - for _, apiID := range []ai.Api{ - ai.APIGoogleGeminiCLI, - } { - ai.RegisterAPIProvider(ai.APIProvider{ - API: apiID, - Stream: notImplementedStream(apiID), - StreamSimple: notImplementedSimpleStream(apiID), - }, BuiltinProviderSourceID) - } + ai.RegisterAPIProvider(ai.APIProvider{ + API: ai.APIGoogleGeminiCLI, + Stream: streamGoogleGeminiCLI, + StreamSimple: streamSimpleGoogleGeminiCLI, + }, BuiltinProviderSourceID) } func ResetAPIProviders() { diff --git a/pkg/ai/providers/register_builtins_test.go b/pkg/ai/providers/register_builtins_test.go index ba5e3233..5e037341 100644 --- a/pkg/ai/providers/register_builtins_test.go +++ b/pkg/ai/providers/register_builtins_test.go @@ -124,6 +124,25 @@ func TestRegisterBuiltInAPIProviders(t *testing.T) { t.Fatalf("expected google runtime implementation, got stub error: %q", googleEvt.Error.ErrorMessage) } + geminiCLIStream, err := ai.Stream(ai.Model{ + ID: "gemini-2.5-flash", + Provider: "google-gemini-cli", + API: ai.APIGoogleGeminiCLI, + }, ai.Context{}, nil) + if err != nil { + t.Fatalf("unexpected gemini-cli stream resolve error: %v", err) + } + geminiCLIEvt, err := geminiCLIStream.Next(ctx) + if err != nil { + t.Fatalf("expected gemini-cli terminal error event, got %v", err) + } + if geminiCLIEvt.Type != ai.EventError { + t.Fatalf("expected gemini-cli error event, got %s", geminiCLIEvt.Type) + } + if strings.Contains(strings.ToLower(geminiCLIEvt.Error.ErrorMessage), "not implemented") { + t.Fatalf("expected gemini-cli runtime implementation, got stub error: %q", geminiCLIEvt.Error.ErrorMessage) + } + vertexStream, err := ai.Stream(ai.Model{ ID: "gemini-2.5-flash", Provider: "google-vertex", diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index ecbe8036..52005475 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -177,7 +177,7 @@ func TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved } } -func TestTryGenerateWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { +func TestTryGenerateWithPkgAIReturnsRuntimeErrorForGeminiCLI(t *testing.T) { resp, handled, err := tryGenerateWithPkgAI(context.Background(), "https://cloudcode-pa.googleapis.com", "", GenerateParams{ Model: "gemini-2.5-flash", Messages: []UnifiedMessage{ @@ -189,14 +189,14 @@ func TestTryGenerateWithPkgAIFallsBackOnStubbedProviders(t *testing.T) { }, }, }) - if err != nil { - t.Fatalf("expected nil error on fallback path, got %v", err) + if !handled { + t.Fatalf("expected gemini-cli provider to be handled by pkg/ai runtime") } - if handled { - t.Fatalf("expected fallback for stubbed google-gemini-cli runtime") + if err == nil { + t.Fatalf("expected runtime error without OAuth credentials") } if resp != nil { - t.Fatalf("expected nil response on fallback path") + t.Fatalf("expected nil response when runtime returns error") } } From 5aa5ce78e479b89e3ecea4bb60ddbebf84afd2bb Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 08:01:02 +0000 Subject: [PATCH 50/75] Widen pkg-ai stream bridge eligibility for tool-safe chats MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/streaming_runtime_selector.go | 22 +++++++++++++++++-- .../streaming_runtime_selector_test.go | 15 ++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/pkg/connector/streaming_runtime_selector.go b/pkg/connector/streaming_runtime_selector.go index 6d0f0514..24cd5f0f 100644 --- a/pkg/connector/streaming_runtime_selector.go +++ b/pkg/connector/streaming_runtime_selector.go @@ -75,7 +75,7 @@ func (oc *AIClient) streamWithPkgAIBridge( if pkgAIRuntimeDryRunEnabled() { oc.runPkgAIBridgeDryRun(ctx, aiModel, aiContext) } - if oc.shouldUsePkgAIBridgeStreaming(meta, prompt) { + if oc.shouldUsePkgAIBridgeStreaming(ctx, meta, prompt) { if baseURL, apiKey, ok := oc.pkgAIProviderBridgeCredentials(); ok { params := oc.buildPkgAIBridgeGenerateParams(meta, prompt) if events, handled := tryGenerateStreamWithPkgAI(ctx, baseURL, apiKey, params); handled { @@ -229,13 +229,31 @@ func (oc *AIClient) pkgAIProviderBridgeCredentials() (string, string, bool) { } func (oc *AIClient) shouldUsePkgAIBridgeStreaming( + ctx context.Context, meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion, ) bool { if meta != nil && meta.Capabilities.SupportsToolCalling { + if oc.selectedBuiltinToolCountSafe(ctx, meta) > 0 { + return false + } + if resolveAgentID(meta) != "" { + return false + } + } + if promptContainsToolCalls(prompt) { return false } - return !promptContainsToolCalls(prompt) + return true +} + +func (oc *AIClient) selectedBuiltinToolCountSafe(ctx context.Context, meta *PortalMetadata) (count int) { + defer func() { + if recover() != nil { + count = 0 + } + }() + return len(oc.selectedBuiltinToolsForTurn(ctx, meta)) } func promptContainsToolCalls(prompt []openai.ChatCompletionMessageParamUnion) bool { diff --git a/pkg/connector/streaming_runtime_selector_test.go b/pkg/connector/streaming_runtime_selector_test.go index d65942c1..23c9a51b 100644 --- a/pkg/connector/streaming_runtime_selector_test.go +++ b/pkg/connector/streaming_runtime_selector_test.go @@ -1,6 +1,7 @@ package connector import ( + "context" "testing" "github.com/openai/openai-go/v3" @@ -153,17 +154,25 @@ func TestPromptContainsToolCalls(t *testing.T) { func TestShouldUsePkgAIBridgeStreaming(t *testing.T) { client := &AIClient{} - if !client.shouldUsePkgAIBridgeStreaming(&PortalMetadata{}, []openai.ChatCompletionMessageParamUnion{ + if !client.shouldUsePkgAIBridgeStreaming(context.Background(), &PortalMetadata{}, []openai.ChatCompletionMessageParamUnion{ openai.UserMessage("hello"), }) { t.Fatalf("expected bridge streaming to be enabled for non-tool prompt") } - if client.shouldUsePkgAIBridgeStreaming(&PortalMetadata{ + if !client.shouldUsePkgAIBridgeStreaming(context.Background(), &PortalMetadata{ Capabilities: ModelCapabilities{SupportsToolCalling: true}, }, []openai.ChatCompletionMessageParamUnion{ openai.UserMessage("hello"), }) { - t.Fatalf("expected bridge streaming disabled when tool calling is enabled") + t.Fatalf("expected bridge streaming enabled when tool calling has no active tools") + } + if client.shouldUsePkgAIBridgeStreaming(context.Background(), &PortalMetadata{ + Capabilities: ModelCapabilities{SupportsToolCalling: true}, + AgentID: "agent-1", + }, []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("hello"), + }) { + t.Fatalf("expected bridge streaming disabled when agent tool mode is active") } } From 8692f34f3a78ac1d9e2d55aadea899cb4102aaf3 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 08:03:20 +0000 Subject: [PATCH 51/75] Map antigravity endpoints to pkg-ai Gemini CLI API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/connector/pkg_ai_provider_bridge.go | 6 +++++- pkg/connector/pkg_ai_provider_bridge_test.go | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go index d1e48b14..8bce3b40 100644 --- a/pkg/connector/pkg_ai_provider_bridge.go +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -24,7 +24,9 @@ func inferProviderNameFromBaseURL(baseURL string) string { return "openrouter" case strings.Contains(lower, "api.anthropic.com"): return "anthropic" - case strings.Contains(lower, "cloudcode-pa.googleapis.com"): + case strings.Contains(lower, "daily-cloudcode-pa.sandbox.googleapis.com"), + strings.Contains(lower, "cloudcode-pa.sandbox.googleapis.com"), + strings.Contains(lower, "cloudcode-pa.googleapis.com"): return "google-gemini-cli" case strings.Contains(lower, "aiplatform.googleapis.com"), strings.Contains(lower, "vertex"): return "google-vertex" @@ -71,6 +73,8 @@ func inferAPIFromProviderModel(provider string, modelID string) aipkg.Api { return aipkg.APIGoogleGenerativeAI case "google-gemini-cli": return aipkg.APIGoogleGeminiCLI + case "google-antigravity": + return aipkg.APIGoogleGeminiCLI case "google-vertex": return aipkg.APIGoogleVertex case "amazon-bedrock": diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 52005475..86f89c95 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -33,6 +33,7 @@ func TestInferProviderNameFromBaseURL(t *testing.T) { {name: "azure", baseURL: "https://my-openai.azure.com", want: "azure-openai-responses"}, {name: "anthropic", baseURL: "https://api.anthropic.com", want: "anthropic"}, {name: "google cloudcode", baseURL: "https://cloudcode-pa.googleapis.com", want: "google-gemini-cli"}, + {name: "google antigravity", baseURL: "https://daily-cloudcode-pa.sandbox.googleapis.com", want: "google-gemini-cli"}, {name: "google mldev", baseURL: "https://generativelanguage.googleapis.com", want: "google"}, {name: "google vertex", baseURL: "https://us-central1-aiplatform.googleapis.com", want: "google-vertex"}, {name: "bedrock", baseURL: "https://bedrock-runtime.us-east-1.amazonaws.com", want: "amazon-bedrock"}, @@ -95,6 +96,13 @@ func TestBuildPkgAIModelFromGenerateParams(t *testing.T) { t.Fatalf("expected google base URL to map to google-generative-ai API, got %q", google.API) } + antigravity := buildPkgAIModelFromGenerateParams(GenerateParams{ + Model: "gemini-2.5-pro", + }, "https://daily-cloudcode-pa.sandbox.googleapis.com") + if antigravity.API != "google-gemini-cli" { + t.Fatalf("expected antigravity endpoint to map to google-gemini-cli API, got %q", antigravity.API) + } + bedrock := buildPkgAIModelFromGenerateParams(GenerateParams{ Model: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", }, "https://bedrock-runtime.us-east-1.amazonaws.com") From e6c705664ff2a26fbb9053d55a164f4c3c015f90 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 08:07:40 +0000 Subject: [PATCH 52/75] Add env-gated connector pkg-ai bridge e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 14 ++ .../pkg_ai_provider_bridge_e2e_test.go | 153 ++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 pkg/connector/pkg_ai_provider_bridge_e2e_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index 72e1d7d4..59e513b9 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -42,6 +42,20 @@ go test ./pkg/ai/... CGO_ENABLED=0 go test ./pkg/connector -run "TestPkgAIProviderRuntimeEnabled|TestInferProviderNameFromBaseURL|TestBuildPkgAIModelFromGenerateParams|TestShouldFallbackFromPkgAIEvent|TestShouldFallbackFromPkgAIError|TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved|TestTryGenerateWithPkgAIFallsBackOnStubbedProviders|TestTryGenerateWithPkgAIReturnsRuntimeErrorWhenProviderResolved|TestGenerateResponseFromAIMessage|TestParseThinkingLevel|TestOpenAIProviderGenerate_UsesPkgAIBridgeWhenEnabled|TestPkgAIRuntimeEnabledFromEnv|TestChooseStreamingRuntimePath|TestPromptContainsToolCalls|TestShouldUsePkgAIBridgeStreaming|TestBuildPkgAIBridgeGenerateParams|TestPkgAIProviderBridgeCredentials|TestAIEventToStreamEvent_Mapping|TestStreamEventsFromAIStream|TestToAIContext_MapsMessagesAndTools" ``` +## Connector bridge env-gated provider validation + +To validate real provider happy paths for connector bridge routing (OpenAI, Anthropic, Google), set credentials and: + +```bash +PI_AI_E2E=1 CGO_ENABLED=0 go test ./pkg/connector -run "TestPkgAIProviderBridgeE2E_" +``` + +Optional model overrides: + +- `PI_AI_E2E_OPENAI_MODEL` +- `PI_AI_E2E_ANTHROPIC_MODEL` +- `PI_AI_E2E_GOOGLE_MODEL` + ## Notes - Full integration remains feature-gated. diff --git a/pkg/connector/pkg_ai_provider_bridge_e2e_test.go b/pkg/connector/pkg_ai_provider_bridge_e2e_test.go new file mode 100644 index 00000000..85af8043 --- /dev/null +++ b/pkg/connector/pkg_ai_provider_bridge_e2e_test.go @@ -0,0 +1,153 @@ +package connector + +import ( + "context" + "os" + "strings" + "testing" + "time" +) + +func TestPkgAIProviderBridgeE2E_CompleteOpenAI(t *testing.T) { + requirePkgAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := envOrDefault("PI_AI_E2E_OPENAI_MODEL", "gpt-4o-mini") + + resp, handled, err := tryGenerateWithPkgAI(context.Background(), "", apiKey, GenerateParams{ + Model: model, + Messages: []UnifiedMessage{ + {Role: RoleUser, Content: []ContentPart{{Type: ContentTypeText, Text: "Reply with the single word OK."}}}, + }, + MaxCompletionTokens: 128, + }) + if err != nil { + t.Fatalf("expected successful pkg/ai completion, got error: %v", err) + } + if !handled { + t.Fatalf("expected pkg/ai bridge to handle OpenAI completion") + } + if resp == nil || strings.TrimSpace(resp.Content) == "" { + t.Fatalf("expected non-empty completion response") + } +} + +func TestPkgAIProviderBridgeE2E_CompleteAnthropic(t *testing.T) { + requirePkgAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")) + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY is not set") + } + model := envOrDefault("PI_AI_E2E_ANTHROPIC_MODEL", "claude-3-5-haiku-latest") + + resp, handled, err := tryGenerateWithPkgAI(context.Background(), "https://api.anthropic.com", apiKey, GenerateParams{ + Model: model, + Messages: []UnifiedMessage{ + {Role: RoleUser, Content: []ContentPart{{Type: ContentTypeText, Text: "Reply with the single word OK."}}}, + }, + MaxCompletionTokens: 128, + }) + if err != nil { + t.Fatalf("expected successful pkg/ai completion, got error: %v", err) + } + if !handled { + t.Fatalf("expected pkg/ai bridge to handle Anthropic completion") + } + if resp == nil || strings.TrimSpace(resp.Content) == "" { + t.Fatalf("expected non-empty completion response") + } +} + +func TestPkgAIProviderBridgeE2E_CompleteGoogle(t *testing.T) { + requirePkgAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("GEMINI_API_KEY")) + if apiKey == "" { + t.Skip("GEMINI_API_KEY is not set") + } + model := envOrDefault("PI_AI_E2E_GOOGLE_MODEL", "gemini-2.5-flash") + + resp, handled, err := tryGenerateWithPkgAI(context.Background(), "https://generativelanguage.googleapis.com", apiKey, GenerateParams{ + Model: model, + Messages: []UnifiedMessage{ + {Role: RoleUser, Content: []ContentPart{{Type: ContentTypeText, Text: "Reply with the single word OK."}}}, + }, + MaxCompletionTokens: 128, + }) + if err != nil { + t.Fatalf("expected successful pkg/ai completion, got error: %v", err) + } + if !handled { + t.Fatalf("expected pkg/ai bridge to handle Google completion") + } + if resp == nil || strings.TrimSpace(resp.Content) == "" { + t.Fatalf("expected non-empty completion response") + } +} + +func TestPkgAIProviderBridgeE2E_StreamOpenAI(t *testing.T) { + requirePkgAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := envOrDefault("PI_AI_E2E_OPENAI_MODEL", "gpt-4o-mini") + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + events, handled := tryGenerateStreamWithPkgAI(ctx, "", apiKey, GenerateParams{ + Model: model, + Messages: []UnifiedMessage{ + {Role: RoleUser, Content: []ContentPart{{Type: ContentTypeText, Text: "Reply with the single word OK."}}}, + }, + MaxCompletionTokens: 128, + }) + if !handled { + t.Fatalf("expected pkg/ai stream bridge to handle OpenAI streaming") + } + + receivedDelta := false + receivedComplete := false + for { + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for stream completion") + case evt, ok := <-events: + if !ok { + if !receivedComplete { + t.Fatalf("stream closed before complete event") + } + return + } + switch evt.Type { + case StreamEventDelta: + if strings.TrimSpace(evt.Delta) != "" { + receivedDelta = true + } + case StreamEventComplete: + receivedComplete = true + if !receivedDelta { + t.Fatalf("expected at least one text delta before completion") + } + return + case StreamEventError: + t.Fatalf("unexpected stream error: %v", evt.Error) + } + } + } +} + +func requirePkgAIE2E(t *testing.T) { + t.Helper() + if strings.TrimSpace(os.Getenv("PI_AI_E2E")) != "1" { + t.Skip("set PI_AI_E2E=1 to run connector pkg/ai bridge e2e tests") + } +} + +func envOrDefault(key, fallback string) string { + if value := strings.TrimSpace(os.Getenv(key)); value != "" { + return value + } + return fallback +} From b24c25dc909c83ffc0949a19ddc349d5b5fc7a1b Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 08:58:15 +0000 Subject: [PATCH 53/75] Add OpenAI abort handling and pkg-ai e2e stream tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/e2e/abort_test.go | 67 ++++++++- pkg/ai/e2e/stream_test.go | 131 +++++++++++++++++- .../providers/openai_completions_runtime.go | 8 ++ pkg/ai/providers/openai_responses_runtime.go | 8 ++ pkg/ai/providers/runtime_abort.go | 44 ++++++ pkg/ai/providers/runtime_abort_test.go | 45 ++++++ 6 files changed, 292 insertions(+), 11 deletions(-) create mode 100644 pkg/ai/providers/runtime_abort.go create mode 100644 pkg/ai/providers/runtime_abort_test.go diff --git a/pkg/ai/e2e/abort_test.go b/pkg/ai/e2e/abort_test.go index b4a657f4..7a47eb93 100644 --- a/pkg/ai/e2e/abort_test.go +++ b/pkg/ai/e2e/abort_test.go @@ -1,14 +1,71 @@ package e2e import ( + "context" + "io" "os" + "strings" "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" ) -// Scaffolding parity target for pi-mono/packages/ai/test/abort.test.ts. -func TestAbortE2EParityScaffold(t *testing.T) { - if os.Getenv("PI_AI_E2E") == "" { - t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") +func TestAbortE2E_OpenAIStream(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + providers.ResetAPIProviders() + + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + readCtx, cancelRead := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelRead() + + stream, err := ai.Stream(model, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Write a long explanation (at least 30 lines) about why unit tests are valuable.", + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.StreamOptions{ + APIKey: apiKey, + Ctx: runCtx, + MaxTokens: 2048, + }) + if err != nil { + t.Fatalf("stream creation failed: %v", err) + } + + cancelled := false + for { + evt, nextErr := stream.Next(readCtx) + if nextErr == io.EOF { + break + } + if nextErr != nil { + t.Fatalf("stream read failed: %v", nextErr) + } + if !cancelled && evt.Type == ai.EventTextDelta && strings.TrimSpace(evt.Delta) != "" { + cancelRun() + cancelled = true + } + } + + if !cancelled { + t.Skip("stream completed before cancellation could be triggered") + } + result, resultErr := stream.Result() + if resultErr != nil { + t.Fatalf("stream result failed: %v", resultErr) + } + if result.StopReason != ai.StopReasonAborted { + t.Fatalf("expected stop reason %q after cancel, got %q (error: %s)", ai.StopReasonAborted, result.StopReason, result.ErrorMessage) } - t.Skip("abort e2e parity test pending full provider runtime port") } diff --git a/pkg/ai/e2e/stream_test.go b/pkg/ai/e2e/stream_test.go index eea5cfe9..c683c1a9 100644 --- a/pkg/ai/e2e/stream_test.go +++ b/pkg/ai/e2e/stream_test.go @@ -1,15 +1,134 @@ package e2e import ( + "context" + "io" "os" + "strings" "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" ) -// Scaffolding parity target for pi-mono/packages/ai/test/stream.test.ts. -// This is intentionally env-gated while provider runtime integration is in progress. -func TestGenerateE2EParityScaffold(t *testing.T) { - if os.Getenv("PI_AI_E2E") == "" { - t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") +func TestGenerateE2E_OpenAIComplete(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + providers.ResetAPIProviders() + + response, err := ai.Complete(model, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Reply with the single word OK.", + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.StreamOptions{APIKey: apiKey}) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("unexpected error stop reason: %s", response.ErrorMessage) + } + if len(response.Content) == 0 { + t.Fatalf("expected non-empty response content") + } + text := strings.ToLower(strings.TrimSpace(firstText(response))) + if text == "" || !strings.Contains(text, "ok") { + t.Fatalf("expected response text to contain 'ok', got %q", text) + } +} + +func TestGenerateE2E_OpenAIStream(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + providers.ResetAPIProviders() + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + stream, err := ai.Stream(model, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Reply with the single word OK.", + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.StreamOptions{ + APIKey: apiKey, + Ctx: ctx, + }) + if err != nil { + t.Fatalf("stream creation failed: %v", err) + } + + receivedDelta := false + receivedDone := false + for { + evt, nextErr := stream.Next(ctx) + if nextErr == io.EOF { + break + } + if nextErr != nil { + t.Fatalf("stream read failed: %v", nextErr) + } + switch evt.Type { + case ai.EventTextDelta: + if strings.TrimSpace(evt.Delta) != "" { + receivedDelta = true + } + case ai.EventDone: + receivedDone = true + case ai.EventError: + t.Fatalf("unexpected stream error: %s", evt.Error.ErrorMessage) + } + } + + message, resultErr := stream.Result() + if resultErr != nil { + t.Fatalf("stream result failed: %v", resultErr) + } + if message.StopReason == ai.StopReasonError { + t.Fatalf("unexpected error stop reason: %s", message.ErrorMessage) + } + if !receivedDone { + t.Fatalf("expected done event before stream close") + } + if !receivedDelta { + t.Fatalf("expected at least one text delta event") + } +} + +func openAIE2EModel() ai.Model { + modelID := strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_MODEL")) + if modelID == "" { + modelID = "gpt-4o-mini" + } + baseURL := strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_BASE_URL")) + return ai.Model{ + ID: modelID, + Name: modelID, + API: ai.APIOpenAIResponses, + Provider: "openai", + BaseURL: baseURL, + } +} + +func firstText(message ai.Message) string { + for _, block := range message.Content { + if block.Type == ai.ContentTypeText { + return block.Text + } } - t.Skip("stream e2e parity test pending full provider runtime port") + return "" } diff --git a/pkg/ai/providers/openai_completions_runtime.go b/pkg/ai/providers/openai_completions_runtime.go index cd07a5af..997953f6 100644 --- a/pkg/ai/providers/openai_completions_runtime.go +++ b/pkg/ai/providers/openai_completions_runtime.go @@ -130,7 +130,15 @@ func streamOpenAICompletionsWithOptions( } } + if isContextAborted(runCtx, nil) { + pushProviderAborted(stream, model) + return + } if err := openAIStream.Err(); err != nil { + if isContextAborted(runCtx, err) { + pushProviderAborted(stream, model) + return + } pushProviderError(stream, model, err.Error()) return } diff --git a/pkg/ai/providers/openai_responses_runtime.go b/pkg/ai/providers/openai_responses_runtime.go index 7f63c6b1..87fc1a57 100644 --- a/pkg/ai/providers/openai_responses_runtime.go +++ b/pkg/ai/providers/openai_responses_runtime.go @@ -119,7 +119,15 @@ func streamOpenAIResponsesWithOptions( } } + if isContextAborted(runCtx, nil) { + pushProviderAborted(stream, model) + return + } if err := openAIStream.Err(); err != nil { + if isContextAborted(runCtx, err) { + pushProviderAborted(stream, model) + return + } pushProviderError(stream, model, err.Error()) return } diff --git a/pkg/ai/providers/runtime_abort.go b/pkg/ai/providers/runtime_abort.go new file mode 100644 index 00000000..87e15269 --- /dev/null +++ b/pkg/ai/providers/runtime_abort.go @@ -0,0 +1,44 @@ +package providers + +import ( + "context" + "errors" + "strings" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func isContextAborted(runCtx context.Context, err error) bool { + if runCtx != nil { + ctxErr := runCtx.Err() + if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { + return true + } + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + if err == nil { + return false + } + lowerErr := strings.ToLower(strings.TrimSpace(err.Error())) + return strings.Contains(lowerErr, "context canceled") || + strings.Contains(lowerErr, "context cancelled") || + strings.Contains(lowerErr, "deadline exceeded") +} + +func pushProviderAborted(stream *ai.AssistantMessageEventStream, model ai.Model) { + stream.Push(ai.AssistantMessageEvent{ + Type: ai.EventDone, + Message: ai.Message{ + Role: ai.RoleAssistant, + API: model.API, + Provider: model.Provider, + Model: model.ID, + StopReason: ai.StopReasonAborted, + Timestamp: time.Now().UnixMilli(), + }, + Reason: ai.StopReasonAborted, + }) +} diff --git a/pkg/ai/providers/runtime_abort_test.go b/pkg/ai/providers/runtime_abort_test.go new file mode 100644 index 00000000..d5a01cbd --- /dev/null +++ b/pkg/ai/providers/runtime_abort_test.go @@ -0,0 +1,45 @@ +package providers + +import ( + "context" + "errors" + "testing" + + "github.com/beeper/ai-bridge/pkg/ai" +) + +func TestIsContextAborted(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if !isContextAborted(ctx, nil) { + t.Fatalf("expected canceled context to be treated as aborted") + } + if !isContextAborted(context.Background(), context.DeadlineExceeded) { + t.Fatalf("expected deadline exceeded error to be treated as aborted") + } + if !isContextAborted(context.Background(), errors.New("request failed: context canceled")) { + t.Fatalf("expected context canceled message to be treated as aborted") + } + if isContextAborted(context.Background(), errors.New("provider rejected request")) { + t.Fatalf("did not expect non-cancellation error to be treated as aborted") + } +} + +func TestPushProviderAborted(t *testing.T) { + stream := ai.NewAssistantMessageEventStream(1) + model := ai.Model{ + ID: "gpt-5-mini", + Provider: "openai", + API: ai.APIOpenAIResponses, + } + + pushProviderAborted(stream, model) + + msg, err := stream.Result() + if err != nil { + t.Fatalf("expected aborted result without error, got %v", err) + } + if msg.StopReason != ai.StopReasonAborted { + t.Fatalf("expected stop reason %q, got %q", ai.StopReasonAborted, msg.StopReason) + } +} From d5301a4da8897b5f24da9715c56054d6b9002e5f Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:32:45 +0000 Subject: [PATCH 54/75] Port OpenAI tool-call and token parity e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 20 ++++ pkg/ai/e2e/abort_test.go | 7 ++ pkg/ai/e2e/parity_openai_test.go | 162 ++++++++++++++++++++++++++++ pkg/ai/e2e/parity_scaffolds_test.go | 15 --- 4 files changed, 189 insertions(+), 15 deletions(-) create mode 100644 pkg/ai/e2e/parity_openai_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index 59e513b9..4e30cc08 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -56,6 +56,26 @@ Optional model overrides: - `PI_AI_E2E_ANTHROPIC_MODEL` - `PI_AI_E2E_GOOGLE_MODEL` +## pkg/ai env-gated OpenAI parity e2e tests + +The `pkg/ai/e2e` suite now includes live OpenAI parity checks for: + +- basic complete/stream flows (`stream.test.ts` parity subset), +- stream cancel behavior (`abort.test.ts` parity subset), +- orphan tool-call recovery (`tool-call-without-result.test.ts` parity subset), +- usage total-token accounting (`total-tokens.test.ts` parity subset). + +Run with: + +```bash +PI_AI_E2E=1 OPENAI_API_KEY=... go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI" +``` + +Optional overrides: + +- `PI_AI_E2E_OPENAI_MODEL` (default: `gpt-4o-mini`) +- `PI_AI_E2E_OPENAI_BASE_URL` (for OpenAI-compatible endpoints) + ## Notes - Full integration remains feature-gated. diff --git a/pkg/ai/e2e/abort_test.go b/pkg/ai/e2e/abort_test.go index 7a47eb93..e9d34b4e 100644 --- a/pkg/ai/e2e/abort_test.go +++ b/pkg/ai/e2e/abort_test.go @@ -68,4 +68,11 @@ func TestAbortE2E_OpenAIStream(t *testing.T) { if result.StopReason != ai.StopReasonAborted { t.Fatalf("expected stop reason %q after cancel, got %q (error: %s)", ai.StopReasonAborted, result.StopReason, result.ErrorMessage) } + if result.Usage.Input != 0 || result.Usage.Output != 0 { + t.Fatalf( + "expected OpenAI aborted stream usage to be zero, got input=%d output=%d", + result.Usage.Input, + result.Usage.Output, + ) + } } diff --git a/pkg/ai/e2e/parity_openai_test.go b/pkg/ai/e2e/parity_openai_test.go new file mode 100644 index 00000000..5e8d0b6b --- /dev/null +++ b/pkg/ai/e2e/parity_openai_test.go @@ -0,0 +1,162 @@ +package e2e + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func TestToolCallWithoutResultE2E_OpenAI(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + providers.ResetAPIProviders() + + tool := ai.Tool{ + Name: "calculate", + Description: "Evaluate math expressions", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "expression": map[string]any{"type": "string"}, + }, + "required": []any{"expression"}, + }, + } + context := ai.Context{ + SystemPrompt: "Use the calculate tool for arithmetic operations.", + Tools: []ai.Tool{tool}, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Please calculate 25 * 18 using the calculate tool.", + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + first, err := ai.Complete(model, context, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 512, + }) + if err != nil { + t.Fatalf("first complete failed: %v", err) + } + toolCall, ok := findFirstToolCall(first) + if !ok { + t.Fatalf("expected tool call in first response, stop=%q err=%q", first.StopReason, first.ErrorMessage) + } + + context.Messages = append(context.Messages, first) + context.Messages = append(context.Messages, ai.Message{ + Role: ai.RoleUser, + Text: "Never mind; just tell me what is 2+2?", + Timestamp: time.Now().UnixMilli(), + }) + + second, err := ai.Complete(model, context, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 512, + }) + if err != nil { + t.Fatalf("second complete failed: %v", err) + } + if second.StopReason == ai.StopReasonError { + t.Fatalf("expected non-error response after orphan tool call, got %q", second.ErrorMessage) + } + if len(second.Content) == 0 { + t.Fatalf("expected non-empty response content") + } + if second.StopReason != ai.StopReasonStop && second.StopReason != ai.StopReasonToolUse { + t.Fatalf("unexpected stop reason %q", second.StopReason) + } + if second.StopReason == ai.StopReasonToolUse { + if _, hasTool := findFirstToolCall(second); !hasTool { + t.Fatalf("expected toolUse response to include a tool call") + } + } + if strings.TrimSpace(toolCall.ID) == "" { + t.Fatalf("expected first tool call id to be populated") + } +} + +func TestTotalTokensE2E_OpenAI(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + providers.ResetAPIProviders() + + longSystemPrompt := strings.Repeat( + "You are a concise assistant. Include only the requested answer.\n", + 60, + ) + context := ai.Context{ + SystemPrompt: longSystemPrompt, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "What is 2 + 2? Reply with one token only.", + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + first, err := ai.Complete(model, context, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 128, + }) + if err != nil { + t.Fatalf("first complete failed: %v", err) + } + assertTotalTokensEqualsComponents(t, first.Usage) + + context.Messages = append(context.Messages, first) + context.Messages = append(context.Messages, ai.Message{ + Role: ai.RoleUser, + Text: "Now what is 3 + 3? Reply with one token only.", + Timestamp: time.Now().UnixMilli(), + }) + second, err := ai.Complete(model, context, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 128, + }) + if err != nil { + t.Fatalf("second complete failed: %v", err) + } + assertTotalTokensEqualsComponents(t, second.Usage) +} + +func findFirstToolCall(message ai.Message) (ai.ContentBlock, bool) { + for _, block := range message.Content { + if block.Type == ai.ContentTypeToolCall { + return block, true + } + } + return ai.ContentBlock{}, false +} + +func assertTotalTokensEqualsComponents(t *testing.T, usage ai.Usage) { + t.Helper() + computed := usage.Input + usage.Output + usage.CacheRead + usage.CacheWrite + if usage.TotalTokens != computed { + t.Fatalf( + "total tokens mismatch: got %d want %d (input=%d output=%d cacheRead=%d cacheWrite=%d)", + usage.TotalTokens, + computed, + usage.Input, + usage.Output, + usage.CacheRead, + usage.CacheWrite, + ) + } +} diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index 0997d136..8be8c5a4 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -15,11 +15,6 @@ func requirePIAIE2E(t *testing.T) { } } -func TestToolCallWithoutResultE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for tool-call-without-result.test.ts pending runtime implementation") -} - func TestInterleavedThinkingE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for interleaved-thinking.test.ts pending runtime implementation") @@ -40,16 +35,6 @@ func TestAnthropicToolNameNormalizationE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for anthropic-tool-name-normalization.test.ts pending runtime implementation") } -func TestTokenStatsOnAbortE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for tokens.test.ts pending runtime implementation") -} - -func TestTotalTokensE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for total-tokens.test.ts pending runtime implementation") -} - func TestCrossProviderHandoffE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for cross-provider-handoff.test.ts pending runtime implementation") From 81f547027340c4592a47da4e446ce4abb8fed104 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:34:07 +0000 Subject: [PATCH 55/75] Remove obsolete builtin runtime stub helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/register_builtins.go | 46 +-------------------------- 1 file changed, 1 insertion(+), 45 deletions(-) diff --git a/pkg/ai/providers/register_builtins.go b/pkg/ai/providers/register_builtins.go index 529b256e..b340d3ba 100644 --- a/pkg/ai/providers/register_builtins.go +++ b/pkg/ai/providers/register_builtins.go @@ -1,53 +1,9 @@ package providers -import ( - "time" - - "github.com/beeper/ai-bridge/pkg/ai" -) +import "github.com/beeper/ai-bridge/pkg/ai" const BuiltinProviderSourceID = "pkg/ai/providers/register_builtins" -func notImplementedStream(apiID ai.Api) ai.StreamFn { - return func(model ai.Model, _ ai.Context, _ *ai.StreamOptions) *ai.AssistantMessageEventStream { - stream := ai.NewAssistantMessageEventStream(2) - stream.Push(ai.AssistantMessageEvent{ - Type: ai.EventError, - Error: ai.Message{ - Role: ai.RoleAssistant, - API: apiID, - Provider: model.Provider, - Model: model.ID, - StopReason: ai.StopReasonError, - ErrorMessage: "provider runtime is not implemented yet", - Timestamp: time.Now().UnixMilli(), - }, - Reason: ai.StopReasonError, - }) - return stream - } -} - -func notImplementedSimpleStream(apiID ai.Api) ai.StreamSimpleFn { - return func(model ai.Model, _ ai.Context, _ *ai.SimpleStreamOptions) *ai.AssistantMessageEventStream { - stream := ai.NewAssistantMessageEventStream(2) - stream.Push(ai.AssistantMessageEvent{ - Type: ai.EventError, - Error: ai.Message{ - Role: ai.RoleAssistant, - API: apiID, - Provider: model.Provider, - Model: model.ID, - StopReason: ai.StopReasonError, - ErrorMessage: "provider runtime is not implemented yet", - Timestamp: time.Now().UnixMilli(), - }, - Reason: ai.StopReasonError, - }) - return stream - } -} - // RegisterBuiltInAPIProviders registers providers implemented in this package. func RegisterBuiltInAPIProviders() { ai.RegisterAPIProvider(ai.APIProvider{ From 921e09fc9ab757ba496ea4c7cc8e99e112586693 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:36:07 +0000 Subject: [PATCH 56/75] Add OpenAI context overflow e2e parity test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 4 +- pkg/ai/e2e/context_overflow_test.go | 64 ++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index 4e30cc08..ee5d2126 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -64,17 +64,19 @@ The `pkg/ai/e2e` suite now includes live OpenAI parity checks for: - stream cancel behavior (`abort.test.ts` parity subset), - orphan tool-call recovery (`tool-call-without-result.test.ts` parity subset), - usage total-token accounting (`total-tokens.test.ts` parity subset). +- context-overflow detection (`context-overflow.test.ts` parity subset). Run with: ```bash -PI_AI_E2E=1 OPENAI_API_KEY=... go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI" +PI_AI_E2E=1 OPENAI_API_KEY=... go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI" ``` Optional overrides: - `PI_AI_E2E_OPENAI_MODEL` (default: `gpt-4o-mini`) - `PI_AI_E2E_OPENAI_BASE_URL` (for OpenAI-compatible endpoints) +- `PI_AI_E2E_OPENAI_CONTEXT_WINDOW` (default: `128000`, or `400000` for `gpt-5*` models) ## Notes diff --git a/pkg/ai/e2e/context_overflow_test.go b/pkg/ai/e2e/context_overflow_test.go index 75e144a7..65f698a8 100644 --- a/pkg/ai/e2e/context_overflow_test.go +++ b/pkg/ai/e2e/context_overflow_test.go @@ -2,13 +2,67 @@ package e2e import ( "os" + "strconv" + "strings" "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" + aiutils "github.com/beeper/ai-bridge/pkg/ai/utils" ) -// Scaffolding parity target for pi-mono/packages/ai/test/context-overflow.test.ts. -func TestContextOverflowE2EParityScaffold(t *testing.T) { - if os.Getenv("PI_AI_E2E") == "" { - t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") +func TestContextOverflowE2E_OpenAI(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + model.ContextWindow = openAIE2EContextWindow() + if model.ContextWindow <= 0 { + t.Skip("model context window is unknown") + } + providers.ResetAPIProviders() + + overflowContent := strings.Repeat("Lorem ipsum dolor sit amet, consectetur adipiscing elit. ", (model.ContextWindow+10000)/10) + response, err := ai.Complete(model, ai.Context{ + SystemPrompt: "You are a helpful assistant.", + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: overflowContent, + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 64, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + + if !aiutils.IsContextOverflow(response, model.ContextWindow) { + t.Fatalf( + "expected context overflow detection for stop=%q err=%q input=%d cacheRead=%d window=%d", + response.StopReason, + response.ErrorMessage, + response.Usage.Input, + response.Usage.CacheRead, + model.ContextWindow, + ) + } +} + +func openAIE2EContextWindow() int { + if raw := strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_CONTEXT_WINDOW")); raw != "" { + if v, err := strconv.Atoi(raw); err == nil { + return v + } + } + if strings.Contains(strings.ToLower(strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_MODEL"))), "gpt-5") { + return 400000 } - t.Skip("context overflow e2e parity test pending full provider runtime port") + return 128000 } From eb43ff02f4b7bdd24ccdc6b91453c11b4c79fdc9 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:39:02 +0000 Subject: [PATCH 57/75] Normalize provider abort handling across streaming runtimes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/anthropic_runtime.go | 8 ++++++++ pkg/ai/providers/azure_openai_responses_runtime.go | 8 ++++++++ pkg/ai/providers/google_runtime.go | 8 ++++++++ pkg/ai/providers/openai_codex_responses_runtime.go | 8 ++++++++ 4 files changed, 32 insertions(+) diff --git a/pkg/ai/providers/anthropic_runtime.go b/pkg/ai/providers/anthropic_runtime.go index fb7d3ff0..69ebd04d 100644 --- a/pkg/ai/providers/anthropic_runtime.go +++ b/pkg/ai/providers/anthropic_runtime.go @@ -147,7 +147,15 @@ func streamAnthropicMessagesWithOptions( } } + if isContextAborted(runCtx, nil) { + pushProviderAborted(stream, model) + return + } if err := anthropicStream.Err(); err != nil { + if isContextAborted(runCtx, err) { + pushProviderAborted(stream, model) + return + } pushProviderError(stream, model, err.Error()) return } diff --git a/pkg/ai/providers/azure_openai_responses_runtime.go b/pkg/ai/providers/azure_openai_responses_runtime.go index fbb7028a..5d2c3f7d 100644 --- a/pkg/ai/providers/azure_openai_responses_runtime.go +++ b/pkg/ai/providers/azure_openai_responses_runtime.go @@ -127,7 +127,15 @@ func streamAzureOpenAIResponsesWithOptions( } } + if isContextAborted(runCtx, nil) { + pushProviderAborted(stream, model) + return + } if err := openAIStream.Err(); err != nil { + if isContextAborted(runCtx, err) { + pushProviderAborted(stream, model) + return + } pushProviderError(stream, model, err.Error()) return } diff --git a/pkg/ai/providers/google_runtime.go b/pkg/ai/providers/google_runtime.go index ee8b19a1..58a504b1 100644 --- a/pkg/ai/providers/google_runtime.go +++ b/pkg/ai/providers/google_runtime.go @@ -97,6 +97,10 @@ func streamGoogleWithBackend( for result, err := range client.Models.GenerateContentStream(runCtx, model.ID, contents, config) { if err != nil { + if isContextAborted(runCtx, err) { + pushProviderAborted(stream, model) + return + } pushProviderError(stream, model, err.Error()) return } @@ -156,6 +160,10 @@ func streamGoogleWithBackend( } } } + if isContextAborted(runCtx, nil) { + pushProviderAborted(stream, model) + return + } usage.Cost = ai.CalculateCost(model, usage) assistantMessage := ai.Message{ diff --git a/pkg/ai/providers/openai_codex_responses_runtime.go b/pkg/ai/providers/openai_codex_responses_runtime.go index 286dd576..bba54cd9 100644 --- a/pkg/ai/providers/openai_codex_responses_runtime.go +++ b/pkg/ai/providers/openai_codex_responses_runtime.go @@ -120,7 +120,15 @@ func streamOpenAICodexResponsesWithOptions( } } + if isContextAborted(runCtx, nil) { + pushProviderAborted(stream, model) + return + } if err := openAIStream.Err(); err != nil { + if isContextAborted(runCtx, err) { + pushProviderAborted(stream, model) + return + } pushProviderError(stream, model, err.Error()) return } From 3f42719bf6c682319a36acbe92576e3a97e6e57a Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:41:48 +0000 Subject: [PATCH 58/75] Add Anthropic and Google pkg-ai parity e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 12 +- pkg/ai/e2e/parity_provider_runtime_test.go | 185 +++++++++++++++++++++ 2 files changed, 193 insertions(+), 4 deletions(-) create mode 100644 pkg/ai/e2e/parity_provider_runtime_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index ee5d2126..cf3a1d48 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -56,20 +56,22 @@ Optional model overrides: - `PI_AI_E2E_ANTHROPIC_MODEL` - `PI_AI_E2E_GOOGLE_MODEL` -## pkg/ai env-gated OpenAI parity e2e tests +## pkg/ai env-gated provider parity e2e tests -The `pkg/ai/e2e` suite now includes live OpenAI parity checks for: +The `pkg/ai/e2e` suite now includes live provider parity checks for: -- basic complete/stream flows (`stream.test.ts` parity subset), +- OpenAI basic complete/stream flows (`stream.test.ts` parity subset), - stream cancel behavior (`abort.test.ts` parity subset), - orphan tool-call recovery (`tool-call-without-result.test.ts` parity subset), - usage total-token accounting (`total-tokens.test.ts` parity subset). - context-overflow detection (`context-overflow.test.ts` parity subset). +- Anthropic and Google complete/stream smoke coverage. Run with: ```bash -PI_AI_E2E=1 OPENAI_API_KEY=... go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI" +PI_AI_E2E=1 OPENAI_API_KEY=... ANTHROPIC_API_KEY=... GEMINI_API_KEY=... \ + go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" ``` Optional overrides: @@ -77,6 +79,8 @@ Optional overrides: - `PI_AI_E2E_OPENAI_MODEL` (default: `gpt-4o-mini`) - `PI_AI_E2E_OPENAI_BASE_URL` (for OpenAI-compatible endpoints) - `PI_AI_E2E_OPENAI_CONTEXT_WINDOW` (default: `128000`, or `400000` for `gpt-5*` models) +- `PI_AI_E2E_ANTHROPIC_MODEL` / `PI_AI_E2E_ANTHROPIC_BASE_URL` +- `PI_AI_E2E_GOOGLE_MODEL` / `PI_AI_E2E_GOOGLE_BASE_URL` ## Notes diff --git a/pkg/ai/e2e/parity_provider_runtime_test.go b/pkg/ai/e2e/parity_provider_runtime_test.go new file mode 100644 index 00000000..6f80fa33 --- /dev/null +++ b/pkg/ai/e2e/parity_provider_runtime_test.go @@ -0,0 +1,185 @@ +package e2e + +import ( + "context" + "io" + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func TestGenerateE2E_AnthropicComplete(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")) + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY is not set") + } + model := anthropicE2EModel() + providers.ResetAPIProviders() + + response, err := ai.Complete(model, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Reply with the single word OK.", + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 128, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("unexpected error stop reason: %s", response.ErrorMessage) + } + if strings.TrimSpace(firstText(response)) == "" { + t.Fatalf("expected non-empty text response") + } +} + +func TestGenerateE2E_GoogleComplete(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("GEMINI_API_KEY")) + if apiKey == "" { + t.Skip("GEMINI_API_KEY is not set") + } + model := googleE2EModel() + providers.ResetAPIProviders() + + response, err := ai.Complete(model, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Reply with the single word OK.", + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 128, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("unexpected error stop reason: %s", response.ErrorMessage) + } + if strings.TrimSpace(firstText(response)) == "" { + t.Fatalf("expected non-empty text response") + } +} + +func TestGenerateE2E_AnthropicStream(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")) + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY is not set") + } + providers.ResetAPIProviders() + runBasicStreamE2E(t, anthropicE2EModel(), apiKey) +} + +func TestGenerateE2E_GoogleStream(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("GEMINI_API_KEY")) + if apiKey == "" { + t.Skip("GEMINI_API_KEY is not set") + } + providers.ResetAPIProviders() + runBasicStreamE2E(t, googleE2EModel(), apiKey) +} + +func runBasicStreamE2E(t *testing.T, model ai.Model, apiKey string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + stream, err := ai.Stream(model, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Reply with the single word OK.", + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 128, + Ctx: ctx, + }) + if err != nil { + t.Fatalf("stream creation failed: %v", err) + } + + receivedDone := false + receivedDelta := false + for { + evt, nextErr := stream.Next(ctx) + if nextErr == io.EOF { + break + } + if nextErr != nil { + t.Fatalf("stream read failed: %v", nextErr) + } + switch evt.Type { + case ai.EventTextDelta: + if strings.TrimSpace(evt.Delta) != "" { + receivedDelta = true + } + case ai.EventDone: + receivedDone = true + case ai.EventError: + t.Fatalf("unexpected stream error: %s", evt.Error.ErrorMessage) + } + } + + response, err := stream.Result() + if err != nil { + t.Fatalf("stream result failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("unexpected stream stop error: %s", response.ErrorMessage) + } + if !receivedDone { + t.Fatalf("expected done event") + } + if !receivedDelta && strings.TrimSpace(firstText(response)) == "" { + t.Fatalf("expected either streamed deltas or non-empty final text") + } +} + +func anthropicE2EModel() ai.Model { + modelID := strings.TrimSpace(os.Getenv("PI_AI_E2E_ANTHROPIC_MODEL")) + if modelID == "" { + modelID = "claude-3-5-haiku-latest" + } + baseURL := strings.TrimSpace(os.Getenv("PI_AI_E2E_ANTHROPIC_BASE_URL")) + return ai.Model{ + ID: modelID, + Name: modelID, + API: ai.APIAnthropicMessages, + Provider: "anthropic", + BaseURL: baseURL, + } +} + +func googleE2EModel() ai.Model { + modelID := strings.TrimSpace(os.Getenv("PI_AI_E2E_GOOGLE_MODEL")) + if modelID == "" { + modelID = "gemini-2.5-flash" + } + baseURL := strings.TrimSpace(os.Getenv("PI_AI_E2E_GOOGLE_BASE_URL")) + return ai.Model{ + ID: modelID, + Name: modelID, + API: ai.APIGoogleGenerativeAI, + Provider: "google", + BaseURL: baseURL, + } +} From a0d98d585656eb60b5f771e354eba6a29084a394 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:43:55 +0000 Subject: [PATCH 59/75] Add pkg-ai test parity tracker documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- README.md | 1 + docs/pkg-ai-test-parity.md | 63 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 docs/pkg-ai-test-parity.md diff --git a/README.md b/README.md index 8aab51a0..b9b98d17 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Experimental Matrix ↔ AI bridge for Beeper, built on top of [mautrix/bridgev2] - `docs/matrix-ai-matrix-spec-v1.md`: Full Matrix transport spec (events, streaming, approvals, state, and schema examples). - `docs/bridge-orchestrator.md`: One-command bridge management in this repo. - `docs/pkg-ai-runtime-migration.md`: Feature flags and rollout notes for connector ↔ `pkg/ai` runtime bridging. +- `docs/pkg-ai-test-parity.md`: Port-status tracker for `pi-mono/packages/ai` test parity. ## Bridge Orchestrator diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md new file mode 100644 index 00000000..48e97184 --- /dev/null +++ b/docs/pkg-ai-test-parity.md @@ -0,0 +1,63 @@ +# pkg/ai Test Parity Tracker + +This document tracks parity between upstream `pi-mono/packages/ai/test/*.test.ts` +and the Go port in `pkg/ai`. + +Legend: + +- ✅ **Ported**: implemented as Go test(s) with runtime behavior coverage. +- 🧪 **Env-gated**: implemented as live provider test, skipped without credentials. +- 📝 **Scaffold**: placeholder test exists; behavior not fully ported yet. + +## Current parity snapshot + +### Core stream/runtime/e2e behavior + +- `stream.test.ts` → ✅🧪 `pkg/ai/e2e/stream_test.go`, `pkg/ai/e2e/parity_provider_runtime_test.go` +- `abort.test.ts` → ✅🧪 `pkg/ai/e2e/abort_test.go` +- `context-overflow.test.ts` → ✅🧪 `pkg/ai/e2e/context_overflow_test.go` +- `tool-call-without-result.test.ts` → ✅🧪 `pkg/ai/e2e/parity_openai_test.go` +- `total-tokens.test.ts` → ✅🧪 `pkg/ai/e2e/parity_openai_test.go` +- `tokens.test.ts` → ✅🧪 `pkg/ai/e2e/abort_test.go` (OpenAI subset) + +### Provider/unit parity + +- `openai-completions-tool-choice.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` +- `openai-completions-tool-result-images.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` +- `openai-codex-stream.test.ts` → ✅ `pkg/ai/providers/openai_codex_responses_test.go` +- `google-gemini-cli-retry-delay.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` +- `google-gemini-cli-empty-stream.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` +- `google-gemini-cli-claude-thinking-header.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` +- `google-tool-call-missing-args.test.ts` → ✅ `pkg/ai/providers/google_tool_call_missing_args_test.go` +- `google-shared-gemini3-unsigned-tool-call.test.ts` → ✅ `pkg/ai/providers/google_shared_test.go` +- `google-thinking-signature.test.ts` → ✅ `pkg/ai/providers/google_shared_test.go` +- `transform-messages-copilot-openai-to-anthropic.test.ts` → ✅ `pkg/ai/providers/transform_messages_test.go` +- `tool-call-id-normalization.test.ts` → ✅ `pkg/ai/providers/openai_responses_shared_test.go`, `pkg/ai/providers/openai_completions_convert_test.go` +- `anthropic-tool-name-normalization.test.ts` → ✅ `pkg/ai/providers/anthropic_test.go` +- `cache-retention.test.ts` → ✅ `pkg/ai/providers/cache_retention_test.go` +- `image-tool-result.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` +- `unicode-surrogate.test.ts` → ✅ `pkg/ai/utils/sanitize_unicode_test.go` +- `supports-xhigh.test.ts` / `xhigh.test.ts` → ✅ `pkg/ai/models_test.go` +- `interleaved-thinking.test.ts` (deterministic parts) → ✅ `pkg/ai/providers/anthropic_test.go`, `pkg/ai/providers/amazon_bedrock_test.go` +- `bedrock-models.test.ts` (deterministic parts) → ✅ `pkg/ai/providers/amazon_bedrock_test.go` + +### OAuth parity + +- `oauth.ts` (provider/token helper semantics) → ✅ `pkg/ai/oauth/*_test.go` + +### Remaining scaffolds in Go e2e suite + +The following are currently kept as env-gated scaffolds in +`pkg/ai/e2e/parity_scaffolds_test.go`: + +- 📝 `interleaved-thinking.test.ts` +- 📝 `bedrock-models.test.ts` +- 📝 `cross-provider-handoff.test.ts` +- 📝 `openai-responses-reasoning-replay-e2e.test.ts` +- 📝 `google-gemini-cli-empty-stream.test.ts` (full live parity) +- 📝 `xhigh.test.ts` (live) +- 📝 `zen.test.ts` +- 📝 `empty.test.ts` +- 📝 `image-tool-result.test.ts` (live) +- 📝 `google-gemini-cli-claude-thinking-header.test.ts` (live) +- 📝 `github-copilot-anthropic.test.ts` (live) From b056b7db68669fca19ec7384443c342ab27afbe5 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:47:30 +0000 Subject: [PATCH 60/75] Handle cancellation as aborted in Gemini CLI runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/google_gemini_cli_runtime.go | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/pkg/ai/providers/google_gemini_cli_runtime.go b/pkg/ai/providers/google_gemini_cli_runtime.go index f57be129..d099d3d4 100644 --- a/pkg/ai/providers/google_gemini_cli_runtime.go +++ b/pkg/ai/providers/google_gemini_cli_runtime.go @@ -176,7 +176,7 @@ func streamGoogleGeminiCLIWithOptions( var lastErr error for attempt := 0; attempt <= maxGeminiCLIRetries; attempt++ { if runCtx.Err() != nil { - pushProviderError(stream, model, runCtx.Err().Error()) + pushProviderAborted(stream, model) return } endpoint := endpoints[minInt(attempt, len(endpoints)-1)] @@ -203,6 +203,10 @@ func streamGoogleGeminiCLIWithOptions( delay = time.Duration(parsedDelayMs) * time.Millisecond } if sleepErr := sleepWithContext(runCtx, delay); sleepErr != nil { + if isContextAborted(runCtx, sleepErr) { + pushProviderAborted(stream, model) + return + } pushProviderError(stream, model, sleepErr.Error()) return } @@ -214,6 +218,10 @@ func streamGoogleGeminiCLIWithOptions( if lastErr != nil && attempt < maxGeminiCLIRetries { delay := baseGeminiCLIRetryDelay * time.Duration(1< Date: Wed, 4 Mar 2026 09:50:34 +0000 Subject: [PATCH 61/75] Fix OpenAI responses handoff function-call ID pairing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- pkg/ai/providers/openai_responses_shared.go | 18 ++++- .../providers/openai_responses_shared_test.go | 79 +++++++++++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/pkg/ai/providers/openai_responses_shared.go b/pkg/ai/providers/openai_responses_shared.go index 72777c74..9cb00c59 100644 --- a/pkg/ai/providers/openai_responses_shared.go +++ b/pkg/ai/providers/openai_responses_shared.go @@ -96,6 +96,10 @@ func ConvertResponsesMessages( "content": content, }) case ai.RoleAssistant: + isDifferentModel := msg.Model != "" && + msg.Model != model.ID && + msg.Provider == model.Provider && + msg.API == model.API for _, block := range msg.Content { switch block.Type { case ai.ContentTypeText: @@ -130,13 +134,21 @@ func ConvertResponsesMessages( b, _ := json.Marshal(block.Arguments) args = string(b) } - messages = append(messages, map[string]any{ + functionCall := map[string]any{ "type": "function_call", - "id": itemID, "call_id": callID, "name": block.Name, "arguments": args, - }) + } + if itemID != "" { + // For same-provider different-model handoffs, omit item IDs that + // can trigger OpenAI pairing validation against foreign reasoning + // history from prior model turns. + if !(isDifferentModel && strings.HasPrefix(itemID, "fc_")) { + functionCall["id"] = itemID + } + } + messages = append(messages, functionCall) } } case ai.RoleToolResult: diff --git a/pkg/ai/providers/openai_responses_shared_test.go b/pkg/ai/providers/openai_responses_shared_test.go index 9adf7c3a..ff705937 100644 --- a/pkg/ai/providers/openai_responses_shared_test.go +++ b/pkg/ai/providers/openai_responses_shared_test.go @@ -100,3 +100,82 @@ func TestConvertResponsesMessages_CanOmitSystemPrompt(t *testing.T) { t.Fatalf("expected no system/developer prompt in output when omitted, got %#v", first) } } + +func TestConvertResponsesMessages_OmitsFunctionCallItemIDForDifferentModel(t *testing.T) { + model := ai.Model{ + ID: "gpt-5.2-codex", + Provider: "openai", + API: ai.APIOpenAIResponses, + } + context := ai.Context{ + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "use tool"}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "call_123|fc_456", + Name: "double_number", + Arguments: map[string]any{"value": 21}, + }, + }, + Provider: "openai", + API: ai.APIOpenAIResponses, + Model: "gpt-5-mini", + StopReason: ai.StopReasonToolUse, + }, + }, + } + + output := ConvertResponsesMessages(model, context, openAIToolCallProviders, nil) + if len(output) < 2 { + t.Fatalf("expected function call message in output, got %d entries", len(output)) + } + functionCall := output[len(output)-1] + if functionCall["type"] != "function_call" { + t.Fatalf("expected function_call entry, got %#v", functionCall) + } + if _, hasID := functionCall["id"]; hasID { + t.Fatalf("expected function_call id to be omitted for different-model handoff, got %#v", functionCall["id"]) + } + if callID, _ := functionCall["call_id"].(string); callID != "call_123" { + t.Fatalf("expected call_id preserved, got %q", callID) + } +} + +func TestConvertResponsesMessages_DropsAbortedReasoningOnlyAssistant(t *testing.T) { + model := ai.Model{ + ID: "gpt-5-mini", + Provider: "openai", + API: ai.APIOpenAIResponses, + } + context := ai.Context{ + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "use tool"}, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeThinking, + Thinking: "", + ThinkingSignature: `{"type":"reasoning","id":"rs_123","summary":[]}`, + }, + }, + Provider: "openai", + API: ai.APIOpenAIResponses, + Model: "gpt-5-mini", + StopReason: ai.StopReasonAborted, + }, + {Role: ai.RoleUser, Text: "say hi"}, + }, + } + + output := ConvertResponsesMessages(model, context, openAIToolCallProviders, nil) + for _, item := range output { + itemType, _ := item["type"].(string) + if itemType == "reasoning" { + t.Fatalf("expected aborted reasoning history to be omitted, got %#v", item) + } + } +} From c575c96d9e20264741749aa4865f6192dddcf3b5 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:50:47 +0000 Subject: [PATCH 62/75] Update parity tracker for reasoning replay coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-test-parity.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 48e97184..3e7d3b05 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -25,6 +25,7 @@ Legend: - `openai-completions-tool-choice.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` - `openai-completions-tool-result-images.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` - `openai-codex-stream.test.ts` → ✅ `pkg/ai/providers/openai_codex_responses_test.go` +- `openai-responses-reasoning-replay-e2e.test.ts` (message conversion semantics) → ✅ `pkg/ai/providers/openai_responses_shared_test.go` - `google-gemini-cli-retry-delay.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` - `google-gemini-cli-empty-stream.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` - `google-gemini-cli-claude-thinking-header.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` @@ -53,7 +54,6 @@ The following are currently kept as env-gated scaffolds in - 📝 `interleaved-thinking.test.ts` - 📝 `bedrock-models.test.ts` - 📝 `cross-provider-handoff.test.ts` -- 📝 `openai-responses-reasoning-replay-e2e.test.ts` - 📝 `google-gemini-cli-empty-stream.test.ts` (full live parity) - 📝 `xhigh.test.ts` (live) - 📝 `zen.test.ts` From fa4d14008c329d8ef6c12ac213b04975d5c5ca8d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:52:49 +0000 Subject: [PATCH 63/75] Add OpenAI reasoning replay parity e2e coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 4 +- docs/pkg-ai-test-parity.md | 2 +- .../e2e/openai_reasoning_replay_e2e_test.go | 197 ++++++++++++++++++ pkg/ai/e2e/parity_scaffolds_test.go | 5 - 4 files changed, 201 insertions(+), 7 deletions(-) create mode 100644 pkg/ai/e2e/openai_reasoning_replay_e2e_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index cf3a1d48..3c7f3e13 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -65,13 +65,14 @@ The `pkg/ai/e2e` suite now includes live provider parity checks for: - orphan tool-call recovery (`tool-call-without-result.test.ts` parity subset), - usage total-token accounting (`total-tokens.test.ts` parity subset). - context-overflow detection (`context-overflow.test.ts` parity subset). +- OpenAI Responses reasoning replay/handoff (`openai-responses-reasoning-replay-e2e.test.ts` subset). - Anthropic and Google complete/stream smoke coverage. Run with: ```bash PI_AI_E2E=1 OPENAI_API_KEY=... ANTHROPIC_API_KEY=... GEMINI_API_KEY=... \ - go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" + go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" ``` Optional overrides: @@ -79,6 +80,7 @@ Optional overrides: - `PI_AI_E2E_OPENAI_MODEL` (default: `gpt-4o-mini`) - `PI_AI_E2E_OPENAI_BASE_URL` (for OpenAI-compatible endpoints) - `PI_AI_E2E_OPENAI_CONTEXT_WINDOW` (default: `128000`, or `400000` for `gpt-5*` models) +- `PI_AI_E2E_OPENAI_REASONING_SOURCE_MODEL` / `PI_AI_E2E_OPENAI_REASONING_TARGET_MODEL` - `PI_AI_E2E_ANTHROPIC_MODEL` / `PI_AI_E2E_ANTHROPIC_BASE_URL` - `PI_AI_E2E_GOOGLE_MODEL` / `PI_AI_E2E_GOOGLE_BASE_URL` diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 3e7d3b05..50b2e9bf 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -19,13 +19,13 @@ Legend: - `tool-call-without-result.test.ts` → ✅🧪 `pkg/ai/e2e/parity_openai_test.go` - `total-tokens.test.ts` → ✅🧪 `pkg/ai/e2e/parity_openai_test.go` - `tokens.test.ts` → ✅🧪 `pkg/ai/e2e/abort_test.go` (OpenAI subset) +- `openai-responses-reasoning-replay-e2e.test.ts` → ✅🧪 `pkg/ai/e2e/openai_reasoning_replay_e2e_test.go` (+ deterministic conversion assertions in `pkg/ai/providers/openai_responses_shared_test.go`) ### Provider/unit parity - `openai-completions-tool-choice.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` - `openai-completions-tool-result-images.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` - `openai-codex-stream.test.ts` → ✅ `pkg/ai/providers/openai_codex_responses_test.go` -- `openai-responses-reasoning-replay-e2e.test.ts` (message conversion semantics) → ✅ `pkg/ai/providers/openai_responses_shared_test.go` - `google-gemini-cli-retry-delay.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` - `google-gemini-cli-empty-stream.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` - `google-gemini-cli-claude-thinking-header.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` diff --git a/pkg/ai/e2e/openai_reasoning_replay_e2e_test.go b/pkg/ai/e2e/openai_reasoning_replay_e2e_test.go new file mode 100644 index 00000000..59e2d7c0 --- /dev/null +++ b/pkg/ai/e2e/openai_reasoning_replay_e2e_test.go @@ -0,0 +1,197 @@ +package e2e + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func TestOpenAIReasoningReplayE2E_SkipsAbortedReasoningHistory(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIReasoningSourceModel() + providers.ResetAPIProviders() + + context := ai.Context{ + SystemPrompt: "You are a helpful assistant.", + Tools: []ai.Tool{doubleNumberTool()}, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Use the double_number tool to double 21.", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeThinking, + Thinking: "", + ThinkingSignature: `{"type":"reasoning","id":"rs_123","summary":[{"type":"summary_text","text":"tool required"}]}`, + }, + }, + Provider: "openai", + API: ai.APIOpenAIResponses, + Model: model.ID, + StopReason: ai.StopReasonAborted, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleUser, + Text: "Say hello to confirm you can continue.", + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + response, err := ai.CompleteSimple(model, context, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 256, + }, + Reasoning: ai.ThinkingHigh, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("expected no provider error, got %q", response.ErrorMessage) + } + if len(response.Content) == 0 { + t.Fatalf("expected non-empty response content") + } +} + +func TestOpenAIReasoningReplayE2E_SameProviderDifferentModelHandoff(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + sourceModel := openAIReasoningSourceModel() + targetModel := openAIReasoningTargetModel() + providers.ResetAPIProviders() + + context := ai.Context{ + SystemPrompt: "You are a helpful assistant. Answer concisely.", + Tools: []ai.Tool{doubleNumberTool()}, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Use the double_number tool to double 21.", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeThinking, + Thinking: "I should call the tool first.", + ThinkingSignature: `{"type":"reasoning","id":"rs_abc","summary":[{"type":"summary_text","text":"call tool"}]}`, + }, + { + Type: ai.ContentTypeToolCall, + ID: "call_123|fc_456", + Name: "double_number", + Arguments: map[string]any{"value": 21}, + }, + }, + Provider: sourceModel.Provider, + API: sourceModel.API, + Model: sourceModel.ID, + StopReason: ai.StopReasonToolUse, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call_123|fc_456", + ToolName: "double_number", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "42"}, + }, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleUser, + Text: "What was the result? Answer with just the number.", + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + response, err := ai.CompleteSimple(targetModel, context, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 256, + }, + Reasoning: ai.ThinkingHigh, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("expected no provider error, got %q", response.ErrorMessage) + } + text := strings.ToLower(strings.TrimSpace(firstText(response))) + if text == "" { + t.Fatalf("expected non-empty text response") + } + if !strings.Contains(text, "42") && + !strings.Contains(text, "forty-two") && + !strings.Contains(text, "forty two") { + t.Fatalf("expected handoff response to reference tool result, got %q", text) + } +} + +func openAIReasoningSourceModel() ai.Model { + modelID := strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_REASONING_SOURCE_MODEL")) + if modelID == "" { + modelID = "gpt-5-mini" + } + baseURL := strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_BASE_URL")) + return ai.Model{ + ID: modelID, + Name: modelID, + API: ai.APIOpenAIResponses, + Provider: "openai", + BaseURL: baseURL, + Reasoning: true, + } +} + +func openAIReasoningTargetModel() ai.Model { + modelID := strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_REASONING_TARGET_MODEL")) + if modelID == "" { + modelID = "gpt-5.2-codex" + } + baseURL := strings.TrimSpace(os.Getenv("PI_AI_E2E_OPENAI_BASE_URL")) + return ai.Model{ + ID: modelID, + Name: modelID, + API: ai.APIOpenAIResponses, + Provider: "openai", + BaseURL: baseURL, + Reasoning: true, + } +} + +func doubleNumberTool() ai.Tool { + return ai.Tool{ + Name: "double_number", + Description: "Doubles a number and returns the result", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{"type": "number"}, + }, + "required": []any{"value"}, + }, + } +} diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index 8be8c5a4..6aa77b71 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -40,11 +40,6 @@ func TestCrossProviderHandoffE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for cross-provider-handoff.test.ts pending runtime implementation") } -func TestOpenAIResponsesReasoningReplayE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for openai-responses-reasoning-replay-e2e.test.ts pending runtime implementation") -} - func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") From 456d70d68172e5ad1afeb13187ec4d2fda96c142 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:54:30 +0000 Subject: [PATCH 64/75] Add cross-provider handoff parity e2e subset tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 3 +- docs/pkg-ai-test-parity.md | 2 +- pkg/ai/e2e/cross_provider_handoff_e2e_test.go | 183 ++++++++++++++++++ pkg/ai/e2e/parity_scaffolds_test.go | 5 - 4 files changed, 186 insertions(+), 7 deletions(-) create mode 100644 pkg/ai/e2e/cross_provider_handoff_e2e_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index 3c7f3e13..1022fc3d 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -66,13 +66,14 @@ The `pkg/ai/e2e` suite now includes live provider parity checks for: - usage total-token accounting (`total-tokens.test.ts` parity subset). - context-overflow detection (`context-overflow.test.ts` parity subset). - OpenAI Responses reasoning replay/handoff (`openai-responses-reasoning-replay-e2e.test.ts` subset). +- cross-provider handoff smoke coverage (`cross-provider-handoff.test.ts` subset). - Anthropic and Google complete/stream smoke coverage. Run with: ```bash PI_AI_E2E=1 OPENAI_API_KEY=... ANTHROPIC_API_KEY=... GEMINI_API_KEY=... \ - go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" + go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" ``` Optional overrides: diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 50b2e9bf..478f225e 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -20,6 +20,7 @@ Legend: - `total-tokens.test.ts` → ✅🧪 `pkg/ai/e2e/parity_openai_test.go` - `tokens.test.ts` → ✅🧪 `pkg/ai/e2e/abort_test.go` (OpenAI subset) - `openai-responses-reasoning-replay-e2e.test.ts` → ✅🧪 `pkg/ai/e2e/openai_reasoning_replay_e2e_test.go` (+ deterministic conversion assertions in `pkg/ai/providers/openai_responses_shared_test.go`) +- `cross-provider-handoff.test.ts` → ✅🧪 `pkg/ai/e2e/cross_provider_handoff_e2e_test.go` (OpenAI↔Anthropic subset) ### Provider/unit parity @@ -53,7 +54,6 @@ The following are currently kept as env-gated scaffolds in - 📝 `interleaved-thinking.test.ts` - 📝 `bedrock-models.test.ts` -- 📝 `cross-provider-handoff.test.ts` - 📝 `google-gemini-cli-empty-stream.test.ts` (full live parity) - 📝 `xhigh.test.ts` (live) - 📝 `zen.test.ts` diff --git a/pkg/ai/e2e/cross_provider_handoff_e2e_test.go b/pkg/ai/e2e/cross_provider_handoff_e2e_test.go new file mode 100644 index 00000000..a5459a8e --- /dev/null +++ b/pkg/ai/e2e/cross_provider_handoff_e2e_test.go @@ -0,0 +1,183 @@ +package e2e + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func TestCrossProviderHandoffE2E_OpenAIConsumesAnthropicContext(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIReasoningSourceModel() + providers.ResetAPIProviders() + + context := ai.Context{ + SystemPrompt: "You are a helpful assistant.", + Tools: []ai.Tool{doubleNumberTool()}, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Use the tool to double 21.", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeThinking, + Thinking: "I should call the tool first.", + ThinkingSignature: "anthropic_thinking_signature", + }, + { + Type: ai.ContentTypeToolCall, + ID: "toolu_123", + Name: "double_number", + Arguments: map[string]any{"value": 21}, + }, + }, + Provider: "anthropic", + API: ai.APIAnthropicMessages, + Model: "claude-sonnet-4-5", + StopReason: ai.StopReasonToolUse, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleToolResult, + ToolCallID: "toolu_123", + ToolName: "double_number", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "42"}, + }, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{{Type: ai.ContentTypeText, Text: "The doubled value is 42."}}, + Provider: "anthropic", + API: ai.APIAnthropicMessages, + Model: "claude-sonnet-4-5", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleUser, + Text: "Say hello to confirm handoff success.", + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + response, err := ai.CompleteSimple(model, context, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 256, + }, + Reasoning: ai.ThinkingHigh, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("expected non-error response, got %q", response.ErrorMessage) + } + text := strings.ToLower(strings.TrimSpace(firstText(response))) + if text == "" { + t.Fatalf("expected non-empty text response") + } + if !strings.Contains(text, "hello") { + t.Fatalf("expected handoff confirmation to contain hello, got %q", text) + } +} + +func TestCrossProviderHandoffE2E_AnthropicConsumesOpenAIContext(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")) + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY is not set") + } + model := anthropicE2EModel() + providers.ResetAPIProviders() + + context := ai.Context{ + SystemPrompt: "You are a helpful assistant.", + Tools: []ai.Tool{doubleNumberTool()}, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Use the tool to double 21.", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeThinking, + Thinking: "Need to call tool.", + ThinkingSignature: `{"type":"reasoning","id":"rs_123","summary":[{"type":"summary_text","text":"call tool"}]}`, + }, + { + Type: ai.ContentTypeToolCall, + ID: "call_123|fc_456", + Name: "double_number", + Arguments: map[string]any{"value": 21}, + }, + }, + Provider: "openai", + API: ai.APIOpenAIResponses, + Model: "gpt-5-mini", + StopReason: ai.StopReasonToolUse, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call_123|fc_456", + ToolName: "double_number", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "42"}, + }, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{{Type: ai.ContentTypeText, Text: "The doubled value is 42."}}, + Provider: "openai", + API: ai.APIOpenAIResponses, + Model: "gpt-5-mini", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleUser, + Text: "Say hello to confirm handoff success.", + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + response, err := ai.CompleteSimple(model, context, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 256, + }, + Reasoning: ai.ThinkingHigh, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("expected non-error response, got %q", response.ErrorMessage) + } + text := strings.ToLower(strings.TrimSpace(firstText(response))) + if text == "" { + t.Fatalf("expected non-empty text response") + } + if !strings.Contains(text, "hello") { + t.Fatalf("expected handoff confirmation to contain hello, got %q", text) + } +} diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index 6aa77b71..1aceb665 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -35,11 +35,6 @@ func TestAnthropicToolNameNormalizationE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for anthropic-tool-name-normalization.test.ts pending runtime implementation") } -func TestCrossProviderHandoffE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for cross-provider-handoff.test.ts pending runtime implementation") -} - func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") From ac8ae1eb6ab9b2f8e1b8198275502fd1d5efb37d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:56:44 +0000 Subject: [PATCH 65/75] Port OpenAI responses tool-result image conversion parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-test-parity.md | 2 +- pkg/ai/providers/openai_responses_shared.go | 28 ++++++ .../providers/openai_responses_shared_test.go | 86 +++++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 478f225e..7f290477 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -37,7 +37,7 @@ Legend: - `tool-call-id-normalization.test.ts` → ✅ `pkg/ai/providers/openai_responses_shared_test.go`, `pkg/ai/providers/openai_completions_convert_test.go` - `anthropic-tool-name-normalization.test.ts` → ✅ `pkg/ai/providers/anthropic_test.go` - `cache-retention.test.ts` → ✅ `pkg/ai/providers/cache_retention_test.go` -- `image-tool-result.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go` +- `image-tool-result.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go`, `pkg/ai/providers/openai_responses_shared_test.go` - `unicode-surrogate.test.ts` → ✅ `pkg/ai/utils/sanitize_unicode_test.go` - `supports-xhigh.test.ts` / `xhigh.test.ts` → ✅ `pkg/ai/models_test.go` - `interleaved-thinking.test.ts` (deterministic parts) → ✅ `pkg/ai/providers/anthropic_test.go`, `pkg/ai/providers/amazon_bedrock_test.go` diff --git a/pkg/ai/providers/openai_responses_shared.go b/pkg/ai/providers/openai_responses_shared.go index 9cb00c59..95d8f0b0 100644 --- a/pkg/ai/providers/openai_responses_shared.go +++ b/pkg/ai/providers/openai_responses_shared.go @@ -2,6 +2,7 @@ package providers import ( "encoding/json" + "slices" "strings" "github.com/beeper/ai-bridge/pkg/ai" @@ -158,10 +159,14 @@ func ConvertResponsesMessages( } output := "(see attached image)" var textParts []string + var imageBlocks []ai.ContentBlock for _, block := range msg.Content { if block.Type == ai.ContentTypeText { textParts = append(textParts, block.Text) } + if block.Type == ai.ContentTypeImage { + imageBlocks = append(imageBlocks, block) + } } if len(textParts) > 0 { output = strings.Join(textParts, "\n") @@ -171,6 +176,29 @@ func ConvertResponsesMessages( "call_id": callID, "output": utils.SanitizeSurrogates(output), }) + if len(imageBlocks) > 0 && slices.Contains(model.Input, "image") { + content := make([]map[string]any, 0, len(imageBlocks)+1) + content = append(content, map[string]any{ + "type": "input_text", + "text": "Attached image(s) from tool result:", + }) + for _, image := range imageBlocks { + if strings.TrimSpace(image.Data) == "" || strings.TrimSpace(image.MimeType) == "" { + continue + } + content = append(content, map[string]any{ + "type": "input_image", + "detail": "auto", + "image_url": "data:" + image.MimeType + ";base64," + image.Data, + }) + } + if len(content) > 1 { + messages = append(messages, map[string]any{ + "role": "user", + "content": content, + }) + } + } } } return messages diff --git a/pkg/ai/providers/openai_responses_shared_test.go b/pkg/ai/providers/openai_responses_shared_test.go index ff705937..903407a1 100644 --- a/pkg/ai/providers/openai_responses_shared_test.go +++ b/pkg/ai/providers/openai_responses_shared_test.go @@ -179,3 +179,89 @@ func TestConvertResponsesMessages_DropsAbortedReasoningOnlyAssistant(t *testing. } } } + +func TestConvertResponsesMessages_ToolResultImageOnlyAddsImageUserMessage(t *testing.T) { + model := ai.Model{ + ID: "gpt-5-mini", + Provider: "openai", + API: ai.APIOpenAIResponses, + Input: []string{"text", "image"}, + } + context := ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleToolResult, + ToolCallID: "call_123|fc_456", + ToolName: "get_circle", + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeImage, + Data: "abc123", + MimeType: "image/png", + }, + }, + }, + }, + } + + output := ConvertResponsesMessages(model, context, openAIToolCallProviders, nil) + if len(output) != 2 { + t.Fatalf("expected function_call_output plus image user message, got %d entries: %#v", len(output), output) + } + functionOutput := output[0] + if functionOutput["type"] != "function_call_output" { + t.Fatalf("expected first output to be function_call_output, got %#v", functionOutput) + } + if outputText, _ := functionOutput["output"].(string); outputText != "(see attached image)" { + t.Fatalf("expected image placeholder output text, got %q", outputText) + } + + userMessage := output[1] + if role, _ := userMessage["role"].(string); role != "user" { + t.Fatalf("expected second output to be user message, got %#v", userMessage) + } + content, _ := userMessage["content"].([]map[string]any) + if len(content) != 2 { + t.Fatalf("expected text prefix + image content, got %#v", userMessage["content"]) + } + if content[0]["type"] != "input_text" { + t.Fatalf("expected first content part input_text, got %#v", content[0]) + } + if content[1]["type"] != "input_image" { + t.Fatalf("expected second content part input_image, got %#v", content[1]) + } + if imageURL, _ := content[1]["image_url"].(string); imageURL != "data:image/png;base64,abc123" { + t.Fatalf("unexpected image_url encoding: %q", imageURL) + } +} + +func TestConvertResponsesMessages_ToolResultTextAndImageKeepsTextOutput(t *testing.T) { + model := ai.Model{ + ID: "gpt-5-mini", + Provider: "openai", + API: ai.APIOpenAIResponses, + Input: []string{"text", "image"}, + } + context := ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleToolResult, + ToolCallID: "call_123|fc_456", + ToolName: "get_circle_with_description", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "diameter is 100 pixels"}, + {Type: ai.ContentTypeImage, Data: "img64", MimeType: "image/png"}, + }, + }, + }, + } + + output := ConvertResponsesMessages(model, context, openAIToolCallProviders, nil) + if len(output) != 2 { + t.Fatalf("expected function_call_output plus image user message, got %d entries: %#v", len(output), output) + } + functionOutput := output[0] + if outputText, _ := functionOutput["output"].(string); outputText != "diameter is 100 pixels" { + t.Fatalf("expected text output to be preserved, got %q", outputText) + } +} From 9d620382e0ce695837fb42c247a8c3c64beddb97 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:58:20 +0000 Subject: [PATCH 66/75] Add OpenAI image tool-result parity e2e coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 3 +- docs/pkg-ai-test-parity.md | 2 +- pkg/ai/e2e/image_tool_result_e2e_test.go | 115 +++++++++++++++++++++++ pkg/ai/e2e/parity_scaffolds_test.go | 5 - 4 files changed, 118 insertions(+), 7 deletions(-) create mode 100644 pkg/ai/e2e/image_tool_result_e2e_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index 1022fc3d..dec01b3d 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -66,6 +66,7 @@ The `pkg/ai/e2e` suite now includes live provider parity checks for: - usage total-token accounting (`total-tokens.test.ts` parity subset). - context-overflow detection (`context-overflow.test.ts` parity subset). - OpenAI Responses reasoning replay/handoff (`openai-responses-reasoning-replay-e2e.test.ts` subset). +- tool-result image handling (`image-tool-result.test.ts` OpenAI subset). - cross-provider handoff smoke coverage (`cross-provider-handoff.test.ts` subset). - Anthropic and Google complete/stream smoke coverage. @@ -73,7 +74,7 @@ Run with: ```bash PI_AI_E2E=1 OPENAI_API_KEY=... ANTHROPIC_API_KEY=... GEMINI_API_KEY=... \ - go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" + go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestImageToolResultE2E_OpenAI|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" ``` Optional overrides: diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 7f290477..6eb9c7f7 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -37,7 +37,7 @@ Legend: - `tool-call-id-normalization.test.ts` → ✅ `pkg/ai/providers/openai_responses_shared_test.go`, `pkg/ai/providers/openai_completions_convert_test.go` - `anthropic-tool-name-normalization.test.ts` → ✅ `pkg/ai/providers/anthropic_test.go` - `cache-retention.test.ts` → ✅ `pkg/ai/providers/cache_retention_test.go` -- `image-tool-result.test.ts` → ✅ `pkg/ai/providers/openai_completions_test.go`, `pkg/ai/providers/openai_responses_shared_test.go` +- `image-tool-result.test.ts` → ✅🧪 `pkg/ai/e2e/image_tool_result_e2e_test.go` (OpenAI subset) + deterministic conversion tests in `pkg/ai/providers/openai_completions_test.go`, `pkg/ai/providers/openai_responses_shared_test.go` - `unicode-surrogate.test.ts` → ✅ `pkg/ai/utils/sanitize_unicode_test.go` - `supports-xhigh.test.ts` / `xhigh.test.ts` → ✅ `pkg/ai/models_test.go` - `interleaved-thinking.test.ts` (deterministic parts) → ✅ `pkg/ai/providers/anthropic_test.go`, `pkg/ai/providers/amazon_bedrock_test.go` diff --git a/pkg/ai/e2e/image_tool_result_e2e_test.go b/pkg/ai/e2e/image_tool_result_e2e_test.go new file mode 100644 index 00000000..5329a86f --- /dev/null +++ b/pkg/ai/e2e/image_tool_result_e2e_test.go @@ -0,0 +1,115 @@ +package e2e + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +// 1x1 red PNG. +const redPixelBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR42mP8z8BQDwAE/wH+1J1iGQAAAABJRU5ErkJggg==" + +func TestImageToolResultE2E_OpenAI(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + model.Input = []string{"text", "image"} + providers.ResetAPIProviders() + + testCases := []struct { + name string + toolResult []ai.ContentBlock + prompt string + expectKeyword string + }{ + { + name: "image-only-tool-result", + toolResult: []ai.ContentBlock{ + {Type: ai.ContentTypeImage, Data: redPixelBase64, MimeType: "image/png"}, + }, + prompt: "Describe what you see in the tool result image. Mention the color.", + expectKeyword: "red", + }, + { + name: "text-and-image-tool-result", + toolResult: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "The shape has a diameter of 100 pixels."}, + {Type: ai.ContentTypeImage, Data: redPixelBase64, MimeType: "image/png"}, + }, + prompt: "Summarize the tool result details and mention any visible color.", + expectKeyword: "pixel", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + context := ai.Context{ + SystemPrompt: "You are a helpful assistant.", + Tools: []ai.Tool{doubleNumberTool()}, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Use the tool result to answer the next question.", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: "call_123|fc_456", + Name: "double_number", + Arguments: map[string]any{"value": 21}, + }, + }, + Provider: model.Provider, + API: model.API, + Model: model.ID, + StopReason: ai.StopReasonToolUse, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleToolResult, + ToolCallID: "call_123|fc_456", + ToolName: "double_number", + Content: tc.toolResult, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleUser, + Text: tc.prompt, + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + response, err := ai.CompleteSimple(model, context, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 512, + }, + Reasoning: ai.ThinkingMedium, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("expected non-error response, got %q", response.ErrorMessage) + } + text := strings.ToLower(strings.TrimSpace(firstText(response))) + if text == "" { + t.Fatalf("expected non-empty text response") + } + if !strings.Contains(text, tc.expectKeyword) && !strings.Contains(text, "red") { + t.Fatalf("expected response to reference tool result content, got %q", text) + } + }) + } +} diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index 1aceb665..c4ace38b 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -55,11 +55,6 @@ func TestEmptyE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for empty.test.ts pending runtime implementation") } -func TestImageToolResultE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for image-tool-result.test.ts pending runtime implementation") -} - func TestGoogleGeminiCliClaudeThinkingHeaderE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for google-gemini-cli-claude-thinking-header.test.ts pending runtime implementation") From c482c598c84c8a005fc8bc063131c5dd97c49694 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 09:59:44 +0000 Subject: [PATCH 67/75] Add OpenAI tool-call ID normalization parity e2e MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 3 +- docs/pkg-ai-test-parity.md | 2 +- pkg/ai/e2e/parity_scaffolds_test.go | 5 -- .../tool_call_id_normalization_e2e_test.go | 87 +++++++++++++++++++ 4 files changed, 90 insertions(+), 7 deletions(-) create mode 100644 pkg/ai/e2e/tool_call_id_normalization_e2e_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index dec01b3d..cb3a2ca0 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -67,6 +67,7 @@ The `pkg/ai/e2e` suite now includes live provider parity checks for: - context-overflow detection (`context-overflow.test.ts` parity subset). - OpenAI Responses reasoning replay/handoff (`openai-responses-reasoning-replay-e2e.test.ts` subset). - tool-result image handling (`image-tool-result.test.ts` OpenAI subset). +- tool-call-id normalization (`tool-call-id-normalization.test.ts` OpenAI subset). - cross-provider handoff smoke coverage (`cross-provider-handoff.test.ts` subset). - Anthropic and Google complete/stream smoke coverage. @@ -74,7 +75,7 @@ Run with: ```bash PI_AI_E2E=1 OPENAI_API_KEY=... ANTHROPIC_API_KEY=... GEMINI_API_KEY=... \ - go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestImageToolResultE2E_OpenAI|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" + go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestImageToolResultE2E_OpenAI|TestToolCallIDNormalizationE2E_OpenAI|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" ``` Optional overrides: diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 6eb9c7f7..e3c45aeb 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -21,6 +21,7 @@ Legend: - `tokens.test.ts` → ✅🧪 `pkg/ai/e2e/abort_test.go` (OpenAI subset) - `openai-responses-reasoning-replay-e2e.test.ts` → ✅🧪 `pkg/ai/e2e/openai_reasoning_replay_e2e_test.go` (+ deterministic conversion assertions in `pkg/ai/providers/openai_responses_shared_test.go`) - `cross-provider-handoff.test.ts` → ✅🧪 `pkg/ai/e2e/cross_provider_handoff_e2e_test.go` (OpenAI↔Anthropic subset) +- `tool-call-id-normalization.test.ts` → ✅🧪 `pkg/ai/e2e/tool_call_id_normalization_e2e_test.go` (OpenAI subset) + deterministic ID normalization tests in providers ### Provider/unit parity @@ -34,7 +35,6 @@ Legend: - `google-shared-gemini3-unsigned-tool-call.test.ts` → ✅ `pkg/ai/providers/google_shared_test.go` - `google-thinking-signature.test.ts` → ✅ `pkg/ai/providers/google_shared_test.go` - `transform-messages-copilot-openai-to-anthropic.test.ts` → ✅ `pkg/ai/providers/transform_messages_test.go` -- `tool-call-id-normalization.test.ts` → ✅ `pkg/ai/providers/openai_responses_shared_test.go`, `pkg/ai/providers/openai_completions_convert_test.go` - `anthropic-tool-name-normalization.test.ts` → ✅ `pkg/ai/providers/anthropic_test.go` - `cache-retention.test.ts` → ✅ `pkg/ai/providers/cache_retention_test.go` - `image-tool-result.test.ts` → ✅🧪 `pkg/ai/e2e/image_tool_result_e2e_test.go` (OpenAI subset) + deterministic conversion tests in `pkg/ai/providers/openai_completions_test.go`, `pkg/ai/providers/openai_responses_shared_test.go` diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index c4ace38b..d6edb893 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -25,11 +25,6 @@ func TestBedrockModelsE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for bedrock-models.test.ts pending runtime implementation") } -func TestToolCallIDNormalizationE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for tool-call-id-normalization.test.ts pending runtime implementation") -} - func TestAnthropicToolNameNormalizationE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for anthropic-tool-name-normalization.test.ts pending runtime implementation") diff --git a/pkg/ai/e2e/tool_call_id_normalization_e2e_test.go b/pkg/ai/e2e/tool_call_id_normalization_e2e_test.go new file mode 100644 index 00000000..5ca6633d --- /dev/null +++ b/pkg/ai/e2e/tool_call_id_normalization_e2e_test.go @@ -0,0 +1,87 @@ +package e2e + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func TestToolCallIDNormalizationE2E_OpenAI(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIReasoningSourceModel() + providers.ResetAPIProviders() + + rawToolCallID := "call_abc|item+/==" + context := ai.Context{ + SystemPrompt: "You are a helpful assistant.", + Tools: []ai.Tool{doubleNumberTool()}, + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Use the tool to double 21.", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{ + { + Type: ai.ContentTypeToolCall, + ID: rawToolCallID, + Name: "double_number", + Arguments: map[string]any{"value": 21}, + }, + }, + Provider: "openai", + API: ai.APIOpenAIResponses, + Model: "gpt-5-mini", + StopReason: ai.StopReasonToolUse, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleToolResult, + ToolCallID: rawToolCallID, + ToolName: "double_number", + Content: []ai.ContentBlock{ + {Type: ai.ContentTypeText, Text: "42"}, + }, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleUser, + Text: "What was the result? Answer with just the number.", + Timestamp: time.Now().UnixMilli(), + }, + }, + } + + response, err := ai.CompleteSimple(model, context, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 256, + }, + Reasoning: ai.ThinkingHigh, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + t.Fatalf("expected non-error response after normalization, got %q", response.ErrorMessage) + } + text := strings.ToLower(strings.TrimSpace(firstText(response))) + if text == "" { + t.Fatalf("expected non-empty text response") + } + if !strings.Contains(text, "42") && + !strings.Contains(text, "forty-two") && + !strings.Contains(text, "forty two") { + t.Fatalf("expected response to reference tool result, got %q", text) + } +} From dfe549baabdc9f96036a0124bc41f6be2188c4a4 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 10:01:07 +0000 Subject: [PATCH 68/75] Drop scaffolds covered by deterministic parity tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-test-parity.md | 2 -- pkg/ai/e2e/parity_scaffolds_test.go | 10 ---------- 2 files changed, 12 deletions(-) diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index e3c45aeb..59632960 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -58,6 +58,4 @@ The following are currently kept as env-gated scaffolds in - 📝 `xhigh.test.ts` (live) - 📝 `zen.test.ts` - 📝 `empty.test.ts` -- 📝 `image-tool-result.test.ts` (live) -- 📝 `google-gemini-cli-claude-thinking-header.test.ts` (live) - 📝 `github-copilot-anthropic.test.ts` (live) diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index d6edb893..47314009 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -25,11 +25,6 @@ func TestBedrockModelsE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for bedrock-models.test.ts pending runtime implementation") } -func TestAnthropicToolNameNormalizationE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for anthropic-tool-name-normalization.test.ts pending runtime implementation") -} - func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") @@ -50,11 +45,6 @@ func TestEmptyE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for empty.test.ts pending runtime implementation") } -func TestGoogleGeminiCliClaudeThinkingHeaderE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for google-gemini-cli-claude-thinking-header.test.ts pending runtime implementation") -} - func TestGithubCopilotAnthropicE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for github-copilot-anthropic.test.ts pending runtime implementation") From 215751cd0e64f3a55fa1c579606d1306d1345c70 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 10:02:43 +0000 Subject: [PATCH 69/75] Add OpenAI xhigh parity e2e subset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 3 +- docs/pkg-ai-test-parity.md | 2 +- pkg/ai/e2e/parity_scaffolds_test.go | 5 --- pkg/ai/e2e/xhigh_e2e_test.go | 57 +++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 pkg/ai/e2e/xhigh_e2e_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index cb3a2ca0..a8851d0a 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -68,6 +68,7 @@ The `pkg/ai/e2e` suite now includes live provider parity checks for: - OpenAI Responses reasoning replay/handoff (`openai-responses-reasoning-replay-e2e.test.ts` subset). - tool-result image handling (`image-tool-result.test.ts` OpenAI subset). - tool-call-id normalization (`tool-call-id-normalization.test.ts` OpenAI subset). +- xhigh reasoning request path (`xhigh.test.ts` OpenAI subset). - cross-provider handoff smoke coverage (`cross-provider-handoff.test.ts` subset). - Anthropic and Google complete/stream smoke coverage. @@ -75,7 +76,7 @@ Run with: ```bash PI_AI_E2E=1 OPENAI_API_KEY=... ANTHROPIC_API_KEY=... GEMINI_API_KEY=... \ - go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestImageToolResultE2E_OpenAI|TestToolCallIDNormalizationE2E_OpenAI|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" + go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestImageToolResultE2E_OpenAI|TestToolCallIDNormalizationE2E_OpenAI|TestXhighE2E_OpenAIResponses|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" ``` Optional overrides: diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 59632960..bb8289e0 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -22,6 +22,7 @@ Legend: - `openai-responses-reasoning-replay-e2e.test.ts` → ✅🧪 `pkg/ai/e2e/openai_reasoning_replay_e2e_test.go` (+ deterministic conversion assertions in `pkg/ai/providers/openai_responses_shared_test.go`) - `cross-provider-handoff.test.ts` → ✅🧪 `pkg/ai/e2e/cross_provider_handoff_e2e_test.go` (OpenAI↔Anthropic subset) - `tool-call-id-normalization.test.ts` → ✅🧪 `pkg/ai/e2e/tool_call_id_normalization_e2e_test.go` (OpenAI subset) + deterministic ID normalization tests in providers +- `xhigh.test.ts` → ✅🧪 `pkg/ai/e2e/xhigh_e2e_test.go` (OpenAI subset) + deterministic support checks in `pkg/ai/models_test.go` ### Provider/unit parity @@ -55,7 +56,6 @@ The following are currently kept as env-gated scaffolds in - 📝 `interleaved-thinking.test.ts` - 📝 `bedrock-models.test.ts` - 📝 `google-gemini-cli-empty-stream.test.ts` (full live parity) -- 📝 `xhigh.test.ts` (live) - 📝 `zen.test.ts` - 📝 `empty.test.ts` - 📝 `github-copilot-anthropic.test.ts` (live) diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index 47314009..333d0545 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -30,11 +30,6 @@ func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") } -func TestXhighE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for xhigh.test.ts pending runtime implementation") -} - func TestZenE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for zen.test.ts pending runtime implementation") diff --git a/pkg/ai/e2e/xhigh_e2e_test.go b/pkg/ai/e2e/xhigh_e2e_test.go new file mode 100644 index 00000000..91220a2b --- /dev/null +++ b/pkg/ai/e2e/xhigh_e2e_test.go @@ -0,0 +1,57 @@ +package e2e + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func TestXhighE2E_OpenAIResponses(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIReasoningTargetModel() + providers.ResetAPIProviders() + + response, err := ai.CompleteSimple(model, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Think step by step and then answer: what is 17 + 25?", + Timestamp: time.Now().UnixMilli(), + }, + }, + }, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 256, + }, + Reasoning: ai.ThinkingXHigh, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + if response.StopReason == ai.StopReasonError { + lower := strings.ToLower(response.ErrorMessage) + if strings.Contains(lower, "model") && + (strings.Contains(lower, "not found") || strings.Contains(lower, "does not exist") || strings.Contains(lower, "access")) { + t.Skipf("model not available for this API key: %s", response.ErrorMessage) + } + t.Fatalf("expected non-error response, got %q", response.ErrorMessage) + } + text := strings.ToLower(strings.TrimSpace(firstText(response))) + if text == "" { + t.Fatalf("expected non-empty text response") + } + if !strings.Contains(text, "42") && + !strings.Contains(text, "forty-two") && + !strings.Contains(text, "forty two") { + t.Fatalf("expected computed answer in response, got %q", text) + } +} From 781ed39cc12132ae4f871a416aedc9dd90989f63 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 10:03:43 +0000 Subject: [PATCH 70/75] Remove interleaved and bedrock e2e scaffolds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-test-parity.md | 2 -- pkg/ai/e2e/parity_scaffolds_test.go | 10 ---------- 2 files changed, 12 deletions(-) diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index bb8289e0..38d37e9d 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -53,8 +53,6 @@ Legend: The following are currently kept as env-gated scaffolds in `pkg/ai/e2e/parity_scaffolds_test.go`: -- 📝 `interleaved-thinking.test.ts` -- 📝 `bedrock-models.test.ts` - 📝 `google-gemini-cli-empty-stream.test.ts` (full live parity) - 📝 `zen.test.ts` - 📝 `empty.test.ts` diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index 333d0545..fd0d8cea 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -15,16 +15,6 @@ func requirePIAIE2E(t *testing.T) { } } -func TestInterleavedThinkingE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for interleaved-thinking.test.ts pending runtime implementation") -} - -func TestBedrockModelsE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for bedrock-models.test.ts pending runtime implementation") -} - func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") From b3341aea21752fbf2177a94958927b60b9ca4842 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 10:11:49 +0000 Subject: [PATCH 71/75] Add OpenAI empty-message parity e2e subset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 3 +- docs/pkg-ai-test-parity.md | 2 +- pkg/ai/e2e/empty_e2e_test.go | 128 ++++++++++++++++++++++++++++ pkg/ai/e2e/parity_scaffolds_test.go | 5 -- 4 files changed, 131 insertions(+), 7 deletions(-) create mode 100644 pkg/ai/e2e/empty_e2e_test.go diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index a8851d0a..b99323b9 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -69,6 +69,7 @@ The `pkg/ai/e2e` suite now includes live provider parity checks for: - tool-result image handling (`image-tool-result.test.ts` OpenAI subset). - tool-call-id normalization (`tool-call-id-normalization.test.ts` OpenAI subset). - xhigh reasoning request path (`xhigh.test.ts` OpenAI subset). +- empty-message handling (`empty.test.ts` OpenAI subset). - cross-provider handoff smoke coverage (`cross-provider-handoff.test.ts` subset). - Anthropic and Google complete/stream smoke coverage. @@ -76,7 +77,7 @@ Run with: ```bash PI_AI_E2E=1 OPENAI_API_KEY=... ANTHROPIC_API_KEY=... GEMINI_API_KEY=... \ - go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestImageToolResultE2E_OpenAI|TestToolCallIDNormalizationE2E_OpenAI|TestXhighE2E_OpenAIResponses|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" + go test ./pkg/ai/e2e -run "TestGenerateE2E_OpenAI|TestAbortE2E_OpenAIStream|TestToolCallWithoutResultE2E_OpenAI|TestTotalTokensE2E_OpenAI|TestContextOverflowE2E_OpenAI|TestOpenAIReasoningReplayE2E_|TestImageToolResultE2E_OpenAI|TestToolCallIDNormalizationE2E_OpenAI|TestXhighE2E_OpenAIResponses|TestEmptyE2E_OpenAI|TestCrossProviderHandoffE2E_|TestGenerateE2E_Anthropic|TestGenerateE2E_Google" ``` Optional overrides: diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 38d37e9d..69a86acd 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -23,6 +23,7 @@ Legend: - `cross-provider-handoff.test.ts` → ✅🧪 `pkg/ai/e2e/cross_provider_handoff_e2e_test.go` (OpenAI↔Anthropic subset) - `tool-call-id-normalization.test.ts` → ✅🧪 `pkg/ai/e2e/tool_call_id_normalization_e2e_test.go` (OpenAI subset) + deterministic ID normalization tests in providers - `xhigh.test.ts` → ✅🧪 `pkg/ai/e2e/xhigh_e2e_test.go` (OpenAI subset) + deterministic support checks in `pkg/ai/models_test.go` +- `empty.test.ts` → ✅🧪 `pkg/ai/e2e/empty_e2e_test.go` (OpenAI subset) ### Provider/unit parity @@ -55,5 +56,4 @@ The following are currently kept as env-gated scaffolds in - 📝 `google-gemini-cli-empty-stream.test.ts` (full live parity) - 📝 `zen.test.ts` -- 📝 `empty.test.ts` - 📝 `github-copilot-anthropic.test.ts` (live) diff --git a/pkg/ai/e2e/empty_e2e_test.go b/pkg/ai/e2e/empty_e2e_test.go new file mode 100644 index 00000000..ebee4eac --- /dev/null +++ b/pkg/ai/e2e/empty_e2e_test.go @@ -0,0 +1,128 @@ +package e2e + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/beeper/ai-bridge/pkg/ai" + "github.com/beeper/ai-bridge/pkg/ai/providers" +) + +func TestEmptyE2E_OpenAI(t *testing.T) { + requirePIAIE2E(t) + apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + if apiKey == "" { + t.Skip("OPENAI_API_KEY is not set") + } + model := openAIE2EModel() + providers.ResetAPIProviders() + + t.Run("empty-content-array", func(t *testing.T) { + response := completeOpenAISimple(t, model, apiKey, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Content: []ai.ContentBlock{}, + Timestamp: time.Now().UnixMilli(), + }, + }, + }) + assertGracefulEmptyResponse(t, response) + }) + + t.Run("empty-string", func(t *testing.T) { + response := completeOpenAISimple(t, model, apiKey, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "", + Timestamp: time.Now().UnixMilli(), + }, + }, + }) + assertGracefulEmptyResponse(t, response) + }) + + t.Run("whitespace-only", func(t *testing.T) { + response := completeOpenAISimple(t, model, apiKey, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: " \n\t ", + Timestamp: time.Now().UnixMilli(), + }, + }, + }) + assertGracefulEmptyResponse(t, response) + }) + + t.Run("empty-assistant-in-history", func(t *testing.T) { + response := completeOpenAISimple(t, model, apiKey, ai.Context{ + Messages: []ai.Message{ + { + Role: ai.RoleUser, + Text: "Hello, how are you?", + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleAssistant, + Content: []ai.ContentBlock{}, + API: model.API, + Provider: model.Provider, + Model: model.ID, + StopReason: ai.StopReasonStop, + Usage: ai.Usage{ + Input: 10, + Output: 0, + CacheRead: 0, + CacheWrite: 0, + TotalTokens: 10, + }, + Timestamp: time.Now().UnixMilli(), + }, + { + Role: ai.RoleUser, + Text: "Please respond this time.", + Timestamp: time.Now().UnixMilli(), + }, + }, + }) + assertGracefulEmptyResponse(t, response) + if response.StopReason != ai.StopReasonError && len(response.Content) == 0 { + t.Fatalf("expected non-empty assistant response content") + } + }) +} + +func completeOpenAISimple(t *testing.T, model ai.Model, apiKey string, context ai.Context) ai.Message { + t.Helper() + response, err := ai.CompleteSimple(model, context, &ai.SimpleStreamOptions{ + StreamOptions: ai.StreamOptions{ + APIKey: apiKey, + MaxTokens: 256, + }, + Reasoning: ai.ThinkingMedium, + }) + if err != nil { + t.Fatalf("complete failed: %v", err) + } + return response +} + +func assertGracefulEmptyResponse(t *testing.T, response ai.Message) { + t.Helper() + if response.Role != ai.RoleAssistant { + t.Fatalf("expected assistant role, got %q", response.Role) + } + if response.StopReason == ai.StopReasonError { + if strings.TrimSpace(response.ErrorMessage) == "" { + t.Fatalf("expected non-empty error message for error response") + } + return + } + if response.Content == nil { + t.Fatalf("expected content to be initialized") + } +} diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index fd0d8cea..e0a9530a 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -25,11 +25,6 @@ func TestZenE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for zen.test.ts pending runtime implementation") } -func TestEmptyE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for empty.test.ts pending runtime implementation") -} - func TestGithubCopilotAnthropicE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for github-copilot-anthropic.test.ts pending runtime implementation") From a87393a9e459a586a2e03e9bb6a84c9e8eb1cb61 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 10:35:40 +0000 Subject: [PATCH 72/75] Remove Gemini empty-stream e2e scaffold MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-test-parity.md | 1 - pkg/ai/e2e/parity_scaffolds_test.go | 5 ----- 2 files changed, 6 deletions(-) diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 69a86acd..7a3c4e05 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -54,6 +54,5 @@ Legend: The following are currently kept as env-gated scaffolds in `pkg/ai/e2e/parity_scaffolds_test.go`: -- 📝 `google-gemini-cli-empty-stream.test.ts` (full live parity) - 📝 `zen.test.ts` - 📝 `github-copilot-anthropic.test.ts` (live) diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index e0a9530a..af7d2eeb 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -15,11 +15,6 @@ func requirePIAIE2E(t *testing.T) { } } -func TestGoogleGeminiCLIEmptyStreamE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for google-gemini-cli-empty-stream.test.ts pending runtime implementation") -} - func TestZenE2EParityScaffold(t *testing.T) { requirePIAIE2E(t) t.Skip("parity scaffold for zen.test.ts pending runtime implementation") From 2a787bccdc449729b74bc25db15d9b33c496ea8e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 10:49:37 +0000 Subject: [PATCH 73/75] Align Anthropic runtime with Copilot header semantics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-test-parity.md | 2 +- pkg/ai/e2e/parity_scaffolds_test.go | 4 -- pkg/ai/providers/anthropic_runtime.go | 70 ++++++++++++++++++++-- pkg/ai/providers/anthropic_runtime_test.go | 51 ++++++++++++++++ 4 files changed, 117 insertions(+), 10 deletions(-) diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 7a3c4e05..9aedb3af 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -33,6 +33,7 @@ Legend: - `google-gemini-cli-retry-delay.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` - `google-gemini-cli-empty-stream.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` - `google-gemini-cli-claude-thinking-header.test.ts` → ✅ `pkg/ai/providers/google_gemini_cli_test.go` +- `github-copilot-anthropic.test.ts` → ✅ `pkg/ai/providers/anthropic_runtime_test.go`, `pkg/ai/providers/github_copilot_headers_test.go` - `google-tool-call-missing-args.test.ts` → ✅ `pkg/ai/providers/google_tool_call_missing_args_test.go` - `google-shared-gemini3-unsigned-tool-call.test.ts` → ✅ `pkg/ai/providers/google_shared_test.go` - `google-thinking-signature.test.ts` → ✅ `pkg/ai/providers/google_shared_test.go` @@ -55,4 +56,3 @@ The following are currently kept as env-gated scaffolds in `pkg/ai/e2e/parity_scaffolds_test.go`: - 📝 `zen.test.ts` -- 📝 `github-copilot-anthropic.test.ts` (live) diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/parity_scaffolds_test.go index af7d2eeb..a373ade5 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/parity_scaffolds_test.go @@ -20,7 +20,3 @@ func TestZenE2EParityScaffold(t *testing.T) { t.Skip("parity scaffold for zen.test.ts pending runtime implementation") } -func TestGithubCopilotAnthropicE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for github-copilot-anthropic.test.ts pending runtime implementation") -} diff --git a/pkg/ai/providers/anthropic_runtime.go b/pkg/ai/providers/anthropic_runtime.go index 69ebd04d..88150a90 100644 --- a/pkg/ai/providers/anthropic_runtime.go +++ b/pkg/ai/providers/anthropic_runtime.go @@ -2,6 +2,7 @@ package providers import ( "context" + "slices" "strings" "time" @@ -12,6 +13,8 @@ import ( "github.com/beeper/ai-bridge/pkg/ai" ) +const anthropicFineGrainedToolStreamingBeta = "fine-grained-tool-streaming-2025-05-14" + func streamAnthropicMessages(model ai.Model, c ai.Context, options *ai.StreamOptions) *ai.AssistantMessageEventStream { anthropicOptions := AnthropicOptions{} if options != nil { @@ -77,8 +80,9 @@ func streamAnthropicMessagesWithOptions( } request := anthropicparam.Override[anthropic.MessageNewParams](payload) + clientConfig := buildAnthropicClientConfig(model, c, apiKey, betaHeader, options.StreamOptions.Headers) reqOptions := []anthropicoption.RequestOption{} - if isOAuthAnthropicToken(apiKey) || model.Provider == "github-copilot" { + if clientConfig.UseAuthToken { reqOptions = append(reqOptions, anthropicoption.WithAuthToken(apiKey)) } else { reqOptions = append(reqOptions, anthropicoption.WithAPIKey(apiKey)) @@ -86,11 +90,10 @@ func streamAnthropicMessagesWithOptions( if baseURL := strings.TrimSpace(model.BaseURL); baseURL != "" { reqOptions = append(reqOptions, anthropicoption.WithBaseURL(baseURL)) } - if betaHeader != "" { - reqOptions = append(reqOptions, anthropicoption.WithHeader("anthropic-beta", betaHeader)) + if clientConfig.BetaHeader != "" { + reqOptions = append(reqOptions, anthropicoption.WithHeader("anthropic-beta", clientConfig.BetaHeader)) } - reqOptions = appendAnthropicHeaderOptions(reqOptions, model.Headers) - reqOptions = appendAnthropicHeaderOptions(reqOptions, options.StreamOptions.Headers) + reqOptions = appendAnthropicHeaderOptions(reqOptions, clientConfig.Headers) client := anthropic.NewClient(reqOptions...) runCtx := options.StreamOptions.Ctx @@ -252,6 +255,63 @@ func isOAuthAnthropicToken(apiKey string) bool { return strings.Contains(apiKey, "sk-ant-oat") } +type anthropicClientConfig struct { + UseAuthToken bool + BetaHeader string + Headers map[string]string +} + +func buildAnthropicClientConfig( + model ai.Model, + context ai.Context, + apiKey string, + payloadBetaHeader string, + optionHeaders map[string]string, +) anthropicClientConfig { + isCopilot := model.Provider == "github-copilot" + isOAuth := isOAuthAnthropicToken(apiKey) + config := anthropicClientConfig{ + UseAuthToken: isCopilot || isOAuth, + Headers: map[string]string{}, + } + + betaFeatures := make([]string, 0, 2) + if !isCopilot { + betaFeatures = append(betaFeatures, anthropicFineGrainedToolStreamingBeta) + } + for _, token := range strings.Split(payloadBetaHeader, ",") { + trimmed := strings.TrimSpace(token) + if trimmed == "" { + continue + } + if !slices.Contains(betaFeatures, trimmed) { + betaFeatures = append(betaFeatures, trimmed) + } + } + if len(betaFeatures) > 0 { + config.BetaHeader = strings.Join(betaFeatures, ",") + } + + mergeHeaderMaps(config.Headers, model.Headers) + if isCopilot { + mergeHeaderMaps(config.Headers, BuildCopilotDynamicHeaders(context.Messages, HasCopilotVisionInput(context.Messages))) + } + mergeHeaderMaps(config.Headers, optionHeaders) + return config +} + +func mergeHeaderMaps(target map[string]string, maps ...map[string]string) { + for _, m := range maps { + for key, value := range m { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + target[key] = trimmed + } + } +} + func supportsAdaptiveThinkingModel(modelID string) bool { id := strings.ToLower(strings.TrimSpace(modelID)) return strings.Contains(id, "opus-4-6") || strings.Contains(id, "opus-4.6") || diff --git a/pkg/ai/providers/anthropic_runtime_test.go b/pkg/ai/providers/anthropic_runtime_test.go index ef6a9a11..af3fc224 100644 --- a/pkg/ai/providers/anthropic_runtime_test.go +++ b/pkg/ai/providers/anthropic_runtime_test.go @@ -77,3 +77,54 @@ func TestSupportsAdaptiveThinkingModel(t *testing.T) { t.Fatalf("did not expect sonnet 4.5 to support adaptive thinking") } } + +func TestBuildAnthropicClientConfig_CopilotAuthAndHeaders(t *testing.T) { + model := ai.Model{ + ID: "claude-sonnet-4", + Provider: "github-copilot", + API: ai.APIAnthropicMessages, + Headers: map[string]string{ + "User-Agent": "GitHubCopilotChat/0.27.0", + "Copilot-Integration-Id": "vscode-chat", + "anthropic-dangerous-direct-browser-access": "true", + }, + } + ctx := ai.Context{ + Messages: []ai.Message{ + {Role: ai.RoleUser, Text: "Hello"}, + }, + } + config := buildAnthropicClientConfig(model, ctx, "tid_copilot_session_test_token", "", nil) + if !config.UseAuthToken { + t.Fatalf("expected copilot config to use bearer auth token") + } + if strings.Contains(config.BetaHeader, anthropicFineGrainedToolStreamingBeta) { + t.Fatalf("did not expect copilot beta header to include fine-grained tool streaming") + } + if config.Headers["X-Initiator"] != "user" { + t.Fatalf("expected X-Initiator=user, got %q", config.Headers["X-Initiator"]) + } + if config.Headers["Openai-Intent"] != "conversation-edits" { + t.Fatalf("expected Openai-Intent header, got %q", config.Headers["Openai-Intent"]) + } + if !strings.Contains(config.Headers["User-Agent"], "GitHubCopilotChat") { + t.Fatalf("expected copilot user-agent header, got %q", config.Headers["User-Agent"]) + } +} + +func TestBuildAnthropicClientConfig_CopilotInterleavedThinkingHeader(t *testing.T) { + model := ai.Model{Provider: "github-copilot"} + config := buildAnthropicClientConfig( + model, + ai.Context{Messages: []ai.Message{{Role: ai.RoleUser, Text: "hello"}}}, + "tid_copilot_session_test_token", + "interleaved-thinking-2025-05-14", + nil, + ) + if !strings.Contains(config.BetaHeader, "interleaved-thinking-2025-05-14") { + t.Fatalf("expected interleaved-thinking beta header, got %q", config.BetaHeader) + } + if strings.Contains(config.BetaHeader, anthropicFineGrainedToolStreamingBeta) { + t.Fatalf("did not expect copilot beta header to include %q", anthropicFineGrainedToolStreamingBeta) + } +} From 0293ee3330546d62cc2a22bf1e779a68dab54b54 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 4 Mar 2026 10:55:54 +0000 Subject: [PATCH 74/75] Remove remaining e2e scaffolds and document scope MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-test-parity.md | 8 +++----- pkg/ai/e2e/{parity_scaffolds_test.go => helpers_test.go} | 8 +------- 2 files changed, 4 insertions(+), 12 deletions(-) rename pkg/ai/e2e/{parity_scaffolds_test.go => helpers_test.go} (51%) diff --git a/docs/pkg-ai-test-parity.md b/docs/pkg-ai-test-parity.md index 9aedb3af..cfdc820b 100644 --- a/docs/pkg-ai-test-parity.md +++ b/docs/pkg-ai-test-parity.md @@ -50,9 +50,7 @@ Legend: - `oauth.ts` (provider/token helper semantics) → ✅ `pkg/ai/oauth/*_test.go` -### Remaining scaffolds in Go e2e suite +### Out of scope for this port -The following are currently kept as env-gated scaffolds in -`pkg/ai/e2e/parity_scaffolds_test.go`: - -- 📝 `zen.test.ts` +- `zen.test.ts` targets OpenCode Zen-specific provider/runtime behavior, which is + outside the `pi-mono/packages/ai/src` provider set and outside this port scope. diff --git a/pkg/ai/e2e/parity_scaffolds_test.go b/pkg/ai/e2e/helpers_test.go similarity index 51% rename from pkg/ai/e2e/parity_scaffolds_test.go rename to pkg/ai/e2e/helpers_test.go index a373ade5..ee017aaa 100644 --- a/pkg/ai/e2e/parity_scaffolds_test.go +++ b/pkg/ai/e2e/helpers_test.go @@ -8,15 +8,9 @@ import ( func requirePIAIE2E(t *testing.T) { t.Helper() if testing.Short() { - t.Skip("skipping e2e parity scaffolds in short mode") + t.Skip("skipping e2e tests in short mode") } if os.Getenv("PI_AI_E2E") == "" { t.Skip("set PI_AI_E2E=1 to enable ai package e2e tests") } } - -func TestZenE2EParityScaffold(t *testing.T) { - requirePIAIE2E(t) - t.Skip("parity scaffold for zen.test.ts pending runtime implementation") -} - From cb48fce1a51be2d7f85d7dfc65fdaed76b8074a0 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Thu, 5 Mar 2026 01:58:44 +0000 Subject: [PATCH 75/75] Make pkg/ai the primary connector runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: batuhan içöz --- docs/pkg-ai-runtime-migration.md | 51 +-- pkg/connector/pkg_ai_provider_bridge.go | 23 +- pkg/connector/pkg_ai_provider_bridge_test.go | 8 +- pkg/connector/provider_openai.go | 191 +------- pkg/connector/provider_openai_pkg_ai_test.go | 1 - pkg/connector/streaming_runtime_selector.go | 420 +++++++++++------- .../streaming_runtime_selector_test.go | 99 +---- 7 files changed, 300 insertions(+), 493 deletions(-) diff --git a/docs/pkg-ai-runtime-migration.md b/docs/pkg-ai-runtime-migration.md index b99323b9..996767c7 100644 --- a/docs/pkg-ai-runtime-migration.md +++ b/docs/pkg-ai-runtime-migration.md @@ -1,45 +1,27 @@ # pkg/ai Runtime Migration Notes This repository now includes a standalone `pkg/ai` Go port of `pi-mono/packages/ai`, -plus controlled connector bridge paths that can route runtime execution to `pkg/ai`. +and the connector now runs on `pkg/ai` as the primary runtime. -## Feature flags +## Runtime architecture (current) -### Connector runtime selector +- Connector streaming path: + - `selectResponseFn` now selects `pkg_ai` by default for all non-audio prompts. + - `streamWithPkgAIBridge` is the primary streaming implementation. + - Tool-call rounds are executed via connector tool infrastructure and continued + through `pkg/ai` context updates. + - Audio prompts continue using Chat Completions path until audio parity lands in + `pkg/ai`. -- `PI_USE_PKG_AI_RUNTIME=1` - - Enables connector runtime selection path (`streamWithPkgAIBridge`). - - Keeps safe fallback to existing Responses/Chat Completions code paths. - -- `PI_USE_PKG_AI_RUNTIME_DRY_RUN=1` - - Runs optional `pkg/ai` dry-run stream consumption for diagnostics while still - executing the existing connector runtime path. - -### Provider runtime bridge - -- `PI_USE_PKG_AI_PROVIDER_RUNTIME=1` - - Enables `OpenAIProvider` bridging for: - - `GenerateStream(...)` via `tryGenerateStreamWithPkgAI(...)` - - `Generate(...)` via `tryGenerateWithPkgAI(...)` - - Includes guarded fallback for unresolved/stubbed provider APIs. - -## Current bridge behavior - -- Streaming (`PI_USE_PKG_AI_RUNTIME`): - - Controlled live pkg/ai event consumption is enabled for safe non-tool - scenarios. - - Falls back to legacy streaming runtime when bridge conditions are not met. - -- Provider abstraction (`PI_USE_PKG_AI_PROVIDER_RUNTIME`): - - Routes both streaming and non-streaming provider calls through pkg/ai where - possible. - - Preserves existing connector behavior on fallback-class errors. +- Provider abstraction: + - `OpenAIProvider.GenerateStream(...)` and `Generate(...)` route through + `tryGenerateStreamWithPkgAI(...)` / `tryGenerateWithPkgAI(...)` as primary paths. ## High-signal test commands ```bash go test ./pkg/ai/... -CGO_ENABLED=0 go test ./pkg/connector -run "TestPkgAIProviderRuntimeEnabled|TestInferProviderNameFromBaseURL|TestBuildPkgAIModelFromGenerateParams|TestShouldFallbackFromPkgAIEvent|TestShouldFallbackFromPkgAIError|TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved|TestTryGenerateWithPkgAIFallsBackOnStubbedProviders|TestTryGenerateWithPkgAIReturnsRuntimeErrorWhenProviderResolved|TestGenerateResponseFromAIMessage|TestParseThinkingLevel|TestOpenAIProviderGenerate_UsesPkgAIBridgeWhenEnabled|TestPkgAIRuntimeEnabledFromEnv|TestChooseStreamingRuntimePath|TestPromptContainsToolCalls|TestShouldUsePkgAIBridgeStreaming|TestBuildPkgAIBridgeGenerateParams|TestPkgAIProviderBridgeCredentials|TestAIEventToStreamEvent_Mapping|TestStreamEventsFromAIStream|TestToAIContext_MapsMessagesAndTools" +go test ./pkg/connector -run "TestPkgAIProviderRuntimeEnabled|TestInferProviderNameFromBaseURL|TestBuildPkgAIModelFromGenerateParams|TestShouldFallbackFromPkgAIEvent|TestShouldFallbackFromPkgAIError|TestTryGenerateStreamWithPkgAIReturnsRuntimeErrorEventsWhenProviderResolved|TestTryGenerateWithPkgAIReturnsRuntimeErrorForGeminiCLI|TestTryGenerateWithPkgAIReturnsRuntimeErrorWhenProviderResolved|TestGenerateResponseFromAIMessage|TestParseThinkingLevel|TestOpenAIProviderGenerate_UsesPkgAIBridgeWhenEnabled|TestPkgAIRuntimeEnabledFromEnv|TestChooseStreamingRuntimePath|TestBuildPkgAIBridgeGenerateParams|TestPkgAIProviderBridgeCredentials|TestAIEventToStreamEvent_Mapping|TestStreamEventsFromAIStream|TestToAIContext_MapsMessagesAndTools" ``` ## Connector bridge env-gated provider validation @@ -47,7 +29,7 @@ CGO_ENABLED=0 go test ./pkg/connector -run "TestPkgAIProviderRuntimeEnabled|Test To validate real provider happy paths for connector bridge routing (OpenAI, Anthropic, Google), set credentials and: ```bash -PI_AI_E2E=1 CGO_ENABLED=0 go test ./pkg/connector -run "TestPkgAIProviderBridgeE2E_" +PI_AI_E2E=1 go test ./pkg/connector -run "TestPkgAIProviderBridgeE2E_" ``` Optional model overrides: @@ -91,5 +73,6 @@ Optional overrides: ## Notes -- Full integration remains feature-gated. -- Fallback behavior is intentional and required for incremental rollout safety. +- `pkg/ai` is now the default bridge runtime for connector turn execution. +- Legacy streaming/provider implementations remain in-tree for now, but are no + longer selected on the primary non-audio path. diff --git a/pkg/connector/pkg_ai_provider_bridge.go b/pkg/connector/pkg_ai_provider_bridge.go index 8bce3b40..cf2db90a 100644 --- a/pkg/connector/pkg_ai_provider_bridge.go +++ b/pkg/connector/pkg_ai_provider_bridge.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "os" "strings" "time" @@ -13,8 +12,7 @@ import ( ) func pkgAIProviderRuntimeEnabled() bool { - value := strings.ToLower(strings.TrimSpace(os.Getenv("PI_USE_PKG_AI_PROVIDER_RUNTIME"))) - return value == "1" || value == "true" || value == "yes" || value == "on" + return true } func inferProviderNameFromBaseURL(baseURL string) string { @@ -167,17 +165,20 @@ func tryGenerateStreamWithPkgAI( stream, err = aipkg.Stream(model, aiContext, options) } if err != nil { - return nil, false + out := make(chan StreamEvent, 1) + out <- StreamEvent{Type: StreamEventError, Error: err} + close(out) + return out, true } mapped := streamEventsFromAIStream(ctx, stream) select { case first, ok := <-mapped: if !ok { - return nil, false - } - if shouldFallbackFromPkgAIEvent(first) { - return nil, false + out := make(chan StreamEvent, 1) + out <- StreamEvent{Type: StreamEventError, Error: errors.New("pkg/ai stream ended before emitting events")} + close(out) + return out, true } out := make(chan StreamEvent, 64) go func() { @@ -225,16 +226,10 @@ func tryGenerateWithPkgAI( message, err = aipkg.Complete(model, aiContext, options) } if err != nil { - if shouldFallbackFromPkgAIError(err) { - return nil, false, nil - } return nil, true, err } if message.StopReason == aipkg.StopReasonError && strings.TrimSpace(message.ErrorMessage) != "" { runtimeErr := errors.New(strings.TrimSpace(message.ErrorMessage)) - if shouldFallbackFromPkgAIError(runtimeErr) { - return nil, false, nil - } return nil, true, runtimeErr } return generateResponseFromAIMessage(message), true, nil diff --git a/pkg/connector/pkg_ai_provider_bridge_test.go b/pkg/connector/pkg_ai_provider_bridge_test.go index 86f89c95..da4fef54 100644 --- a/pkg/connector/pkg_ai_provider_bridge_test.go +++ b/pkg/connector/pkg_ai_provider_bridge_test.go @@ -9,14 +9,8 @@ import ( ) func TestPkgAIProviderRuntimeEnabled(t *testing.T) { - t.Setenv("PI_USE_PKG_AI_PROVIDER_RUNTIME", "true") if !pkgAIProviderRuntimeEnabled() { - t.Fatalf("expected runtime flag to be enabled") - } - - t.Setenv("PI_USE_PKG_AI_PROVIDER_RUNTIME", "0") - if pkgAIProviderRuntimeEnabled() { - t.Fatalf("expected runtime flag to be disabled") + t.Fatalf("expected pkg/ai provider runtime to be always enabled") } } diff --git a/pkg/connector/provider_openai.go b/pkg/connector/provider_openai.go index 35de563f..4ce059bb 100644 --- a/pkg/connector/provider_openai.go +++ b/pkg/connector/provider_openai.go @@ -268,194 +268,27 @@ func (o *OpenAIProvider) buildResponsesParams(params GenerateParams) responses.R // GenerateStream generates a streaming response from OpenAI using Responses API func (o *OpenAIProvider) GenerateStream(ctx context.Context, params GenerateParams) (<-chan StreamEvent, error) { - if pkgAIProviderRuntimeEnabled() { - if pkgAIEvents, ok := tryGenerateStreamWithPkgAI(ctx, o.baseURL, o.apiKey, params); ok { - o.log.Debug(). - Str("model", params.Model). - Msg("Using pkg/ai provider runtime for OpenAI stream") - return pkgAIEvents, nil - } - o.log.Warn(). + if pkgAIEvents, ok := tryGenerateStreamWithPkgAI(ctx, o.baseURL, o.apiKey, params); ok { + o.log.Debug(). Str("model", params.Model). - Msg("pkg/ai provider runtime fallback to existing OpenAI stream path") + Msg("Using pkg/ai provider runtime for OpenAI stream") + return pkgAIEvents, nil } - - events := make(chan StreamEvent, 100) - - go func() { - defer close(events) - - responsesParams := o.buildResponsesParams(params) - - // Create streaming request - stream := o.client.Responses.NewStreaming(ctx, responsesParams) - if stream == nil { - events <- StreamEvent{ - Type: StreamEventError, - Error: errors.New("failed to create streaming request"), - } - return - } - - var responseID string - - // Process stream events - for stream.Next() { - streamEvent := stream.Current() - - switch streamEvent.Type { - case "response.output_text.delta": - events <- StreamEvent{ - Type: StreamEventDelta, - Delta: streamEvent.Delta, - } - - case "response.reasoning_text.delta": - events <- StreamEvent{ - Type: StreamEventReasoning, - ReasoningDelta: streamEvent.Delta, - } - - case "response.function_call_arguments.done": - events <- StreamEvent{ - Type: StreamEventToolCall, - ToolCall: &ToolCallResult{ - ID: streamEvent.ItemID, - Name: streamEvent.Name, - Arguments: streamEvent.Arguments, - }, - } - - case "response.completed": - responseID = streamEvent.Response.ID - finishReason := "stop" - if streamEvent.Response.Status != "completed" { - finishReason = string(streamEvent.Response.Status) - } - - // Extract usage - var usage *UsageInfo - if streamEvent.Response.Usage.InputTokens > 0 || streamEvent.Response.Usage.OutputTokens > 0 { - usage = &UsageInfo{ - PromptTokens: int(streamEvent.Response.Usage.InputTokens), - CompletionTokens: int(streamEvent.Response.Usage.OutputTokens), - TotalTokens: int(streamEvent.Response.Usage.TotalTokens), - } - if streamEvent.Response.Usage.OutputTokensDetails.ReasoningTokens > 0 { - usage.ReasoningTokens = int(streamEvent.Response.Usage.OutputTokensDetails.ReasoningTokens) - } - } - - events <- StreamEvent{ - Type: StreamEventComplete, - FinishReason: finishReason, - ResponseID: responseID, - Usage: usage, - } - - case "error": - events <- StreamEvent{ - Type: StreamEventError, - Error: fmt.Errorf("API error: %s", streamEvent.Message), - } - return - } - } - - if err := stream.Err(); err != nil { - events <- StreamEvent{ - Type: StreamEventError, - Error: err, - } - } - }() - - return events, nil + return nil, errors.New("pkg/ai stream runtime unavailable") } // Generate performs a non-streaming generation using Responses API func (o *OpenAIProvider) Generate(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { - if pkgAIProviderRuntimeEnabled() { - if pkgAIResp, handled, err := tryGenerateWithPkgAI(ctx, o.baseURL, o.apiKey, params); handled { - if err != nil { - return nil, fmt.Errorf("pkg/ai generation failed: %w", err) - } - o.log.Debug(). - Str("model", params.Model). - Msg("Using pkg/ai provider runtime for OpenAI generate") - return pkgAIResp, nil + if pkgAIResp, handled, err := tryGenerateWithPkgAI(ctx, o.baseURL, o.apiKey, params); handled { + if err != nil { + return nil, fmt.Errorf("pkg/ai generation failed: %w", err) } - o.log.Warn(). + o.log.Debug(). Str("model", params.Model). - Msg("pkg/ai provider runtime fallback to existing OpenAI generate path") + Msg("Using pkg/ai provider runtime for OpenAI generate") + return pkgAIResp, nil } - - // Responses input supports images and PDFs but not audio/video, so fall back to - // Chat Completions when unsupported media is present. - if hasUnsupportedResponsesUnifiedMessages(params.Messages) { - return o.generateChatCompletions(ctx, params) - } - - responsesParams := o.buildResponsesParams(params) - - // Make request - resp, err := o.client.Responses.New(ctx, responsesParams) - if err != nil { - return nil, fmt.Errorf("OpenAI generation failed: %w", err) - } - - // Extract response content - var content strings.Builder - var toolCalls []ToolCallResult - - var reasoning strings.Builder - for _, item := range resp.Output { - switch item := item.AsAny().(type) { - case responses.ResponseOutputMessage: - for _, contentPart := range item.Content { - switch part := contentPart.AsAny().(type) { - case responses.ResponseOutputText: - content.WriteString(part.Text) - } - } - case responses.ResponseReasoningItem: - // Handle reasoning model output - extract from summary - for _, summary := range item.Summary { - if summary.Text != "" { - reasoning.WriteString(summary.Text) - } - } - case responses.ResponseFunctionToolCall: - toolCalls = append(toolCalls, ToolCallResult{ - ID: item.ID, - Name: item.Name, - Arguments: item.Arguments, - }) - } - } - - // If no regular content but we have reasoning, use that as content - if content.Len() == 0 && reasoning.Len() > 0 { - content = reasoning - } - - finishReason := "stop" - if resp.Status != "completed" { - finishReason = string(resp.Status) - } - - return &GenerateResponse{ - Content: content.String(), - FinishReason: finishReason, - ResponseID: resp.ID, - ToolCalls: toolCalls, - Usage: UsageInfo{ - PromptTokens: int(resp.Usage.InputTokens), - CompletionTokens: int(resp.Usage.OutputTokens), - TotalTokens: int(resp.Usage.TotalTokens), - ReasoningTokens: int(resp.Usage.OutputTokensDetails.ReasoningTokens), - }, - }, nil + return nil, errors.New("pkg/ai generate runtime unavailable") } func (o *OpenAIProvider) generateChatCompletions(ctx context.Context, params GenerateParams) (*GenerateResponse, error) { diff --git a/pkg/connector/provider_openai_pkg_ai_test.go b/pkg/connector/provider_openai_pkg_ai_test.go index e3344daa..4adf352d 100644 --- a/pkg/connector/provider_openai_pkg_ai_test.go +++ b/pkg/connector/provider_openai_pkg_ai_test.go @@ -9,7 +9,6 @@ import ( ) func TestOpenAIProviderGenerate_UsesPkgAIBridgeWhenEnabled(t *testing.T) { - t.Setenv("PI_USE_PKG_AI_PROVIDER_RUNTIME", "true") t.Setenv("ANTHROPIC_API_KEY", "") t.Setenv("ANTHROPIC_OAUTH_TOKEN", "") diff --git a/pkg/connector/streaming_runtime_selector.go b/pkg/connector/streaming_runtime_selector.go index 24cd5f0f..46bc9ea7 100644 --- a/pkg/connector/streaming_runtime_selector.go +++ b/pkg/connector/streaming_runtime_selector.go @@ -2,13 +2,14 @@ package connector import ( "context" - "os" + "encoding/json" + "errors" "strconv" "strings" "time" + "github.com/beeper/ai-bridge/pkg/agents/tools" aipkg "github.com/beeper/ai-bridge/pkg/ai" - aiproviders "github.com/beeper/ai-bridge/pkg/ai/providers" airuntime "github.com/beeper/ai-bridge/pkg/runtime" "github.com/openai/openai-go/v3" "maunium.net/go/mautrix/bridgev2" @@ -24,26 +25,14 @@ const ( ) func pkgAIRuntimeEnabled() bool { - value := strings.ToLower(strings.TrimSpace(os.Getenv("PI_USE_PKG_AI_RUNTIME"))) - return value == "1" || value == "true" || value == "yes" || value == "on" -} - -func pkgAIRuntimeDryRunEnabled() bool { - value := strings.ToLower(strings.TrimSpace(os.Getenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN"))) - return value == "1" || value == "true" || value == "yes" || value == "on" + return true } -func chooseStreamingRuntimePath(hasAudio bool, modelAPI ModelAPI, preferPkgAI bool) streamingRuntimePath { +func chooseStreamingRuntimePath(hasAudio bool, _ ModelAPI, _ bool) streamingRuntimePath { if hasAudio { return streamingRuntimeChatCompletions } - if preferPkgAI { - return streamingRuntimePkgAI - } - if modelAPI == ModelAPIChatCompletions { - return streamingRuntimeChatCompletions - } - return streamingRuntimeResponses + return streamingRuntimePkgAI } func (oc *AIClient) streamWithPkgAIBridge( @@ -71,55 +60,15 @@ func (oc *AIClient) streamWithPkgAIBridge( Str("ai_model_api", string(aiModel.API)). Str("ai_model_provider", string(aiModel.Provider)). Str("ai_model_id", aiModel.ID). - Msg("pkg/ai runtime bridge flag enabled; prepared adapter context/model and delegating to existing runtime path") - if pkgAIRuntimeDryRunEnabled() { - oc.runPkgAIBridgeDryRun(ctx, aiModel, aiContext) - } - if oc.shouldUsePkgAIBridgeStreaming(ctx, meta, prompt) { - if baseURL, apiKey, ok := oc.pkgAIProviderBridgeCredentials(); ok { - params := oc.buildPkgAIBridgeGenerateParams(meta, prompt) - if events, handled := tryGenerateStreamWithPkgAI(ctx, baseURL, apiKey, params); handled { - oc.loggerForContext(ctx).Debug(). - Str("model", params.Model). - Msg("Executing pkg/ai runtime bridge event stream path") - return oc.streamPkgAIBridgeEvents(ctx, evt, portal, meta, prompt, events) - } - oc.loggerForContext(ctx).Debug(). - Str("model", params.Model). - Msg("pkg/ai bridge event stream path requested fallback") - } - } - switch oc.resolveModelAPI(meta) { - case ModelAPIChatCompletions: - return oc.streamChatCompletions(ctx, evt, portal, meta, prompt) - default: - return oc.streamingResponseWithToolSchemaFallback(ctx, evt, portal, meta, prompt) - } -} + Msg("Using pkg/ai runtime bridge as primary streaming path") -func (oc *AIClient) runPkgAIBridgeDryRun(ctx context.Context, model aipkg.Model, aiContext aipkg.Context) { - aiproviders.RegisterBuiltInAPIProviders() - stream, err := aipkg.Stream(model, aiContext, &aipkg.StreamOptions{ - Ctx: ctx, - MaxTokens: model.MaxTokens, - }) - if err != nil { - oc.loggerForContext(ctx).Warn().Err(err).Msg("pkg/ai dry-run failed to create stream") - return - } - events := streamEventsFromAIStream(ctx, stream) - count := 0 - for evt := range events { - count++ - if evt.Type == StreamEventError { - oc.loggerForContext(ctx).Debug().Err(evt.Error).Int("event_count", count).Msg("pkg/ai dry-run produced error event") - return - } - if evt.Type == StreamEventComplete { - oc.loggerForContext(ctx).Debug().Int("event_count", count).Str("finish_reason", evt.FinishReason).Msg("pkg/ai dry-run completed") - return - } + baseURL, apiKey, ok := oc.pkgAIProviderBridgeCredentials() + if !ok { + return false, nil, errors.New("pkg/ai runtime requires OpenAI-compatible provider credentials") } + + params := oc.buildPkgAIBridgeGenerateParams(ctx, portal, meta, prompt) + return oc.streamPkgAIBridgeEvents(ctx, evt, portal, meta, prompt, baseURL, apiKey, params) } func buildPkgAIContext(systemPrompt string, prompt []openai.ChatCompletionMessageParamUnion) aipkg.Context { @@ -228,57 +177,96 @@ func (oc *AIClient) pkgAIProviderBridgeCredentials() (string, string, bool) { return provider.baseURL, provider.apiKey, true } -func (oc *AIClient) shouldUsePkgAIBridgeStreaming( +func (oc *AIClient) buildPkgAIBridgeGenerateParams( ctx context.Context, + portal *bridgev2.Portal, meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion, -) bool { - if meta != nil && meta.Capabilities.SupportsToolCalling { - if oc.selectedBuiltinToolCountSafe(ctx, meta) > 0 { - return false - } - if resolveAgentID(meta) != "" { - return false - } - } - if promptContainsToolCalls(prompt) { - return false +) GenerateParams { + return GenerateParams{ + Model: oc.effectiveModel(meta), + Messages: chatPromptToUnifiedMessages(prompt), + SystemPrompt: oc.effectivePrompt(meta), + Temperature: oc.effectiveTemperature(meta), + MaxCompletionTokens: oc.effectiveMaxTokens(meta), + ReasoningEffort: oc.effectiveReasoningEffort(meta), + Tools: oc.buildPkgAIBridgeTools(ctx, portal, meta), } - return true } -func (oc *AIClient) selectedBuiltinToolCountSafe(ctx context.Context, meta *PortalMetadata) (count int) { - defer func() { - if recover() != nil { - count = 0 +func (oc *AIClient) buildPkgAIBridgeTools( + ctx context.Context, + portal *bridgev2.Portal, + meta *PortalMetadata, +) []ToolDefinition { + definitions := append([]ToolDefinition(nil), oc.selectedBuiltinToolsForTurn(ctx, meta)...) + if meta == nil || !meta.Capabilities.SupportsToolCalling { + return dedupeToolDefinitionsByName(definitions) + } + + hasAgent := resolveAgentID(meta) != "" + if hasAgent && !hasBossAgent(meta) && !oc.isBuilderRoom(portal) { + for _, tool := range tools.SessionTools() { + if !oc.isToolEnabled(meta, tool.Name) { + continue + } + definitions = append(definitions, toToolDefinitionFromAgentTool(tool)) } - }() - return len(oc.selectedBuiltinToolsForTurn(ctx, meta)) + } + if hasBossAgent(meta) || oc.isBuilderRoom(portal) { + for _, tool := range tools.BossTools() { + if !oc.isToolEnabled(meta, tool.Name) { + continue + } + definitions = append(definitions, toToolDefinitionFromAgentTool(tool)) + } + } + return dedupeToolDefinitionsByName(definitions) } -func promptContainsToolCalls(prompt []openai.ChatCompletionMessageParamUnion) bool { - for _, msg := range prompt { - if msg.OfTool != nil { - return true +func dedupeToolDefinitionsByName(tools []ToolDefinition) []ToolDefinition { + if len(tools) == 0 { + return nil + } + deduped := make([]ToolDefinition, 0, len(tools)) + seen := make(map[string]struct{}, len(tools)) + for _, tool := range tools { + name := strings.TrimSpace(tool.Name) + if name == "" { + continue } - if msg.OfAssistant != nil && len(msg.OfAssistant.ToolCalls) > 0 { - return true + if _, exists := seen[name]; exists { + continue } + seen[name] = struct{}{} + deduped = append(deduped, tool) } - return false + return deduped } -func (oc *AIClient) buildPkgAIBridgeGenerateParams( - meta *PortalMetadata, - prompt []openai.ChatCompletionMessageParamUnion, -) GenerateParams { - return GenerateParams{ - Model: oc.effectiveModel(meta), - Messages: chatPromptToUnifiedMessages(prompt), - SystemPrompt: oc.effectivePrompt(meta), - Temperature: oc.effectiveTemperature(meta), - MaxCompletionTokens: oc.effectiveMaxTokens(meta), - ReasoningEffort: oc.effectiveReasoningEffort(meta), +func toToolDefinitionFromAgentTool(tool *tools.Tool) ToolDefinition { + if tool == nil { + return ToolDefinition{} + } + var parameters map[string]any + switch schema := tool.InputSchema.(type) { + case nil: + parameters = nil + case map[string]any: + parameters = schema + default: + blob, err := json.Marshal(schema) + if err == nil { + _ = json.Unmarshal(blob, ¶meters) + } + if len(parameters) == 0 { + parameters = nil + } + } + return ToolDefinition{ + Name: tool.Name, + Description: tool.Description, + Parameters: parameters, } } @@ -288,7 +276,9 @@ func (oc *AIClient) streamPkgAIBridgeEvents( portal *bridgev2.Portal, meta *PortalMetadata, prompt []openai.ChatCompletionMessageParamUnion, - events <-chan StreamEvent, + baseURL string, + apiKey string, + params GenerateParams, ) (bool, *ContextLengthError, error) { log := oc.loggerForContext(ctx).With(). Str("action", "stream_pkg_ai_bridge_events"). @@ -302,77 +292,175 @@ func (oc *AIClient) streamPkgAIBridgeEvents( isHeartbeat := prep.IsHeartbeat oc.emitUIStart(ctx, portal, state, meta) + currentParams := params + const maxToolRounds = 10 - for { - select { - case <-ctx.Done(): - state.finishReason = "cancelled" - if state.hasInitialMessageTarget() && state.accumulated.Len() > 0 { - oc.flushPartialStreamingMessage(context.Background(), portal, state, meta) - } - oc.uiEmitter(state).EmitUIAbort(ctx, portal, "cancelled") + for round := 0; ; round++ { + state.pendingFunctionOutputs = nil + toolCallsThisRound := make([]ToolCallResult, 0, 4) + activeTools := make(map[string]*activeToolCall) + var roundContent strings.Builder + events, handled := tryGenerateStreamWithPkgAI(ctx, baseURL, apiKey, currentParams) + if !handled { + err := errors.New("pkg/ai runtime stream was not handled by registered providers") + oc.uiEmitter(state).EmitUIError(ctx, portal, err.Error()) oc.emitUIFinish(ctx, portal, state, meta) - return false, nil, streamFailureError(state, ctx.Err()) - case event, ok := <-events: - if !ok { - state.completedAtMs = time.Now().UnixMilli() - oc.finalizeResponsesStream(ctx, log, portal, state, meta) - return true, nil, nil - } + return false, nil, streamFailureError(state, err) + } - oc.markMessageSendSuccess(ctx, portal, evt, state) - switch event.Type { - case StreamEventDelta: - touchTyping() - if err := oc.handleResponseOutputTextDelta( - ctx, - log, - portal, - state, - meta, - typingSignals, - isHeartbeat, - event.Delta, - "failed to send initial streaming message", - "Failed to send initial streaming message", - ); err != nil { - return false, nil, &PreDeltaError{Err: err} - } - case StreamEventReasoning: - touchTyping() - if err := oc.handleResponseReasoningTextDelta( - ctx, - log, - portal, - state, - meta, - isHeartbeat, - event.ReasoningDelta, - "failed to send initial streaming message", - "Failed to send initial streaming message", - ); err != nil { - return false, nil, &PreDeltaError{Err: err} + for { + select { + case <-ctx.Done(): + state.finishReason = "cancelled" + if state.hasInitialMessageTarget() && state.accumulated.Len() > 0 { + oc.flushPartialStreamingMessage(context.Background(), portal, state, meta) } - case StreamEventComplete: - if reason := strings.TrimSpace(event.FinishReason); reason != "" { - state.finishReason = reason - } - state.responseID = strings.TrimSpace(event.ResponseID) - if event.Usage != nil { - state.promptTokens = int64(event.Usage.PromptTokens) - state.completionTokens = int64(event.Usage.CompletionTokens) - state.reasoningTokens = int64(event.Usage.ReasoningTokens) - state.totalTokens = int64(event.Usage.TotalTokens) - oc.uiEmitter(state).EmitUIMessageMetadata(ctx, portal, oc.buildUIMessageMetadata(state, meta, true)) + oc.uiEmitter(state).EmitUIAbort(ctx, portal, "cancelled") + oc.emitUIFinish(ctx, portal, state, meta) + return false, nil, streamFailureError(state, ctx.Err()) + case event, ok := <-events: + if !ok { + if shouldContinueChatToolLoop(state.finishReason, len(toolCallsThisRound)) { + if round >= maxToolRounds { + err := errors.New("max pkg/ai tool call rounds reached") + oc.uiEmitter(state).EmitUIError(ctx, portal, err.Error()) + oc.emitUIFinish(ctx, portal, state, meta) + return false, nil, streamFailureError(state, err) + } + + assistantContent := make([]ContentPart, 0, 1) + if text := strings.TrimSpace(roundContent.String()); text != "" { + assistantContent = append(assistantContent, ContentPart{Type: ContentTypeText, Text: text}) + } + currentParams.Messages = append(currentParams.Messages, UnifiedMessage{ + Role: RoleAssistant, + Content: assistantContent, + ToolCalls: toolCallsThisRound, + }) + for _, output := range state.pendingFunctionOutputs { + currentParams.Messages = append(currentParams.Messages, UnifiedMessage{ + Role: RoleTool, + ToolCallID: output.callID, + Name: output.name, + Content: []ContentPart{ + {Type: ContentTypeText, Text: output.output}, + }, + }) + } + state.pendingFunctionOutputs = nil + state.needsTextSeparator = true + + if steerItems := oc.drainSteerQueue(state.roomID); len(steerItems) > 0 { + for _, item := range steerItems { + if item.pending.Type != pendingTypeText { + continue + } + userPrompt := strings.TrimSpace(item.prompt) + if userPrompt == "" { + userPrompt = strings.TrimSpace(item.pending.MessageBody) + } + if userPrompt == "" { + continue + } + currentParams.Messages = append(currentParams.Messages, UnifiedMessage{ + Role: RoleUser, + Content: []ContentPart{{Type: ContentTypeText, Text: userPrompt}}, + }) + } + } + goto nextRound + } + + state.completedAtMs = time.Now().UnixMilli() + oc.finalizeResponsesStream(ctx, log, portal, state, meta) + return true, nil, nil } - case StreamEventError: - if cle := ParseContextLengthError(event.Error); cle != nil { - return false, cle, nil + + oc.markMessageSendSuccess(ctx, portal, evt, state) + switch event.Type { + case StreamEventDelta: + touchTyping() + roundContent.WriteString(event.Delta) + if err := oc.handleResponseOutputTextDelta( + ctx, + log, + portal, + state, + meta, + typingSignals, + isHeartbeat, + event.Delta, + "failed to send initial streaming message", + "Failed to send initial streaming message", + ); err != nil { + return false, nil, &PreDeltaError{Err: err} + } + case StreamEventReasoning: + touchTyping() + if err := oc.handleResponseReasoningTextDelta( + ctx, + log, + portal, + state, + meta, + isHeartbeat, + event.ReasoningDelta, + "failed to send initial streaming message", + "Failed to send initial streaming message", + ); err != nil { + return false, nil, &PreDeltaError{Err: err} + } + case StreamEventToolCall: + if event.ToolCall == nil { + continue + } + toolCallID := strings.TrimSpace(event.ToolCall.ID) + if toolCallID == "" { + toolCallID = NewCallID() + } + toolName := strings.TrimSpace(event.ToolCall.Name) + arguments := normalizeToolArgsJSON(strings.TrimSpace(event.ToolCall.Arguments)) + toolCallsThisRound = append(toolCallsThisRound, ToolCallResult{ + ID: toolCallID, + Name: toolName, + Arguments: arguments, + }) + oc.handleFunctionCallArgumentsDone( + ctx, + log, + portal, + state, + meta, + activeTools, + toolCallID, + toolName, + arguments, + true, + " (pkg/ai)", + ) + case StreamEventComplete: + if reason := strings.TrimSpace(event.FinishReason); reason != "" { + state.finishReason = reason + } + state.responseID = strings.TrimSpace(event.ResponseID) + if event.Usage != nil { + state.promptTokens = int64(event.Usage.PromptTokens) + state.completionTokens = int64(event.Usage.CompletionTokens) + state.reasoningTokens = int64(event.Usage.ReasoningTokens) + state.totalTokens = int64(event.Usage.TotalTokens) + oc.uiEmitter(state).EmitUIMessageMetadata(ctx, portal, oc.buildUIMessageMetadata(state, meta, true)) + } + case StreamEventError: + if cle := ParseContextLengthError(event.Error); cle != nil { + return false, cle, nil + } + oc.uiEmitter(state).EmitUIError(ctx, portal, event.Error.Error()) + oc.emitUIFinish(ctx, portal, state, meta) + return false, nil, streamFailureError(state, event.Error) } - oc.uiEmitter(state).EmitUIError(ctx, portal, event.Error.Error()) - oc.emitUIFinish(ctx, portal, state, meta) - return false, nil, streamFailureError(state, event.Error) } } + + nextRound: } } diff --git a/pkg/connector/streaming_runtime_selector_test.go b/pkg/connector/streaming_runtime_selector_test.go index 23c9a51b..f36ff93a 100644 --- a/pkg/connector/streaming_runtime_selector_test.go +++ b/pkg/connector/streaming_runtime_selector_test.go @@ -8,37 +8,8 @@ import ( ) func TestPkgAIRuntimeEnabledFromEnv(t *testing.T) { - t.Setenv("PI_USE_PKG_AI_RUNTIME", "") - if pkgAIRuntimeEnabled() { - t.Fatalf("expected runtime flag disabled by default") - } - - t.Setenv("PI_USE_PKG_AI_RUNTIME", "1") - if !pkgAIRuntimeEnabled() { - t.Fatalf("expected runtime flag enabled for value 1") - } - - t.Setenv("PI_USE_PKG_AI_RUNTIME", "true") if !pkgAIRuntimeEnabled() { - t.Fatalf("expected runtime flag enabled for value true") - } - - t.Setenv("PI_USE_PKG_AI_RUNTIME", "off") - if pkgAIRuntimeEnabled() { - t.Fatalf("expected runtime flag disabled for value off") - } - - t.Setenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN", "") - if pkgAIRuntimeDryRunEnabled() { - t.Fatalf("expected dry-run flag disabled by default") - } - t.Setenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN", "yes") - if !pkgAIRuntimeDryRunEnabled() { - t.Fatalf("expected dry-run flag enabled for value yes") - } - t.Setenv("PI_USE_PKG_AI_RUNTIME_DRY_RUN", "0") - if pkgAIRuntimeDryRunEnabled() { - t.Fatalf("expected dry-run flag disabled for value 0") + t.Fatalf("expected pkg/ai runtime to be always enabled") } } @@ -49,11 +20,8 @@ func TestChooseStreamingRuntimePath(t *testing.T) { if got := chooseStreamingRuntimePath(false, ModelAPIResponses, true); got != streamingRuntimePkgAI { t.Fatalf("expected pkg_ai path when preferred and no audio, got %s", got) } - if got := chooseStreamingRuntimePath(false, ModelAPIChatCompletions, false); got != streamingRuntimeChatCompletions { - t.Fatalf("expected chat model api path, got %s", got) - } - if got := chooseStreamingRuntimePath(false, ModelAPIResponses, false); got != streamingRuntimeResponses { - t.Fatalf("expected responses path fallback, got %s", got) + if got := chooseStreamingRuntimePath(false, ModelAPIChatCompletions, false); got != streamingRuntimePkgAI { + t.Fatalf("expected pkg_ai path regardless of model API, got %s", got) } } @@ -120,62 +88,6 @@ func TestBuildPkgAIContext_UsesSystemPromptAndMappedMessages(t *testing.T) { } } -func TestPromptContainsToolCalls(t *testing.T) { - if promptContainsToolCalls([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("hello"), - }) { - t.Fatalf("did not expect tool call detection for plain user prompt") - } - if !promptContainsToolCalls([]openai.ChatCompletionMessageParamUnion{ - { - OfAssistant: &openai.ChatCompletionAssistantMessageParam{ - ToolCalls: []openai.ChatCompletionMessageToolCallUnionParam{ - { - OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ - ID: "call_1", - Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ - Name: "search", - Arguments: "{}", - }, - }, - }, - }, - }, - }, - }) { - t.Fatalf("expected assistant tool calls to be detected") - } - if !promptContainsToolCalls([]openai.ChatCompletionMessageParamUnion{ - openai.ToolMessage("tool result", "call_1"), - }) { - t.Fatalf("expected tool role messages to be detected") - } -} - -func TestShouldUsePkgAIBridgeStreaming(t *testing.T) { - client := &AIClient{} - if !client.shouldUsePkgAIBridgeStreaming(context.Background(), &PortalMetadata{}, []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("hello"), - }) { - t.Fatalf("expected bridge streaming to be enabled for non-tool prompt") - } - if !client.shouldUsePkgAIBridgeStreaming(context.Background(), &PortalMetadata{ - Capabilities: ModelCapabilities{SupportsToolCalling: true}, - }, []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("hello"), - }) { - t.Fatalf("expected bridge streaming enabled when tool calling has no active tools") - } - if client.shouldUsePkgAIBridgeStreaming(context.Background(), &PortalMetadata{ - Capabilities: ModelCapabilities{SupportsToolCalling: true}, - AgentID: "agent-1", - }, []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("hello"), - }) { - t.Fatalf("expected bridge streaming disabled when agent tool mode is active") - } -} - func TestBuildPkgAIBridgeGenerateParams(t *testing.T) { client := &AIClient{} meta := &PortalMetadata{ @@ -188,7 +100,7 @@ func TestBuildPkgAIBridgeGenerateParams(t *testing.T) { SupportsReasoning: true, }, } - params := client.buildPkgAIBridgeGenerateParams(meta, []openai.ChatCompletionMessageParamUnion{ + params := client.buildPkgAIBridgeGenerateParams(context.Background(), nil, meta, []openai.ChatCompletionMessageParamUnion{ openai.SystemMessage("ignored"), openai.UserMessage("hello"), openai.AssistantMessage("hi"), @@ -211,6 +123,9 @@ func TestBuildPkgAIBridgeGenerateParams(t *testing.T) { if len(params.Messages) != 2 { t.Fatalf("expected mapped user+assistant messages, got %d", len(params.Messages)) } + if len(params.Tools) != 0 { + t.Fatalf("expected no tools without tool-calling metadata, got %d", len(params.Tools)) + } } func TestPkgAIProviderBridgeCredentials(t *testing.T) {