diff --git a/.github/actions/checkout-eyrie/action.yml b/.github/actions/checkout-eyrie/action.yml index 98485f38..52268717 100644 --- a/.github/actions/checkout-eyrie/action.yml +++ b/.github/actions/checkout-eyrie/action.yml @@ -19,5 +19,11 @@ runs: echo "eyrie already present at $dest" exit 0 fi - git clone --depth=1 --branch "${{ inputs.ref }}" \ + ref="${{ inputs.ref }}" + # Fall back to main if the branch doesn't exist on eyrie + if ! git ls-remote --heads "https://github.com/GrayCodeAI/eyrie.git" "$ref" | grep -q .; then + echo "Branch '$ref' not found on eyrie, falling back to main" + ref="main" + fi + git clone --depth=1 --branch "$ref" \ "https://github.com/GrayCodeAI/eyrie.git" "$dest" diff --git a/cmd/bg_sessions.go b/cmd/bg_sessions.go index eac13281..b6a64d35 100644 --- a/cmd/bg_sessions.go +++ b/cmd/bg_sessions.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "encoding/json" "fmt" "os" @@ -125,7 +126,7 @@ func StartBGSession(prompt string, args []string) (*BGSessionInfo, error) { // Build command: hawk --print with all inherited flags cmdArgs := append([]string{"--print", "--session-id", id, prompt}, args...) - cmd := exec.Command("hawk", cmdArgs...) + cmd := exec.CommandContext(context.Background(), "hawk", cmdArgs...) cmd.Dir = cwd logF, err := os.Create(logFile) diff --git a/cmd/chat_commands.go b/cmd/chat_commands.go index e5d19f24..6e113f15 100644 --- a/cmd/chat_commands.go +++ b/cmd/chat_commands.go @@ -16,6 +16,7 @@ import ( "github.com/GrayCodeAI/eyrie/client" hawkconfig "github.com/GrayCodeAI/hawk/internal/config" "github.com/GrayCodeAI/hawk/internal/engine" + "github.com/GrayCodeAI/hawk/internal/engine/project" "github.com/GrayCodeAI/hawk/internal/feature/shellmode" "github.com/GrayCodeAI/hawk/internal/feature/taste" "github.com/GrayCodeAI/hawk/internal/intelligence/memory" @@ -243,7 +244,7 @@ func applySlashSuggestion(input string) string { } func gitOutput(args ...string) (string, error) { - out, err := exec.Command("git", args...).CombinedOutput() + out, err := exec.CommandContext(context.Background(), "git", args...).CombinedOutput() return strings.TrimSpace(string(out)), err } @@ -1085,12 +1086,12 @@ Generate the recap:`, summary.String()) arg := strings.TrimSpace(strings.TrimPrefix(text, "/context")) if arg == "init" { cwd, _ := os.Getwd() - pc := engine.NewProjectContext(cwd) + pc := project.NewProjectContext(cwd) return m.startPromptCommand("/context init", pc.InitPrompt()) } if arg == "show" { cwd, _ := os.Getwd() - pc := engine.NewProjectContext(cwd) + pc := project.NewProjectContext(cwd) content := pc.Load() if content == "" { m.messages = append(m.messages, displayMsg{role: "system", content: "No project context files found. Run /context init to generate."}) @@ -1871,7 +1872,7 @@ Generate the recap:`, summary.String()) m.messages = append(m.messages, displayMsg{role: "system", content: pluginsSummary(m.pluginRuntime)}) return m, nil case "/voice": - out, err := exec.Command("which", "whisper").CombinedOutput() + out, err := exec.CommandContext(context.Background(), "which", "whisper").CombinedOutput() if err != nil || strings.TrimSpace(string(out)) == "" { m.messages = append(m.messages, displayMsg{role: "error", content: "Voice requires whisper.cpp. Install with: brew install whisper-cpp"}) } else { @@ -2216,7 +2217,7 @@ Generate the recap:`, summary.String()) m.messages = append(m.messages, displayMsg{role: "error", content: "Blocked: command fails safety check"}) return m, nil } - out, err := exec.Command("sh", "-c", cmdStr).CombinedOutput() + out, err := exec.CommandContext(context.Background(), "sh", "-c", cmdStr).CombinedOutput() result := strings.TrimSpace(string(out)) if err != nil { result += "\n" + err.Error() @@ -2234,7 +2235,7 @@ Generate the recap:`, summary.String()) m.messages = append(m.messages, displayMsg{role: "error", content: "Blocked: command fails safety check"}) return m, nil } - out, err := exec.Command("sh", "-c", cmdStr).CombinedOutput() + out, err := exec.CommandContext(context.Background(), "sh", "-c", cmdStr).CombinedOutput() result := strings.TrimSpace(string(out)) if err != nil { result += "\n" + err.Error() @@ -2254,7 +2255,7 @@ Generate the recap:`, summary.String()) m.messages = append(m.messages, displayMsg{role: "error", content: "Blocked: command fails safety check"}) return m, nil } - out, _ := exec.Command("sh", "-c", cmdStr).CombinedOutput() + out, _ := exec.CommandContext(context.Background(), "sh", "-c", cmdStr).CombinedOutput() result := strings.TrimSpace(string(out)) if result == "" { m.messages = append(m.messages, displayMsg{role: "system", content: "No lint issues."}) @@ -2321,7 +2322,7 @@ Generate the recap:`, summary.String()) func explainCode(path string, line int) (string, error) { // Step 1: git blame to find the commit args := []string{"blame", "-L", fmt.Sprintf("%d,%d", line, line), "--porcelain", path} - out, err := exec.Command("git", args...).Output() + out, err := exec.CommandContext(context.Background(), "git", args...).Output() if err != nil { return "", fmt.Errorf("git blame failed: %w", err) } @@ -2335,13 +2336,13 @@ func explainCode(path string, line int) (string, error) { } // Step 2: get commit info - info, err := exec.Command("git", "log", "-1", "--format=%h %s (%an, %ar)", commitHash).Output() + info, err := exec.CommandContext(context.Background(), "git", "log", "-1", "--format=%h %s (%an, %ar)", commitHash).Output() if err != nil { return fmt.Sprintf("Commit: %s (details unavailable)", commitHash[:7]), nil } // Step 3: get the diff for context - diff, _ := exec.Command("git", "log", "-1", "--format=", "-p", "--", path, commitHash).Output() + diff, _ := exec.CommandContext(context.Background(), "git", "log", "-1", "--format=", "-p", "--", path, commitHash).Output() diffStr := string(diff) if len(diffStr) > 2000 { diffStr = diffStr[:2000] + "\n... (truncated)" diff --git a/cmd/chat_print.go b/cmd/chat_print.go index 68752cda..91e09c88 100644 --- a/cmd/chat_print.go +++ b/cmd/chat_print.go @@ -12,6 +12,7 @@ import ( "github.com/GrayCodeAI/eyrie/client" "github.com/GrayCodeAI/hawk/internal/engine" + "github.com/GrayCodeAI/hawk/internal/engine/lifecycle" "github.com/GrayCodeAI/hawk/internal/observability/logger" "github.com/GrayCodeAI/hawk/internal/session" ) @@ -63,9 +64,9 @@ func runPrint(text string) error { // Wire timeout if --timeout flag is set ctx := context.Background() if timeout > 0 { - cfg := engine.TimeoutConfig{Total: timeout, Countdown: true} + cfg := lifecycle.TimeoutConfig{Total: timeout, Countdown: true} var cancel context.CancelFunc - ctx, cancel = engine.WithTimeout(ctx, cfg) + ctx, cancel = lifecycle.WithTimeout(ctx, cfg) defer cancel() } diff --git a/cmd/clipboard.go b/cmd/clipboard.go index 5939db5b..d838df17 100644 --- a/cmd/clipboard.go +++ b/cmd/clipboard.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "context" "fmt" "os/exec" "runtime" @@ -15,18 +16,18 @@ func copyToClipboard(text string) error { switch runtime.GOOS { case "darwin": - cmd = exec.Command("pbcopy") + cmd = exec.CommandContext(context.Background(), "pbcopy") case "linux": // Try xclip first, fall back to xsel if _, err := exec.LookPath("xclip"); err == nil { - cmd = exec.Command("xclip", "-selection", "clipboard") + cmd = exec.CommandContext(context.Background(), "xclip", "-selection", "clipboard") } else if _, err := exec.LookPath("xsel"); err == nil { - cmd = exec.Command("xsel", "--clipboard", "--input") + cmd = exec.CommandContext(context.Background(), "xsel", "--clipboard", "--input") } else { return fmt.Errorf("clipboard not available: install xclip or xsel") } case "windows": - cmd = exec.Command("clip.exe") + cmd = exec.CommandContext(context.Background(), "clip.exe") default: return fmt.Errorf("clipboard not supported on %s", runtime.GOOS) } @@ -42,17 +43,17 @@ func pasteFromClipboard() (string, error) { switch runtime.GOOS { case "darwin": - cmd = exec.Command("pbpaste") + cmd = exec.CommandContext(context.Background(), "pbpaste") case "linux": if _, err := exec.LookPath("xclip"); err == nil { - cmd = exec.Command("xclip", "-selection", "clipboard", "-o") + cmd = exec.CommandContext(context.Background(), "xclip", "-selection", "clipboard", "-o") } else if _, err := exec.LookPath("xsel"); err == nil { - cmd = exec.Command("xsel", "--clipboard", "--output") + cmd = exec.CommandContext(context.Background(), "xsel", "--clipboard", "--output") } else { return "", fmt.Errorf("clipboard not available: install xclip or xsel") } case "windows": - cmd = exec.Command("powershell.exe", "-command", "Get-Clipboard") + cmd = exec.CommandContext(context.Background(), "powershell.exe", "-command", "Get-Clipboard") default: return "", fmt.Errorf("clipboard not supported on %s", runtime.GOOS) } diff --git a/cmd/context_export.go b/cmd/context_export.go index 1b926688..d1b4d970 100644 --- a/cmd/context_export.go +++ b/cmd/context_export.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/exec" @@ -213,7 +214,7 @@ func gitContextInfo(dir string) string { // runGit executes a git command in the given directory. func runGit(dir string, args ...string) (string, error) { - cmd := exec.Command("git", args...) + cmd := exec.CommandContext(context.Background(), "git", args...) cmd.Dir = dir out, err := cmd.Output() if err != nil { diff --git a/cmd/daemon.go b/cmd/daemon.go index 435595ac..0bc69c7a 100644 --- a/cmd/daemon.go +++ b/cmd/daemon.go @@ -16,6 +16,7 @@ import ( hawkconfig "github.com/GrayCodeAI/hawk/internal/config" "github.com/GrayCodeAI/hawk/internal/daemon" "github.com/GrayCodeAI/hawk/internal/engine" + "github.com/GrayCodeAI/hawk/internal/netutil" "github.com/GrayCodeAI/hawk/internal/observability/logger" "github.com/spf13/cobra" ) @@ -92,7 +93,7 @@ func runDaemonStart(_ *cobra.Command, _ []string) error { return sess, nil } - srv := daemon.New(daemon.Config{Port: daemonPort, Host: "127.0.0.1", APIKey: apiKey}, factory) + srv := daemon.New(daemon.Config{Port: daemonPort, Host: netutil.LoopbackHost, APIKey: apiKey}, factory) addr, err := srv.Start() if err != nil { return err @@ -103,7 +104,7 @@ func runDaemonStart(_ *cobra.Command, _ []string) error { preheater.Start([]string{ "https://api.anthropic.com/v1/messages", "https://api.openai.com/v1/chat/completions", - fmt.Sprintf("http://127.0.0.1:%d/v1/health", daemonPort), + fmt.Sprintf("http://%s:%d/v1/health", netutil.LoopbackHost, daemonPort), }) defer preheater.Stop() diff --git a/cmd/exec.go b/cmd/exec.go index 74b52b9b..cb60ac9d 100644 --- a/cmd/exec.go +++ b/cmd/exec.go @@ -285,14 +285,14 @@ func persistExecSession(id, model, provider, userMsg, assistantMsg string) { } func createExecWorktree(repoDir, baseBranch, branch string) (string, error) { - cmd := exec.Command("mktemp", "-d") + cmd := exec.CommandContext(context.Background(), "mktemp", "-d") out, err := cmd.Output() if err != nil { return "", err } wtPath := strings.TrimSpace(string(out)) - gitCmd := exec.Command("git", "worktree", "add", "-b", branch, wtPath, baseBranch) + gitCmd := exec.CommandContext(context.Background(), "git", "worktree", "add", "-b", branch, wtPath, baseBranch) gitCmd.Dir = repoDir if errOut, err := gitCmd.CombinedOutput(); err != nil { return "", fmt.Errorf("%s: %w", strings.TrimSpace(string(errOut)), err) @@ -304,7 +304,7 @@ func cleanupExecWorktree(repoDir, wtPath string) { if wtPath == "" { return } - cmd := exec.Command("git", "worktree", "remove", "--force", wtPath) + cmd := exec.CommandContext(context.Background(), "git", "worktree", "remove", "--force", wtPath) cmd.Dir = repoDir _ = cmd.Run() } diff --git a/cmd/feedback.go b/cmd/feedback.go index ee29bc83..73db104f 100644 --- a/cmd/feedback.go +++ b/cmd/feedback.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "encoding/json" "fmt" "net/url" @@ -152,11 +153,11 @@ func openFeedbackIssue(report FeedbackReport) error { func openBrowser(url string) error { switch runtime.GOOS { case "darwin": - return exec.Command("open", url).Start() + return exec.CommandContext(context.Background(), "open", url).Start() case "linux": - return exec.Command("xdg-open", url).Start() + return exec.CommandContext(context.Background(), "xdg-open", url).Start() case "windows": - return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + return exec.CommandContext(context.Background(), "rundll32", "url.dll,FileProtocolHandler", url).Start() default: return fmt.Errorf("unsupported platform") } diff --git a/cmd/mission.go b/cmd/mission.go index a691bad8..e64b7cd4 100644 --- a/cmd/mission.go +++ b/cmd/mission.go @@ -194,7 +194,7 @@ func parseFeatures(text string) []mission.Feature { } func getCurrentBranch(dir string) string { - cmd := exec.Command("git", "rev-parse", "--abbrev-ref", "HEAD") + cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--abbrev-ref", "HEAD") cmd.Dir = dir out, err := cmd.Output() if err != nil { diff --git a/cmd/notifications.go b/cmd/notifications.go index bed1ccb3..95ebd81a 100644 --- a/cmd/notifications.go +++ b/cmd/notifications.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "os/exec" "runtime" "time" @@ -19,12 +20,13 @@ func notifyCompletion(duration time.Duration) { switch runtime.GOOS { case "darwin": // macOS: use osascript for native notification - _ = exec.Command( + _ = exec.CommandContext( + context.Background(), "osascript", "-e", `display notification "`+msg+`" with title "Hawk"`, ).Start() case "linux": // Linux: use notify-send if available - _ = exec.Command("notify-send", "Hawk", msg).Start() + _ = exec.CommandContext(context.Background(), "notify-send", "Hawk", msg).Start() } } diff --git a/cmd/notify.go b/cmd/notify.go index 1f7a2e47..4c5e838c 100644 --- a/cmd/notify.go +++ b/cmd/notify.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/exec" @@ -135,10 +136,10 @@ func (n *Notifier) DesktopNotify(title, message string) error { case "darwin": script := fmt.Sprintf(`display notification "%s" with title "%s"`, escapeAppleScript(message), escapeAppleScript(title)) - cmd := exec.Command("osascript", "-e", script) + cmd := exec.CommandContext(context.Background(), "osascript", "-e", script) return cmd.Run() case "linux": - cmd := exec.Command("notify-send", title, message) + cmd := exec.CommandContext(context.Background(), "notify-send", title, message) return cmd.Run() case "windows": script := fmt.Sprintf(` @@ -150,7 +151,7 @@ $textNodes.Item(1).AppendChild($template.CreateTextNode('%s')) | Out-Null $toast = [Windows.UI.Notifications.ToastNotification]::new($template) [Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('hawk').Show($toast)`, escapePowerShell(title), escapePowerShell(message)) - cmd := exec.Command("powershell", "-Command", script) + cmd := exec.CommandContext(context.Background(), "powershell", "-Command", script) return cmd.Run() default: return fmt.Errorf("desktop notifications not supported on %s", runtime.GOOS) diff --git a/cmd/options.go b/cmd/options.go index 75501d10..e679354f 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -13,6 +13,8 @@ import ( "github.com/GrayCodeAI/eyrie/client" hawkconfig "github.com/GrayCodeAI/hawk/internal/config" "github.com/GrayCodeAI/hawk/internal/engine" + "github.com/GrayCodeAI/hawk/internal/engine/branching" + "github.com/GrayCodeAI/hawk/internal/engine/lifecycle" "github.com/GrayCodeAI/hawk/internal/eyrieclient" "github.com/GrayCodeAI/hawk/internal/intelligence/memory" "github.com/GrayCodeAI/hawk/internal/intelligence/repomap" @@ -261,14 +263,14 @@ func configureSession(sess *engine.Session, settings hawkconfig.Settings) error if settings.ModelRoles != nil { roles = *settings.ModelRoles } - sess.Cascade = engine.NewCascadeRouter(sess.Model(), roles) + sess.Cascade = branching.NewCascadeRouter(sess.Model(), roles) sess.Cascade.Enabled = true sess.Cascade.FrugalMode = settings.Frugal // Session lifecycle: self-improvement loop (learn from sessions) - sess.Lifecycle = &engine.SessionLifecycle{ - Memory: &engine.EvolvingMemoryAdapter{EM: memory.NewEvolvingMemory()}, - SkillStore: &engine.SkillDistillerAdapter{SD: sess.SkillDistiller}, + sess.Lifecycle = &lifecycle.SessionLifecycle{ + Memory: &lifecycle.EvolvingMemoryAdapter{EM: memory.NewEvolvingMemory()}, + SkillStore: &lifecycle.SkillDistillerAdapter{SD: sess.SkillDistiller}, } return nil diff --git a/cmd/review.go b/cmd/review.go index 5e5a5029..ddd5a2c5 100644 --- a/cmd/review.go +++ b/cmd/review.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/exec" @@ -80,7 +81,7 @@ func runReviewInit(_ *cobra.Command, _ []string) error { } func findGitDir() (string, error) { - out, err := exec.Command("git", "rev-parse", "--git-dir").Output() + out, err := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir").Output() if err != nil { return "", fmt.Errorf("not a git repository (run from inside a git repo)") } diff --git a/cmd/review_analyze.go b/cmd/review_analyze.go index 40dab459..db5b6364 100644 --- a/cmd/review_analyze.go +++ b/cmd/review_analyze.go @@ -198,7 +198,7 @@ func runReviewAnalyze(_ *cobra.Command, args []string) error { func getAnalysisContent(patterns []string) (string, error) { // Use git ls-files to expand patterns, then read files. args := append([]string{"ls-files", "--"}, patterns...) - out, err := exec.Command("git", args...).Output() + out, err := exec.CommandContext(context.Background(), "git", args...).Output() if err != nil { // Fallback: treat patterns as literal file paths. var b strings.Builder @@ -244,7 +244,7 @@ func autoFixAnalysis(result *sightLib.Result) error { hawkBin = "hawk" } - cmd := exec.Command(hawkBin, "exec", "--auto", "full", b.String()) + cmd := exec.CommandContext(context.Background(), hawkBin, "exec", "--auto", "full", b.String()) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd.Run() diff --git a/cmd/review_fix.go b/cmd/review_fix.go index bdb72eaf..e64f0377 100644 --- a/cmd/review_fix.go +++ b/cmd/review_fix.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/exec" @@ -93,7 +94,7 @@ func fixReview(store *ReviewStore, r *ReviewRecord) error { hawkBin = "hawk" } - cmd := exec.Command(hawkBin, execArgs...) + cmd := exec.CommandContext(context.Background(), hawkBin, execArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Stdin = os.Stdin diff --git a/cmd/review_refine.go b/cmd/review_refine.go index c4f14093..ff85b0a9 100644 --- a/cmd/review_refine.go +++ b/cmd/review_refine.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/exec" @@ -135,7 +136,7 @@ func fixReviewRefine(store *ReviewStore, r *ReviewRecord) error { } execArgs = append(execArgs, prompt) - cmd := exec.Command(hawkBin, execArgs...) + cmd := exec.CommandContext(context.Background(), hawkBin, execArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -159,14 +160,14 @@ func runReviewOnSHA(store *ReviewStore, sha string) error { args = append(args, "--model", refineModel) } - cmd := exec.Command(hawkBin, args...) + cmd := exec.CommandContext(context.Background(), hawkBin, args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd.Run() } func getLatestCommitSHA() string { - out, err := exec.Command("git", "rev-parse", "HEAD").Output() + out, err := exec.CommandContext(context.Background(), "git", "rev-parse", "HEAD").Output() if err != nil { return "" } diff --git a/cmd/review_run.go b/cmd/review_run.go index 0eb792b9..403cc7f3 100644 --- a/cmd/review_run.go +++ b/cmd/review_run.go @@ -41,7 +41,7 @@ func runReviewRun(_ *cobra.Command, args []string) error { // Resolve short SHA to full. if len(sha) < 40 { - out, err := exec.Command("git", "rev-parse", sha).Output() + out, err := exec.CommandContext(context.Background(), "git", "rev-parse", sha).Output() if err == nil { sha = strings.TrimSpace(string(out)) } @@ -140,10 +140,10 @@ func runReviewRun(_ *cobra.Command, args []string) error { func getCommitDiff(sha string) (string, error) { // For the first commit, diff against empty tree. - out, err := exec.Command("git", "diff-tree", "-p", sha).Output() + out, err := exec.CommandContext(context.Background(), "git", "diff-tree", "-p", sha).Output() if err != nil { // Fallback: diff against parent. - out, err = exec.Command("git", "diff", sha+"^", sha).Output() + out, err = exec.CommandContext(context.Background(), "git", "diff", sha+"^", sha).Output() if err != nil { return "", fmt.Errorf("git diff for %s: %w", sha[:8], err) } diff --git a/cmd/sight.go b/cmd/sight.go index d1b3b89a..ec4dafa2 100644 --- a/cmd/sight.go +++ b/cmd/sight.go @@ -128,7 +128,7 @@ func init() { func getDiff() (string, error) { if sightPR > 0 { // Use gh CLI to fetch the PR diff - out, err := exec.Command("gh", "pr", "diff", fmt.Sprintf("%d", sightPR)).Output() + out, err := exec.CommandContext(context.Background(), "gh", "pr", "diff", fmt.Sprintf("%d", sightPR)).Output() if err != nil { return "", fmt.Errorf("gh pr diff failed: %w", err) } @@ -139,10 +139,10 @@ func getDiff() (string, error) { if base == "" { base = "main" } - out, err := exec.Command("git", "diff", base+"...HEAD").Output() + out, err := exec.CommandContext(context.Background(), "git", "diff", base+"...HEAD").Output() if err != nil { // Fallback to two-dot syntax - out, err = exec.Command("git", "diff", base, "HEAD").Output() + out, err = exec.CommandContext(context.Background(), "git", "diff", base, "HEAD").Output() if err != nil { return "", fmt.Errorf("git diff %s...HEAD failed: %w", base, err) } diff --git a/cmd/sleep_prevent.go b/cmd/sleep_prevent.go index 9af24e03..f38e187f 100644 --- a/cmd/sleep_prevent.go +++ b/cmd/sleep_prevent.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "os/exec" "runtime" ) @@ -13,7 +14,7 @@ func preventSleep() func() { return func() {} } - cmd := exec.Command("caffeinate", "-i") + cmd := exec.CommandContext(context.Background(), "caffeinate", "-i") if err := cmd.Start(); err != nil { return func() {} } diff --git a/cmd/terminal_notify.go b/cmd/terminal_notify.go index f37fee29..7d8339a2 100644 --- a/cmd/terminal_notify.go +++ b/cmd/terminal_notify.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "os/exec" @@ -52,7 +53,7 @@ func sendTerminalNotification(title, body string) { case "apple": if runtime.GOOS == "darwin" { script := fmt.Sprintf(`display notification "%s" with title "%s"`, body, title) - _ = exec.Command("osascript", "-e", script).Start() + _ = exec.CommandContext(context.Background(), "osascript", "-e", script).Start() } default: // Generic: BEL character diff --git a/cmd/watch.go b/cmd/watch.go index e6024235..bf46344d 100644 --- a/cmd/watch.go +++ b/cmd/watch.go @@ -135,7 +135,7 @@ func gitDiffForFile(dir, path string) string { if err != nil { rel = path } - cmd := exec.Command("git", "diff", "--", rel) + cmd := exec.CommandContext(context.Background(), "git", "diff", "--", rel) cmd.Dir = dir out, err := cmd.CombinedOutput() if err != nil { diff --git a/integration_test.go b/integration_test.go index b4907000..817c8adf 100644 --- a/integration_test.go +++ b/integration_test.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" "net/http" - "net/http/httptest" "path/filepath" "testing" "github.com/GrayCodeAI/hawk/internal/provider/routing" + "github.com/GrayCodeAI/hawk/internal/testutil" "github.com/GrayCodeAI/inspect" "github.com/GrayCodeAI/sight" "github.com/GrayCodeAI/tok" @@ -105,7 +105,7 @@ func TestIntegration_SightReviewStoreRecall(t *testing.T) { func TestIntegration_InspectScanHTTPTest(t *testing.T) { // 1. Start a test HTTP server with known issues - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := testutil.NewLoopbackHTTPServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") // Intentionally missing security headers fmt.Fprint(w, ` @@ -259,7 +259,7 @@ func TestIntegration_FullPipeline(t *testing.T) { } // 5. Use inspect on a test server - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := testutil.NewLoopbackHTTPServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") fmt.Fprint(w, `OKOK`) })) diff --git a/internal/api/server.go b/internal/api/server.go index cd5f929a..c7c781d5 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -133,7 +133,7 @@ func (s *Server) Start(ctx context.Context) error { } s.mu.Unlock() - ln, err := net.Listen("tcp", s.addr) + ln, err := new(net.ListenConfig).Listen(ctx, "tcp", s.addr) if err != nil { return err } diff --git a/internal/bridge/sessioncapture/trace_control.go b/internal/bridge/sessioncapture/trace_control.go index 2c45d6e1..fee6e130 100644 --- a/internal/bridge/sessioncapture/trace_control.go +++ b/internal/bridge/sessioncapture/trace_control.go @@ -11,6 +11,7 @@ package sessioncapture import ( + "context" "fmt" "os" "os/exec" @@ -51,7 +52,7 @@ func (tc *TraceControl) Enable() (string, error) { return "Session capture is already enabled.", nil } - cmd := exec.Command("trace", "enable", "--agent", "hawk") + cmd := exec.CommandContext(context.Background(), "trace", "enable", "--agent", "hawk") cmd.Dir = tc.ProjectDir output, err := cmd.CombinedOutput() if err != nil { @@ -69,7 +70,7 @@ func (tc *TraceControl) Disable() (string, error) { return "Session capture is already disabled.", nil } - cmd := exec.Command("trace", "disable") + cmd := exec.CommandContext(context.Background(), "trace", "disable") cmd.Dir = tc.ProjectDir output, err := cmd.CombinedOutput() if err != nil { @@ -88,7 +89,7 @@ func (tc *TraceControl) Status() string { } // Get detailed status from trace - cmd := exec.Command("trace", "status") + cmd := exec.CommandContext(context.Background(), "trace", "status") cmd.Dir = tc.ProjectDir output, err := cmd.CombinedOutput() if err != nil { diff --git a/internal/cmdhistory/history.go b/internal/cmdhistory/history.go index 17bc5b58..04c519a8 100644 --- a/internal/cmdhistory/history.go +++ b/internal/cmdhistory/history.go @@ -5,6 +5,7 @@ package cmdhistory import ( + "context" "crypto/rand" "database/sql" "fmt" @@ -85,7 +86,8 @@ func (s *Store) Record(entry Entry) error { entry.CreatedAt = time.Now().UTC() } - _, err := s.db.Exec( + _, err := s.db.ExecContext( + context.Background(), `INSERT INTO entries (id, command, exit_code, duration_ms, cwd, git_branch, session_id, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, entry.ID, @@ -160,7 +162,7 @@ func (s *Store) Search(query string, opts SearchOpts) ([]Entry, error) { LIMIT ?`, where, ) - rows, err := s.db.Query(q, args...) + rows, err := s.db.QueryContext(context.Background(), q, args...) if err != nil { return nil, fmt.Errorf("search query: %w", err) } @@ -175,7 +177,8 @@ func (s *Store) Recent(n int) ([]Entry, error) { n = 20 } - rows, err := s.db.Query( + rows, err := s.db.QueryContext( + context.Background(), `SELECT id, command, exit_code, duration_ms, cwd, git_branch, session_id, created_at FROM entries ORDER BY created_at DESC @@ -195,7 +198,8 @@ func (s *Store) SearchByDir(dir string, limit int) ([]Entry, error) { limit = 50 } - rows, err := s.db.Query( + rows, err := s.db.QueryContext( + context.Background(), `SELECT id, command, exit_code, duration_ms, cwd, git_branch, session_id, created_at FROM entries WHERE cwd = ? @@ -215,7 +219,7 @@ func (s *Store) Stats() (*HistoryStats, error) { stats := &HistoryStats{} // Total and unique command counts, plus success rate. - err := s.db.QueryRow(` + err := s.db.QueryRowContext(context.Background(), ` SELECT COUNT(*) AS total, COUNT(DISTINCT command) AS uniq, @@ -230,7 +234,7 @@ func (s *Store) Stats() (*HistoryStats, error) { } // Top 10 commands by frequency. - cmdRows, err := s.db.Query(` + cmdRows, err := s.db.QueryContext(context.Background(), ` SELECT command, COUNT(*) AS cnt FROM entries GROUP BY command @@ -254,7 +258,7 @@ func (s *Store) Stats() (*HistoryStats, error) { } // Top 10 directories by frequency. - dirRows, err := s.db.Query(` + dirRows, err := s.db.QueryContext(context.Background(), ` SELECT cwd, COUNT(*) AS cnt FROM entries WHERE cwd != '' diff --git a/internal/cmdhistory/schema.go b/internal/cmdhistory/schema.go index d317851d..1a16ee22 100644 --- a/internal/cmdhistory/schema.go +++ b/internal/cmdhistory/schema.go @@ -1,6 +1,9 @@ package cmdhistory -import "database/sql" +import ( + "context" + "database/sql" +) const schemaSQL = ` CREATE TABLE IF NOT EXISTS entries ( @@ -41,9 +44,9 @@ END; // createSchema initializes the SQLite schema and enables WAL mode. func createSchema(db *sql.DB) error { - if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { + if _, err := db.ExecContext(context.Background(), "PRAGMA journal_mode=WAL"); err != nil { return err } - _, err := db.Exec(schemaSQL) + _, err := db.ExecContext(context.Background(), schemaSQL) return err } diff --git a/internal/config/config.go b/internal/config/config.go index 9cd41d86..e2df2e77 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "os" "os/exec" @@ -95,7 +96,7 @@ func GitContext() string { } func gitCmd(args ...string) (string, error) { - out, err := exec.Command("git", args...).Output() + out, err := exec.CommandContext(context.Background(), "git", args...).Output() return strings.TrimSpace(string(out)), err } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index acc8afa3..82b67e5f 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -16,6 +16,7 @@ import ( "time" "github.com/GrayCodeAI/hawk/internal/engine" + "github.com/GrayCodeAI/hawk/internal/netutil" ) const maxRequestBodyBytes = 1 << 20 @@ -47,7 +48,7 @@ type Config struct { func DefaultConfig() Config { return Config{ Port: 4590, - Host: "127.0.0.1", + Host: netutil.LoopbackHost, } } @@ -119,7 +120,7 @@ func New(cfg Config, factory SessionFactory) *Server { // Start begins serving in the background. Returns the listening address. func (s *Server) Start() (string, error) { - ln, err := net.Listen("tcp", s.addr) + ln, err := new(net.ListenConfig).Listen(context.Background(), "tcp", s.addr) if err != nil { return "", fmt.Errorf("daemon listen: %w", err) } diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index fbf1768e..0a151f2e 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -10,14 +10,22 @@ import ( "time" "github.com/GrayCodeAI/hawk/internal/engine" + "github.com/GrayCodeAI/hawk/internal/testutil" ) -func TestDaemon_StartStop(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) // port 0 = random free port +func startTestDaemon(t *testing.T, srv *Server) string { + t.Helper() addr, err := srv.Start() if err != nil { + testutil.SkipIfLoopbackUnavailable(t, err) t.Fatalf("Start failed: %v", err) } + return addr +} + +func TestDaemon_StartStop(t *testing.T) { + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) // port 0 = random free port + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) if addr == "" { @@ -26,11 +34,8 @@ func TestDaemon_StartStop(t *testing.T) { } func TestDaemon_Health(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) resp, err := http.Get("http://" + addr + "/v1/health") @@ -54,11 +59,8 @@ func TestDaemon_Health(t *testing.T) { } func TestDaemon_Chat_NoEngine(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) body, _ := json.Marshal(ChatRequest{Prompt: "hello"}) @@ -74,11 +76,8 @@ func TestDaemon_Chat_NoEngine(t *testing.T) { } func TestDaemon_ProtectedEndpointsRequireAPIKey(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1", APIKey: "secret"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost, APIKey: "secret"}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) body, _ := json.Marshal(ChatRequest{Prompt: "hello"}) @@ -108,11 +107,8 @@ func TestDaemon_ProtectedEndpointsRequireAPIKey(t *testing.T) { } func TestDaemon_RejectsOversizedBody(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) body := []byte(`{"prompt":"` + strings.Repeat("x", maxRequestBodyBytes+1) + `"}`) @@ -128,11 +124,8 @@ func TestDaemon_RejectsOversizedBody(t *testing.T) { } func TestDaemon_RejectsUnknownFields(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", bytes.NewReader([]byte(`{"prompt":"hello","unknown":true}`))) @@ -152,11 +145,8 @@ func TestDaemon_Chat_WithEngine(t *testing.T) { sess.MaxTurns = 1 return sess, nil } - srv := New(Config{Port: 0, Host: "127.0.0.1"}, factory) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, factory) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) body, _ := json.Marshal(ChatRequest{Prompt: "hello", MaxTurns: 1}) @@ -172,11 +162,8 @@ func TestDaemon_Chat_WithEngine(t *testing.T) { } func TestDaemon_Chat_EmptyPrompt(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) body, _ := json.Marshal(ChatRequest{}) @@ -192,11 +179,8 @@ func TestDaemon_Chat_EmptyPrompt(t *testing.T) { } func TestDaemon_Sessions(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) resp, err := http.Get("http://" + addr + "/v1/sessions") @@ -211,11 +195,8 @@ func TestDaemon_Sessions(t *testing.T) { } func TestDaemon_GracefulShutdown(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - _, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + _ = startTestDaemon(t, srv) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -230,17 +211,14 @@ func TestDefaultConfig(t *testing.T) { if cfg.Port != 4590 { t.Errorf("DefaultConfig().Port = %d, want 4590", cfg.Port) } - if cfg.Host != "127.0.0.1" { - t.Errorf("DefaultConfig().Host = %q, want 127.0.0.1", cfg.Host) + if cfg.Host != testutil.LoopbackHost { + t.Errorf("DefaultConfig().Host = %q, want %q", cfg.Host, testutil.LoopbackHost) } } func TestDaemon_Stats(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) resp, err := http.Get("http://" + addr + "/v1/stats") @@ -255,11 +233,8 @@ func TestDaemon_Stats(t *testing.T) { } func TestDaemon_InvalidMethod(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) req, _ := http.NewRequest("DELETE", "http://"+addr+"/v1/health", nil) @@ -275,11 +250,8 @@ func TestDaemon_InvalidMethod(t *testing.T) { } func TestDaemon_InvalidJSON(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", bytes.NewReader([]byte("not json"))) @@ -294,11 +266,8 @@ func TestDaemon_InvalidJSON(t *testing.T) { } func TestDaemon_GetSession_MissingID(t *testing.T) { - srv := New(Config{Port: 0, Host: "127.0.0.1"}, nil) - addr, err := srv.Start() - if err != nil { - t.Fatalf("Start failed: %v", err) - } + srv := New(Config{Port: 0, Host: testutil.LoopbackHost}, nil) + addr := startTestDaemon(t, srv) defer srv.Stop(context.Background()) resp, err := http.Get("http://" + addr + "/v1/sessions/nonexistent-id") diff --git a/internal/daemon/routes_review.go b/internal/daemon/routes_review.go index d324664b..25e45fb2 100644 --- a/internal/daemon/routes_review.go +++ b/internal/daemon/routes_review.go @@ -1,6 +1,7 @@ package daemon import ( + "context" "fmt" "net/http" "os/exec" @@ -49,7 +50,7 @@ func (s *Server) handleReview(w http.ResponseWriter, r *http.Request) { if req.Concerns != "" { args = append(args, "--concerns", req.Concerns) } - _ = exec.Command("hawk", args...).Run() + _ = exec.CommandContext(context.Background(), "hawk", args...).Run() }() resp := ReviewResponse{ @@ -62,7 +63,7 @@ func (s *Server) handleReview(w http.ResponseWriter, r *http.Request) { func (s *Server) handleReviewStatus(w http.ResponseWriter, _ *http.Request) { // Run hawk review status and return output. - out, err := exec.Command("hawk", "review", "status").Output() + out, err := exec.CommandContext(context.Background(), "hawk", "review", "status").Output() if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return diff --git a/internal/engine/agent/agent_types.go b/internal/engine/agent/agent_types.go new file mode 100644 index 00000000..95edc26f --- /dev/null +++ b/internal/engine/agent/agent_types.go @@ -0,0 +1,21 @@ +package agent + +// SubAgentMode determines the capabilities and cost profile of a sub-agent. +type SubAgentMode string + +const ( + SubAgentExplore SubAgentMode = "explore" + SubAgentGeneral SubAgentMode = "general" +) + +// Sub-agent budget defaults per mode. +const ( + DefaultExploreTurns = 15 + DefaultGeneralTurns = 20 + MaxAgentDepth = 2 +) + +// ExploreTools are the read-only tools available to explore-mode sub-agents. +var ExploreTools = []string{ + "Glob", "Grep", "Read", "Bash", "LS", +} diff --git a/internal/engine/agent/aliases.go b/internal/engine/agent/aliases.go index f10345f2..0602af5b 100644 --- a/internal/engine/agent/aliases.go +++ b/internal/engine/agent/aliases.go @@ -1,24 +1,3 @@ -// Package agent is the Stage-1 namespace for sub-agent orchestration types. +// Package agent is the namespace for sub-agent orchestration types. // See ../REFACTOR_PLAN.md. package agent - -import "github.com/GrayCodeAI/hawk/internal/engine" - -type ( - SubAgentMode = engine.SubAgentMode - SubAgentConfig = engine.SubAgentConfig - SubAgentBudget = engine.SubAgentBudget - BackgroundAgentPool = engine.BackgroundAgentPool - BackgroundResult = engine.BackgroundResult -) - -func DefaultSubAgentConfig() SubAgentConfig { return engine.DefaultSubAgentConfig() } -func NewSubAgentBudget(mode SubAgentMode, cfg SubAgentConfig) *SubAgentBudget { - return engine.NewSubAgentBudget(mode, cfg) -} - -func FilterToolsForMode(mode SubAgentMode, available []string) []string { - return engine.FilterToolsForMode(mode, available) -} -func NewBackgroundAgentPool() *BackgroundAgentPool { return engine.NewBackgroundAgentPool() } -func FormatResults(results []BackgroundResult) string { return engine.FormatResults(results) } diff --git a/internal/engine/background_agent.go b/internal/engine/agent/background_agent.go similarity index 99% rename from internal/engine/background_agent.go rename to internal/engine/agent/background_agent.go index 4b18b47d..31e7abff 100644 --- a/internal/engine/background_agent.go +++ b/internal/engine/agent/background_agent.go @@ -1,4 +1,4 @@ -package engine +package agent import ( "context" diff --git a/internal/engine/background_agent_test.go b/internal/engine/agent/background_agent_test.go similarity index 99% rename from internal/engine/background_agent_test.go rename to internal/engine/agent/background_agent_test.go index 72be439d..33121f93 100644 --- a/internal/engine/background_agent_test.go +++ b/internal/engine/agent/background_agent_test.go @@ -1,4 +1,4 @@ -package engine +package agent import ( "context" diff --git a/internal/engine/subagent_budget.go b/internal/engine/agent/subagent_budget.go similarity index 99% rename from internal/engine/subagent_budget.go rename to internal/engine/agent/subagent_budget.go index 7689227e..f6462933 100644 --- a/internal/engine/subagent_budget.go +++ b/internal/engine/agent/subagent_budget.go @@ -1,4 +1,4 @@ -package engine +package agent // subagent_budget.go implements mode-based budget tracking and tool allowlists // for sub-agents, extracted from herm's production-grade sub-agent system. diff --git a/internal/engine/subagent_budget_test.go b/internal/engine/agent/subagent_budget_test.go similarity index 99% rename from internal/engine/subagent_budget_test.go rename to internal/engine/agent/subagent_budget_test.go index 47133064..c8e519aa 100644 --- a/internal/engine/subagent_budget_test.go +++ b/internal/engine/agent/subagent_budget_test.go @@ -1,4 +1,4 @@ -package engine +package agent import ( "testing" diff --git a/internal/engine/agent_reexports.go b/internal/engine/agent_reexports.go new file mode 100644 index 00000000..67673317 --- /dev/null +++ b/internal/engine/agent_reexports.go @@ -0,0 +1,38 @@ +// This file re-exports symbols from the agent sub-package so that existing +// callers of engine.SubAgentMode, engine.NewSubAgentBudget, etc. keep compiling +// during the Stage 2 migration. See REFACTOR_PLAN.md. +package engine + +import "github.com/GrayCodeAI/hawk/internal/engine/agent" + +type ( + SubAgentMode = agent.SubAgentMode + SubAgentConfig = agent.SubAgentConfig + SubAgentBudget = agent.SubAgentBudget + BackgroundAgentPool = agent.BackgroundAgentPool + BackgroundResult = agent.BackgroundResult +) + +const ( + SubAgentExplore = agent.SubAgentExplore + SubAgentGeneral = agent.SubAgentGeneral + DefaultExploreTurns = agent.DefaultExploreTurns + DefaultGeneralTurns = agent.DefaultGeneralTurns + MaxAgentDepth = agent.MaxAgentDepth +) + +var ( + ExploreTools = agent.ExploreTools + ModeToolAllowlist = agent.ModeToolAllowlist +) + +func DefaultSubAgentConfig() SubAgentConfig { return agent.DefaultSubAgentConfig() } +func NewSubAgentBudget(mode SubAgentMode, cfg SubAgentConfig) *SubAgentBudget { + return agent.NewSubAgentBudget(mode, cfg) +} + +func FilterToolsForMode(mode SubAgentMode, available []string) []string { + return agent.FilterToolsForMode(mode, available) +} +func NewBackgroundAgentPool() *BackgroundAgentPool { return agent.NewBackgroundAgentPool() } +func FormatResults(results []BackgroundResult) string { return agent.FormatResults(results) } diff --git a/internal/engine/agent.go b/internal/engine/agent_session_tool.go similarity index 81% rename from internal/engine/agent.go rename to internal/engine/agent_session_tool.go index e4d496f7..20570283 100644 --- a/internal/engine/agent.go +++ b/internal/engine/agent_session_tool.go @@ -8,26 +8,6 @@ import ( "github.com/GrayCodeAI/hawk/internal/tool" ) -// SubAgentMode determines the capabilities and cost profile of a sub-agent. -type SubAgentMode string - -const ( - SubAgentExplore SubAgentMode = "explore" - SubAgentGeneral SubAgentMode = "general" -) - -// Sub-agent budget defaults per mode. -const ( - DefaultExploreTurns = 15 - DefaultGeneralTurns = 20 - MaxAgentDepth = 2 -) - -// ExploreTools are the read-only tools available to explore-mode sub-agents. -var ExploreTools = []string{ - "Glob", "Grep", "Read", "Bash", "LS", -} - // WireAgentTool sets up sub-agent spawning with two modes: // - explore: fast/cheap model, read-only tools, higher turn budget // - general: full model, all tools, standard budget diff --git a/internal/engine/assumptions.go b/internal/engine/assumptions.go index 434873f2..1eb326f4 100644 --- a/internal/engine/assumptions.go +++ b/internal/engine/assumptions.go @@ -1,6 +1,7 @@ package engine import ( + "context" "fmt" "os" "os/exec" @@ -62,7 +63,7 @@ func (at *AssumptionTracker) VerifyCommandSucceeds(text, cmd string) { at.mu.Lock() defer at.mu.Unlock() a := Assumption{Text: text} - out, err := exec.Command("sh", "-c", cmd).CombinedOutput() + out, err := exec.CommandContext(context.Background(), "sh", "-c", cmd).CombinedOutput() if err == nil { a.Status = AssumptionConfirmed a.Proof = "command succeeded" diff --git a/internal/engine/auto_commit.go b/internal/engine/auto_commit.go index 1564a9ce..c039cebc 100644 --- a/internal/engine/auto_commit.go +++ b/internal/engine/auto_commit.go @@ -1,6 +1,7 @@ package engine import ( + "context" "fmt" "os/exec" "strings" @@ -25,7 +26,7 @@ func (ac *AutoCommitter) CommitIfChanged(description string) error { return nil } // Check if there are changes - cmd := exec.Command("git", "status", "--porcelain") + cmd := exec.CommandContext(context.Background(), "git", "status", "--porcelain") cmd.Dir = ac.RepoDir out, err := cmd.Output() if err != nil || len(strings.TrimSpace(string(out))) == 0 { @@ -33,7 +34,7 @@ func (ac *AutoCommitter) CommitIfChanged(description string) error { } // Stage all changes - stage := exec.Command("git", "add", "-A") + stage := exec.CommandContext(context.Background(), "git", "add", "-A") stage.Dir = ac.RepoDir if err := stage.Run(); err != nil { return err @@ -43,14 +44,14 @@ func (ac *AutoCommitter) CommitIfChanged(description string) error { msg := ac.generateMessage(description) // Commit - commit := exec.Command("git", "commit", "-m", msg, "--no-verify") + commit := exec.CommandContext(context.Background(), "git", "commit", "-m", msg, "--no-verify") commit.Dir = ac.RepoDir return commit.Run() } // Undo reverts the last auto-commit. func (ac *AutoCommitter) Undo() error { - cmd := exec.Command("git", "reset", "--soft", "HEAD~1") + cmd := exec.CommandContext(context.Background(), "git", "reset", "--soft", "HEAD~1") cmd.Dir = ac.RepoDir return cmd.Run() } diff --git a/internal/engine/branching/aliases.go b/internal/engine/branching/aliases.go index 207da5a4..6f0626f4 100644 --- a/internal/engine/branching/aliases.go +++ b/internal/engine/branching/aliases.go @@ -1,29 +1,4 @@ -// Package branching is the Stage-1 namespace for branching strategies, cascade, council, shadow, snowball. -// See ../REFACTOR_PLAN.md. package branching -import "github.com/GrayCodeAI/hawk/internal/engine" - -type ( - BranchMessage = engine.BranchMessage - ConversationBranch = engine.ConversationBranch - BranchManager = engine.BranchManager - CascadeRouter = engine.CascadeRouter - RoutingDecision = engine.RoutingDecision - ModelTier = engine.ModelTier - CouncilConfig = engine.CouncilConfig - CouncilResponse = engine.CouncilResponse - CouncilRanking = engine.CouncilRanking - CouncilResult = engine.CouncilResult - ShadowWorkspace = engine.ShadowWorkspace - SnowballDetector = engine.SnowballDetector -) - -var ( - NewBranchManager = engine.NewBranchManager - NewCascadeRouter = engine.NewCascadeRouter - RunCouncil = engine.RunCouncil - DefaultCouncilModels = engine.DefaultCouncilModels - NewShadowWorkspace = engine.NewShadowWorkspace - NewSnowballDetector = engine.NewSnowballDetector -) +// All branching symbols are now defined locally in this package. +// Council symbols remain in engine and can be accessed via the engine package directly. diff --git a/internal/engine/branching.go b/internal/engine/branching/branching.go similarity index 99% rename from internal/engine/branching.go rename to internal/engine/branching/branching.go index 183b62c2..457a925c 100644 --- a/internal/engine/branching.go +++ b/internal/engine/branching/branching.go @@ -1,4 +1,4 @@ -package engine +package branching import ( "crypto/rand" diff --git a/internal/engine/branching_test.go b/internal/engine/branching/branching_test.go similarity index 99% rename from internal/engine/branching_test.go rename to internal/engine/branching/branching_test.go index 8e4f3bb9..15f4db30 100644 --- a/internal/engine/branching_test.go +++ b/internal/engine/branching/branching_test.go @@ -1,4 +1,4 @@ -package engine +package branching import ( "encoding/json" diff --git a/internal/engine/cascade.go b/internal/engine/branching/cascade.go similarity index 98% rename from internal/engine/cascade.go rename to internal/engine/branching/cascade.go index f28ddd7f..a8c31108 100644 --- a/internal/engine/cascade.go +++ b/internal/engine/branching/cascade.go @@ -1,4 +1,4 @@ -package engine +package branching import ( "fmt" @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/GrayCodeAI/hawk/internal/engine/cost" "github.com/GrayCodeAI/hawk/internal/provider/routing" eycatalog "github.com/GrayCodeAI/eyrie/catalog" @@ -113,8 +114,8 @@ func (cr *CascadeRouter) Savings() float64 { defer cr.mu.Unlock() var saved float64 for _, d := range cr.decisions { - origIn, _ := pricingForModel(d.OriginalModel) - selIn, _ := pricingForModel(d.SelectedModel) + origIn, _ := cost.ModelPricing(d.OriginalModel) + selIn, _ := cost.ModelPricing(d.SelectedModel) if selIn < origIn { // Rough estimate: assume 4000 input tokens per request. saved += (origIn - selIn) * 4000 / 1_000_000 diff --git a/internal/engine/cascade_test.go b/internal/engine/branching/cascade_test.go similarity index 99% rename from internal/engine/cascade_test.go rename to internal/engine/branching/cascade_test.go index 61e2f520..2d55decc 100644 --- a/internal/engine/cascade_test.go +++ b/internal/engine/branching/cascade_test.go @@ -1,4 +1,4 @@ -package engine +package branching import ( "testing" diff --git a/internal/engine/shadow.go b/internal/engine/branching/shadow.go similarity index 65% rename from internal/engine/shadow.go rename to internal/engine/branching/shadow.go index fdf8e05a..8e269496 100644 --- a/internal/engine/shadow.go +++ b/internal/engine/branching/shadow.go @@ -1,13 +1,24 @@ -package engine +package branching import ( + "context" "fmt" "os" "os/exec" "path/filepath" + "regexp" + "strconv" "strings" ) +// ValidationError represents a single validation issue. +type ValidationError struct { + File string + Line int + Column int + Message string +} + // ShadowWorkspace provides a temporary directory where file edits can be // validated (e.g. via `go vet`, `tsc`, `pylint`) without touching the // original source tree. @@ -95,7 +106,7 @@ func shadowValidateGo(tmpPath, origPath string) []ValidationError { defer func() { _ = os.Remove(modPath) }() } - cmd := exec.Command("go", "vet", "./...") + cmd := exec.CommandContext(context.Background(), "go", "vet", "./...") cmd.Dir = dir output, err := cmd.CombinedOutput() if err == nil { @@ -115,7 +126,7 @@ func shadowValidateGo(tmpPath, origPath string) []ValidationError { // shadowValidatePython runs `python3 -c "import py_compile; ..."` on the temp file. func shadowValidatePython(tmpPath, origPath string) []ValidationError { - cmd := exec.Command("python3", "-c", + cmd := exec.CommandContext(context.Background(), "python3", "-c", fmt.Sprintf("import py_compile; py_compile.compile('%s', doraise=True)", tmpPath)) output, err := cmd.CombinedOutput() if err == nil { @@ -131,7 +142,7 @@ func shadowValidatePython(tmpPath, origPath string) []ValidationError { // shadowValidateTS runs `npx tsc --noEmit` on the temp file. func shadowValidateTS(tmpPath, origPath string) []ValidationError { - cmd := exec.Command("npx", "tsc", "--noEmit", "--allowJs", tmpPath) + cmd := exec.CommandContext(context.Background(), "npx", "tsc", "--noEmit", "--allowJs", tmpPath) output, err := cmd.CombinedOutput() if err == nil { return nil @@ -147,3 +158,69 @@ func shadowValidateTS(tmpPath, origPath string) []ValidationError { } return parsed } + +// --- parser helpers (moved inline from validate.go for independence) --- + +var goErrorRe = regexp.MustCompile(`([^:]+\.go):(\d+):(\d+):\s*(.+)`) + +func parseGoErrors(output string) []ValidationError { + var errors []ValidationError + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if matches := goErrorRe.FindStringSubmatch(line); matches != nil { + lineNum, _ := strconv.Atoi(matches[2]) + colNum, _ := strconv.Atoi(matches[3]) + errors = append(errors, ValidationError{ + File: matches[1], + Line: lineNum, + Column: colNum, + Message: matches[4], + }) + } + } + return errors +} + +var ( + pythonLineRe = regexp.MustCompile(`File "([^"]+)", line (\d+)`) + pythonErrorRe = regexp.MustCompile(`(SyntaxError|IndentationError|TabError):\s*(.+)`) +) + +func parsePythonErrors(output, path string) []ValidationError { + var errors []ValidationError + lines := strings.Split(output, "\n") + lineNum := 0 + + for _, line := range lines { + if matches := pythonLineRe.FindStringSubmatch(line); matches != nil { + lineNum, _ = strconv.Atoi(matches[2]) + } + if matches := pythonErrorRe.FindStringSubmatch(line); matches != nil { + errors = append(errors, ValidationError{ + File: path, + Line: lineNum, + Message: matches[1] + ": " + matches[2], + }) + } + } + return errors +} + +var tsErrorRe = regexp.MustCompile(`\((\d+),(\d+)\):\s*error\s+\w+:\s*(.+)`) + +func parseTSErrors(output, path string) []ValidationError { + var errors []ValidationError + for _, line := range strings.Split(output, "\n") { + if matches := tsErrorRe.FindStringSubmatch(line); matches != nil { + lineNum, _ := strconv.Atoi(matches[1]) + colNum, _ := strconv.Atoi(matches[2]) + errors = append(errors, ValidationError{ + File: path, + Line: lineNum, + Column: colNum, + Message: matches[3], + }) + } + } + return errors +} diff --git a/internal/engine/shadow_test.go b/internal/engine/branching/shadow_test.go similarity index 99% rename from internal/engine/shadow_test.go rename to internal/engine/branching/shadow_test.go index 6229c2c9..36788ee4 100644 --- a/internal/engine/shadow_test.go +++ b/internal/engine/branching/shadow_test.go @@ -1,4 +1,4 @@ -package engine +package branching import ( "os" diff --git a/internal/engine/snowball.go b/internal/engine/branching/snowball.go similarity index 99% rename from internal/engine/snowball.go rename to internal/engine/branching/snowball.go index 6b8b6b19..8c9e5a50 100644 --- a/internal/engine/snowball.go +++ b/internal/engine/branching/snowball.go @@ -1,4 +1,4 @@ -package engine +package branching import ( "fmt" diff --git a/internal/engine/snowball_test.go b/internal/engine/branching/snowball_test.go similarity index 99% rename from internal/engine/snowball_test.go rename to internal/engine/branching/snowball_test.go index c0d63e8b..ad5fc82e 100644 --- a/internal/engine/snowball_test.go +++ b/internal/engine/branching/snowball_test.go @@ -1,4 +1,4 @@ -package engine +package branching import ( "strings" diff --git a/internal/engine/code/aliases.go b/internal/engine/code/aliases.go index a386ca13..8f20a0bb 100644 --- a/internal/engine/code/aliases.go +++ b/internal/engine/code/aliases.go @@ -1,34 +1,3 @@ -// Package code is the Stage-1 namespace for code-aware features -// (context extraction, lenses, actions, explainer). See ../REFACTOR_PLAN.md. +// Package code provides code-aware features: context extraction, +// lenses, actions, and explainer. See ../REFACTOR_PLAN.md. package code - -import "github.com/GrayCodeAI/hawk/internal/engine" - -type ( - Snippet = engine.CodeSnippet - Context = engine.CodeContext - ContextExtractor = engine.ContextExtractor - Lens = engine.CodeLens - LensGenerator = engine.LensGenerator - LensProvider = engine.CodeLensProvider - Action = engine.CodeAction - ActionDetector = engine.ActionDetector - ActionRule = engine.ActionRule - Explanation = engine.CodeExplanation - ExplanationSection = engine.ExplanationSection - Explainer = engine.CodeExplainer -) - -func NewContextExtractor(projectDir string, maxTokens int) *ContextExtractor { - return engine.NewContextExtractor(projectDir, maxTokens) -} -func FormatContext(ctx *Context) string { return engine.FormatContext(ctx) } -func NewLensProvider() *LensProvider { return engine.NewCodeLensProvider() } -func NewActionDetector() *ActionDetector { return engine.NewActionDetector() } -func NewExplainer() *Explainer { return engine.NewCodeExplainer() } -func FormatExplanation(exp *Explanation) string { return engine.FormatExplanation(exp) } -func FormatSuggestions(actions []Action, max int) string { - return engine.FormatSuggestions(actions, max) -} - -func ApplyFix(action Action, content string) (string, error) { return engine.ApplyFix(action, content) } diff --git a/internal/engine/code_actions.go b/internal/engine/code/code_actions.go similarity index 89% rename from internal/engine/code_actions.go rename to internal/engine/code/code_actions.go index 602e0f5b..f6d205ae 100644 --- a/internal/engine/code_actions.go +++ b/internal/engine/code/code_actions.go @@ -1,4 +1,4 @@ -package engine +package code import ( "fmt" @@ -10,26 +10,23 @@ import ( "text/template" ) -// CodeAction represents a suggested improvement for a specific location in code. type CodeAction struct { ID string Title string Description string File string Line int - Category string // "refactor", "performance", "security", "style", "fix" - Priority int // 1=high, 3=medium, 5=low - Fix string // suggested replacement code - Confidence float64 // 0.0 to 1.0 + Category string + Priority int + Fix string + Confidence float64 } -// ActionDetector scans code content and produces CodeAction suggestions. type ActionDetector struct { Rules []ActionRule mu sync.RWMutex } -// ActionRule defines a pattern-based detection rule for code actions. type ActionRule struct { ID string Name string @@ -39,10 +36,9 @@ type ActionRule struct { Antipattern *regexp.Regexp Priority int Message string - FixTemplate string // Go template for generating fix + FixTemplate string } -// NewActionDetector creates an ActionDetector pre-loaded with 25+ built-in rules. func NewActionDetector() *ActionDetector { ad := &ActionDetector{ Rules: builtinRules(), @@ -52,7 +48,6 @@ func NewActionDetector() *ActionDetector { func builtinRules() []ActionRule { return []ActionRule{ - // --- Go refactoring --- { ID: "go-err-wrap", Name: "Wrap error with context", @@ -113,7 +108,6 @@ func builtinRules() []ActionRule { Priority: 5, Message: "Naked return in function; consider using named returns explicitly", }, - // --- Go performance --- { ID: "go-append-loop", Name: "Pre-allocate slice", @@ -154,7 +148,6 @@ func builtinRules() []ActionRule { Priority: 5, Message: "Map access without ok check; consider comma-ok pattern", }, - // --- Python --- { ID: "py-bare-except", Name: "Avoid bare except", @@ -206,7 +199,6 @@ func builtinRules() []ActionRule { Priority: 5, Message: "Use f-string or .format() instead of % formatting", }, - // --- TypeScript --- { ID: "ts-any-type", Name: "Avoid any type", @@ -255,7 +247,6 @@ func builtinRules() []ActionRule { Priority: 3, Message: "Non-null assertion operator; add proper null check", }, - // --- Universal --- { ID: "todo-comment", Name: "Address TODO/FIXME/HACK comment", @@ -313,7 +304,6 @@ func builtinRules() []ActionRule { } } -// Detect runs all matching rules against the content and returns actions sorted by priority. func (ad *ActionDetector) Detect(path, content string) []CodeAction { ad.mu.RLock() defer ad.mu.RUnlock() @@ -327,12 +317,10 @@ func (ad *ActionDetector) Detect(path, content string) []CodeAction { continue } - // For multiline patterns, try matching on the full content if isMultilinePattern(rule) { matches := rule.Pattern.FindAllStringIndex(content, -1) for _, m := range matches { if rule.Antipattern != nil { - // Check surrounding context for antipattern start := m[0] if start > 200 { start = m[0] - 200 @@ -366,7 +354,6 @@ func (ad *ActionDetector) Detect(path, content string) []CodeAction { continue } - // Line-by-line matching for single-line patterns for i, line := range lines { if !rule.Pattern.MatchString(line) { continue @@ -400,7 +387,6 @@ func (ad *ActionDetector) Detect(path, content string) []CodeAction { return actions } -// DetectForDiff only checks added lines in a unified diff. func (ad *ActionDetector) DetectForDiff(diff string) []CodeAction { added := extractAddedLines(diff) if added.path == "" { @@ -409,7 +395,6 @@ func (ad *ActionDetector) DetectForDiff(diff string) []CodeAction { return ad.Detect(added.path, added.content) } -// FormatSuggestions formats code actions for display. func FormatSuggestions(actions []CodeAction, maxDisplay int) string { if len(actions) == 0 { return "" @@ -440,7 +425,6 @@ func FormatSuggestions(actions []CodeAction, maxDisplay int) string { return sb.String() } -// ApplyFix applies the suggested fix to the file content at the correct line. func ApplyFix(action CodeAction, content string) (string, error) { if action.Fix == "" { return "", fmt.Errorf("no fix available for action %s", action.ID) @@ -457,7 +441,6 @@ func ApplyFix(action CodeAction, content string) (string, error) { idx := action.Line - 1 original := lines[idx] - // Find the rule to get the pattern detector := NewActionDetector() var rule *ActionRule for i := range detector.Rules { @@ -475,14 +458,11 @@ func ApplyFix(action CodeAction, content string) (string, error) { } } - // Fallback: replace the entire line with the fix preserving indentation indent := extractIndent(original) lines[idx] = indent + strings.TrimSpace(action.Fix) return strings.Join(lines, "\n"), nil } -// --- helpers --- - func detectLanguageFromPath(path string) string { lower := strings.ToLower(path) switch { @@ -493,7 +473,7 @@ func detectLanguageFromPath(path string) string { case strings.HasSuffix(lower, ".ts"), strings.HasSuffix(lower, ".tsx"): return "typescript" case strings.HasSuffix(lower, ".js"), strings.HasSuffix(lower, ".jsx"): - return "typescript" // JS rules overlap significantly + return "typescript" case strings.HasSuffix(lower, ".rs"): return "rust" case strings.HasSuffix(lower, ".rb"): @@ -504,7 +484,6 @@ func detectLanguageFromPath(path string) string { } func isMultilinePattern(rule ActionRule) bool { - // Patterns that use (?s) or contain \n or brace matching are multiline p := rule.Pattern.String() return strings.Contains(p, "(?s)") || strings.Contains(p, `\{[^}]*`) || strings.Contains(p, `\{[^\}]`) } @@ -518,12 +497,10 @@ func generateFix(rule ActionRule, matched string) string { return "" } - // Simple template without Go template syntax if !strings.Contains(rule.FixTemplate, "{{") { return rule.FixTemplate } - // Parse and execute template tmpl, err := template.New("fix").Parse(rule.FixTemplate) if err != nil { return rule.FixTemplate @@ -542,7 +519,6 @@ func generateFix(rule ActionRule, matched string) string { } func computeConfidence(rule ActionRule) float64 { - // Higher priority rules tend to have higher confidence switch rule.Priority { case 1: return 0.9 @@ -560,17 +536,17 @@ func computeConfidence(rule ActionRule) float64 { func categoryIcon(category string) string { switch category { case "refactor": - return "\U0001f527" // wrench + return "\U0001f527" case "performance": - return "⚡" // lightning + return "⚡" case "security": - return "\U0001f512" // lock + return "\U0001f512" case "style": - return "\U0001f3a8" // palette + return "\U0001f3a8" case "fix": - return "\U0001f41b" // bug + return "\U0001f41b" default: - return "\U0001f4a1" // bulb + return "\U0001f4a1" } } @@ -600,7 +576,6 @@ func extractAddedLines(diff string) diffLines { continue } if strings.HasPrefix(line, "@@") { - // Parse the line number from @@ -a,b +c,d @@ parts := strings.Split(line, "+") if len(parts) >= 2 { numStr := strings.Split(parts[1], ",")[0] @@ -619,7 +594,7 @@ func extractAddedLines(diff string) diffLines { } } - _ = currentLine // used for line tracking during parse + _ = currentLine return diffLines{ path: path, diff --git a/internal/engine/code_actions_test.go b/internal/engine/code/code_actions_test.go similarity index 99% rename from internal/engine/code_actions_test.go rename to internal/engine/code/code_actions_test.go index 9b81465c..a82806b3 100644 --- a/internal/engine/code_actions_test.go +++ b/internal/engine/code/code_actions_test.go @@ -1,4 +1,4 @@ -package engine +package code import ( "strings" diff --git a/internal/engine/code_context.go b/internal/engine/code/code_context.go similarity index 78% rename from internal/engine/code_context.go rename to internal/engine/code/code_context.go index aca4aab4..d0845718 100644 --- a/internal/engine/code_context.go +++ b/internal/engine/code/code_context.go @@ -1,7 +1,8 @@ -package engine +package code import ( "bufio" + "context" "fmt" "os" "os/exec" @@ -12,18 +13,16 @@ import ( "sync" ) -// CodeSnippet represents an extracted piece of code with metadata. type CodeSnippet struct { - File string // relative file path - StartLine int // first line (1-based) - EndLine int // last line (1-based) - Content string // the actual code - Relevance float64 // 0.0 - 1.0 relevance score - Type string // "function", "type", "block", "import" - Symbol string // name of the symbol (function/type name) + File string + StartLine int + EndLine int + Content string + Relevance float64 + Type string + Symbol string } -// CodeContext holds a collection of relevant code snippets for a task. type CodeContext struct { Snippets []CodeSnippet TotalTokens int @@ -31,14 +30,12 @@ type CodeContext struct { mu sync.RWMutex } -// ContextExtractor intelligently extracts relevant code snippets for tasks. type ContextExtractor struct { ProjectDir string MaxTokens int mu sync.Mutex } -// NewContextExtractor creates a new extractor rooted at the given project directory. func NewContextExtractor(projectDir string, maxTokens int) *ContextExtractor { if maxTokens <= 0 { maxTokens = 8000 @@ -49,8 +46,6 @@ func NewContextExtractor(projectDir string, maxTokens int) *ContextExtractor { } } -// ExtractForTask analyzes a task description and extracts relevant code snippets -// that fit within the token budget. func (ce *ContextExtractor) ExtractForTask(task string) (*CodeContext, error) { ce.mu.Lock() defer ce.mu.Unlock() @@ -59,16 +54,13 @@ func (ce *ContextExtractor) ExtractForTask(task string) (*CodeContext, error) { Query: task, } - // Find relevant symbols based on the task description symbols := ce.FindRelevantSymbols(task, 20) if len(symbols) == 0 { return ctx, nil } - // Rank snippets by relevance to the task ranked := ce.RankSnippets(symbols, task) - // Fit within token budget totalTokens := 0 for _, snip := range ranked { tokens := codeCtxEstimateTokens(snip.Content) @@ -83,7 +75,6 @@ func (ce *ContextExtractor) ExtractForTask(task string) (*CodeContext, error) { return ctx, nil } -// ExtractFunction extracts a single function's complete code from the given file. func (ce *ContextExtractor) ExtractFunction(file, funcName string) (*CodeSnippet, error) { fullPath := ce.resolvePath(file) lines, err := readFileLines(fullPath) @@ -91,7 +82,6 @@ func (ce *ContextExtractor) ExtractFunction(file, funcName string) (*CodeSnippet return nil, fmt.Errorf("reading %s: %w", file, err) } - // Pattern matches: func Name(, func (receiver) Name( funcPattern := regexp.MustCompile(`^func\s+(\([^)]*\)\s+)?` + regexp.QuoteMeta(funcName) + `\s*[\[(]`) startLine := -1 @@ -105,7 +95,6 @@ func (ce *ContextExtractor) ExtractFunction(file, funcName string) (*CodeSnippet return nil, fmt.Errorf("function %s not found in %s", funcName, file) } - // Find the end of the function by counting braces endLine := findBlockEnd(lines, startLine) content := strings.Join(lines[startLine:endLine+1], "\n") @@ -120,7 +109,6 @@ func (ce *ContextExtractor) ExtractFunction(file, funcName string) (*CodeSnippet }, nil } -// ExtractType extracts a type definition and its methods from the given file. func (ce *ContextExtractor) ExtractType(file, typeName string) (*CodeSnippet, error) { fullPath := ce.resolvePath(file) lines, err := readFileLines(fullPath) @@ -128,7 +116,6 @@ func (ce *ContextExtractor) ExtractType(file, typeName string) (*CodeSnippet, er return nil, fmt.Errorf("reading %s: %w", file, err) } - // Find type declaration typePattern := regexp.MustCompile(`^type\s+` + regexp.QuoteMeta(typeName) + `\s+`) startLine := -1 for i, line := range lines { @@ -141,13 +128,11 @@ func (ce *ContextExtractor) ExtractType(file, typeName string) (*CodeSnippet, er return nil, fmt.Errorf("type %s not found in %s", typeName, file) } - // Find end of type definition endLine := startLine if strings.Contains(lines[startLine], "{") { endLine = findBlockEnd(lines, startLine) } - // Collect methods for this type methodPattern := regexp.MustCompile(`^func\s+\([^)]*\*?` + regexp.QuoteMeta(typeName) + `\)\s+`) var methodBlocks []string for i, line := range lines { @@ -176,7 +161,6 @@ func (ce *ContextExtractor) ExtractType(file, typeName string) (*CodeSnippet, er }, nil } -// ExtractImports extracts the import block from the given file. func (ce *ContextExtractor) ExtractImports(file string) (*CodeSnippet, error) { fullPath := ce.resolvePath(file) lines, err := readFileLines(fullPath) @@ -199,7 +183,6 @@ func (ce *ContextExtractor) ExtractImports(file string) (*CodeSnippet, error) { } break } else if strings.HasPrefix(trimmed, "import ") && !strings.Contains(trimmed, "(") { - // Single-line import startLine = i endLine = i break @@ -222,7 +205,6 @@ func (ce *ContextExtractor) ExtractImports(file string) (*CodeSnippet, error) { }, nil } -// ExtractSurrounding extracts N lines of context around a target line. func (ce *ContextExtractor) ExtractSurrounding(file string, line, contextLines int) (*CodeSnippet, error) { fullPath := ce.resolvePath(file) lines, err := readFileLines(fullPath) @@ -255,8 +237,6 @@ func (ce *ContextExtractor) ExtractSurrounding(file string, line, contextLines i }, nil } -// FindRelevantSymbols searches for symbols matching the query using grep and -// simple AST-like pattern matching. func (ce *ContextExtractor) FindRelevantSymbols(query string, limit int) []CodeSnippet { keywords := extractKeywords(query) if len(keywords) == 0 { @@ -265,16 +245,13 @@ func (ce *ContextExtractor) FindRelevantSymbols(query string, limit int) []CodeS var allSnippets []CodeSnippet - // Use grep to find files containing the keywords matchedFiles := ce.grepForFiles(keywords) - // For each matched file, extract relevant symbols for _, file := range matchedFiles { snippets := ce.extractSymbolsFromFile(file, keywords) allSnippets = append(allSnippets, snippets...) } - // Deduplicate by file+symbol seen := make(map[string]bool) var unique []CodeSnippet for _, s := range allSnippets { @@ -291,20 +268,17 @@ func (ce *ContextExtractor) FindRelevantSymbols(query string, limit int) []CodeS return unique } -// RankSnippets scores and sorts snippets by relevance to the query. func (ce *ContextExtractor) RankSnippets(snippets []CodeSnippet, query string) []CodeSnippet { keywords := extractKeywords(query) if len(keywords) == 0 { return snippets } - // Score each snippet for i := range snippets { score := scoreSnippet(&snippets[i], keywords) snippets[i].Relevance = score } - // Sort by relevance descending sort.Slice(snippets, func(i, j int) bool { return snippets[i].Relevance > snippets[j].Relevance }) @@ -312,7 +286,6 @@ func (ce *ContextExtractor) RankSnippets(snippets []CodeSnippet, query string) [ return snippets } -// FormatContext produces a markdown-formatted representation of the code context. func FormatContext(ctx *CodeContext) string { if ctx == nil || len(ctx.Snippets) == 0 { return "" @@ -341,13 +314,10 @@ func FormatContext(ctx *CodeContext) string { return sb.String() } -// codeCtxEstimateTokens gives a rough token count for a piece of content. -// Uses the approximation of ~4 characters per token for code. func codeCtxEstimateTokens(content string) int { if content == "" { return 0 } - // Rough approximation: 1 token ≈ 4 characters for code chars := len(content) tokens := (chars + 3) / 4 if tokens == 0 { @@ -356,8 +326,6 @@ func codeCtxEstimateTokens(content string) int { return tokens } -// --- Internal helpers --- - func (ce *ContextExtractor) resolvePath(file string) string { if filepath.IsAbs(file) { return file @@ -374,7 +342,6 @@ func readFileLines(path string) ([]string, error) { var lines []string scanner := bufio.NewScanner(f) - // Increase buffer size for long lines scanner.Buffer(make([]byte, 1024*1024), 1024*1024) for scanner.Scan() { lines = append(lines, scanner.Text()) @@ -385,7 +352,6 @@ func readFileLines(path string) ([]string, error) { return lines, nil } -// findBlockEnd finds the closing brace for a block starting at startLine. func findBlockEnd(lines []string, startLine int) int { depth := 0 for i := startLine; i < len(lines); i++ { @@ -400,13 +366,10 @@ func findBlockEnd(lines []string, startLine int) int { } } } - // If we never find the closing brace, return the last line return len(lines) - 1 } -// extractKeywords splits a task/query into meaningful keywords. func extractKeywords(query string) []string { - // Remove common stop words and extract meaningful terms stopWords := map[string]bool{ "the": true, "a": true, "an": true, "is": true, "are": true, "was": true, "were": true, "be": true, "been": true, "being": true, @@ -422,7 +385,6 @@ func extractKeywords(query string) []string { "your": true, "they": true, "them": true, "their": true, } - // Split on non-alphanumeric characters splitter := regexp.MustCompile(`[^a-zA-Z0-9_]+`) words := splitter.Split(strings.ToLower(query), -1) @@ -443,16 +405,14 @@ func extractKeywords(query string) []string { return keywords } -// grepForFiles finds files containing any of the keywords. func (ce *ContextExtractor) grepForFiles(keywords []string) []string { if len(keywords) == 0 { return nil } - // Build a grep pattern matching any keyword pattern := strings.Join(keywords, "|") - cmd := exec.Command("grep", "-rl", "--include=*.go", "-E", pattern, ce.ProjectDir) + cmd := exec.CommandContext(context.Background(), "grep", "-rl", "--include=*.go", "-E", pattern, ce.ProjectDir) out, err := cmd.Output() if err != nil { return nil @@ -464,7 +424,6 @@ func (ce *ContextExtractor) grepForFiles(keywords []string) []string { if line == "" { continue } - // Convert to relative path rel, err := filepath.Rel(ce.ProjectDir, line) if err != nil { rel = line @@ -475,15 +434,12 @@ func (ce *ContextExtractor) grepForFiles(keywords []string) []string { } } - // Limit to avoid processing too many files if len(files) > 30 { files = files[:30] } return files } -// extractSymbolsFromFile parses a Go file for function and type declarations -// that match the given keywords. func (ce *ContextExtractor) extractSymbolsFromFile(file string, keywords []string) []CodeSnippet { fullPath := ce.resolvePath(file) lines, err := readFileLines(fullPath) @@ -497,7 +453,6 @@ func (ce *ContextExtractor) extractSymbolsFromFile(file string, keywords []strin var snippets []CodeSnippet for i, line := range lines { - // Check for function declarations if matches := funcPattern.FindStringSubmatch(line); matches != nil { funcName := matches[1] if symbolMatchesKeywords(funcName, keywords) { @@ -514,7 +469,6 @@ func (ce *ContextExtractor) extractSymbolsFromFile(file string, keywords []strin } } - // Check for type declarations if matches := typePattern.FindStringSubmatch(line); matches != nil { typeName := matches[1] if symbolMatchesKeywords(typeName, keywords) { @@ -538,7 +492,6 @@ func (ce *ContextExtractor) extractSymbolsFromFile(file string, keywords []strin return snippets } -// symbolMatchesKeywords checks whether a symbol name matches any keyword. func symbolMatchesKeywords(symbol string, keywords []string) bool { lower := strings.ToLower(symbol) for _, kw := range keywords { @@ -549,7 +502,6 @@ func symbolMatchesKeywords(symbol string, keywords []string) bool { return false } -// scoreSnippet computes a relevance score for a snippet given the query keywords. func scoreSnippet(snip *CodeSnippet, keywords []string) float64 { if len(keywords) == 0 { return 0.0 @@ -559,36 +511,30 @@ func scoreSnippet(snip *CodeSnippet, keywords []string) float64 { lowerContent := strings.ToLower(snip.Content) lowerSymbol := strings.ToLower(snip.Symbol) - // Keyword overlap: how many keywords appear in the content matchCount := 0 for _, kw := range keywords { if strings.Contains(lowerContent, kw) { matchCount++ } - // Bonus if keyword appears in the symbol name if strings.Contains(lowerSymbol, kw) { score += 0.15 } } score += float64(matchCount) / float64(len(keywords)) * 0.5 - // Boost exported symbols (starts with uppercase) if len(snip.Symbol) > 0 && snip.Symbol[0] >= 'A' && snip.Symbol[0] <= 'Z' { score += 0.1 } - // Boost functions over other types (they tend to be more actionable) if snip.Type == "function" { score += 0.05 } - // Penalize very large snippets (they're less focused) lineCount := snip.EndLine - snip.StartLine + 1 if lineCount > 100 { score -= 0.1 } - // Clamp to [0.0, 1.0] if score > 1.0 { score = 1.0 } diff --git a/internal/engine/code_context_test.go b/internal/engine/code/code_context_test.go similarity index 99% rename from internal/engine/code_context_test.go rename to internal/engine/code/code_context_test.go index 5d2174e8..285dd02f 100644 --- a/internal/engine/code_context_test.go +++ b/internal/engine/code/code_context_test.go @@ -1,4 +1,4 @@ -package engine +package code import ( "os" diff --git a/internal/engine/code_explainer.go b/internal/engine/code/code_explainer.go similarity index 92% rename from internal/engine/code_explainer.go rename to internal/engine/code/code_explainer.go index b9759b41..fba30b97 100644 --- a/internal/engine/code_explainer.go +++ b/internal/engine/code/code_explainer.go @@ -1,4 +1,4 @@ -package engine +package code import ( "fmt" @@ -10,7 +10,6 @@ import ( "sync" ) -// CodeExplanation holds a structured explanation of a code element. type CodeExplanation struct { File string Symbol string @@ -21,26 +20,20 @@ type CodeExplanation struct { UsedBy []string } -// ExplanationSection is a titled portion of an explanation with optional code reference. type ExplanationSection struct { Title string Content string CodeRef string } -// CodeExplainer generates structured explanations of code using AST analysis -// and pattern recognition, without any LLM calls. type CodeExplainer struct { mu sync.Mutex } -// NewCodeExplainer creates a new CodeExplainer instance. func NewCodeExplainer() *CodeExplainer { return &CodeExplainer{} } -// ExplainFunction parses the given file content and generates a structured explanation -// for the named function. func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeExplanation, error) { ce.mu.Lock() defer ce.mu.Unlock() @@ -66,14 +59,10 @@ func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeE return nil, fmt.Errorf("function %q not found in %s", funcName, file) } - // Extract parameters params := explainerExtractParams(funcDecl) - // Extract return types returns := extractReturns(funcDecl) - // Extract doc comment docComment := explainerExtractDocComment(funcDecl) - // Build purpose paramTypes := make([]string, 0, len(params)) for _, p := range params { paramTypes = append(paramTypes, p[1]) @@ -83,16 +72,13 @@ func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeE purpose = docComment } - // Build sections var sections []ExplanationSection - // Purpose section sections = append(sections, ExplanationSection{ Title: "Purpose", Content: purpose, }) - // Parameters section if len(params) > 0 { var paramLines []string for _, p := range params { @@ -105,7 +91,6 @@ func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeE }) } - // Returns section if len(returns) > 0 { sections = append(sections, ExplanationSection{ Title: "Returns", @@ -113,7 +98,6 @@ func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeE }) } - // Control flow section funcBody := extractFuncBody(content, funcDecl, fset) controlFlow := DescribeControlFlow(funcBody) sections = append(sections, ExplanationSection{ @@ -121,14 +105,12 @@ func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeE Content: controlFlow, }) - // Error handling section errHandling := describeErrorHandling(funcBody) sections = append(sections, ExplanationSection{ Title: "Error Handling", Content: errHandling, }) - // Side effects section sideEffects := DetectSideEffects(funcBody) sideEffectStr := "None (pure function)" if len(sideEffects) > 0 { @@ -139,11 +121,9 @@ func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeE Content: sideEffectStr, }) - // Complexity cc := computeCyclomaticComplexity(funcBody) complexity := classifyComplexity(cc) - // Dependencies deps := extractDependencies(funcBody) return &CodeExplanation{ @@ -157,8 +137,6 @@ func (ce *CodeExplainer) ExplainFunction(file, content, funcName string) (*CodeE }, nil } -// ExplainType parses the given file content and generates a structured explanation -// for the named type. func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExplanation, error) { ce.mu.Lock() defer ce.mu.Unlock() @@ -197,7 +175,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla var sections []ExplanationSection - // Doc comment docComment := "" if genDecl.Doc != nil { docComment = strings.TrimSpace(genDecl.Doc.Text()) @@ -215,7 +192,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla Content: purpose, }) - // Fields (for struct types) if st, ok := typeSpec.Type.(*ast.StructType); ok { var fieldLines []string for _, field := range st.Fields.List { @@ -225,7 +201,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla fieldLines = append(fieldLines, fmt.Sprintf("- `%s %s` — %s", name.Name, typeStr, desc)) } if len(field.Names) == 0 { - // Embedded field fieldLines = append(fieldLines, fmt.Sprintf("- `%s` (embedded)", typeStr)) } } @@ -237,7 +212,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla } } - // Interface methods if iface, ok := typeSpec.Type.(*ast.InterfaceType); ok { var methodLines []string for _, method := range iface.Methods.List { @@ -253,7 +227,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla } } - // Methods on this type var methods []string for _, decl := range f.Decls { fd, ok := decl.(*ast.FuncDecl) @@ -262,7 +235,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla } for _, recv := range fd.Recv.List { recvType := explainerExprToString(recv.Type) - // Strip pointer recvType = strings.TrimPrefix(recvType, "*") if recvType == typeName { methods = append(methods, fd.Name.Name) @@ -280,7 +252,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla }) } - // Constructor pattern detection constructor := findConstructor(f, typeName) if constructor != "" { sections = append(sections, ExplanationSection{ @@ -289,7 +260,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla }) } - // Interfaces implemented (heuristic) interfaces := detectImplementedInterfaces(f, typeName, methods) if len(interfaces) > 0 { sections = append(sections, ExplanationSection{ @@ -309,7 +279,6 @@ func (ce *CodeExplainer) ExplainType(file, content, typeName string) (*CodeExpla }, nil } -// ExplainFile generates a structured explanation of an entire file. func (ce *CodeExplainer) ExplainFile(path, content string) (*CodeExplanation, error) { ce.mu.Lock() defer ce.mu.Unlock() @@ -322,7 +291,6 @@ func (ce *CodeExplainer) ExplainFile(path, content string) (*CodeExplanation, er var sections []ExplanationSection - // Package purpose pkgName := f.Name.Name pkgDoc := "" if f.Doc != nil { @@ -337,7 +305,6 @@ func (ce *CodeExplainer) ExplainFile(path, content string) (*CodeExplanation, er Content: pkgPurpose, }) - // Exported API summary var exportedFuncs []string var exportedTypes []string var internalFuncs []string @@ -381,7 +348,6 @@ func (ce *CodeExplainer) ExplainFile(path, content string) (*CodeExplanation, er }) } - // Internal structure if len(internalTypes) > 0 || len(internalFuncs) > 0 { var internalLines []string for _, t := range internalTypes { @@ -396,7 +362,6 @@ func (ce *CodeExplainer) ExplainFile(path, content string) (*CodeExplanation, er }) } - // Key patterns patterns := detectPatterns(content) if len(patterns) > 0 { sections = append(sections, ExplanationSection{ @@ -419,8 +384,6 @@ func (ce *CodeExplainer) ExplainFile(path, content string) (*CodeExplanation, er }, nil } -// InferPurpose infers the purpose of a function from its name, parameter types, -// and return types using heuristic pattern matching. func InferPurpose(name string, params, returns []string) string { lower := strings.ToLower(name) words := splitCamelCase(name) @@ -433,7 +396,6 @@ func InferPurpose(name string, params, returns []string) string { object = strings.Join(words[1:], " ") } - // Check for common verb patterns hasError := containsType(returns, "error") hasBool := containsType(returns, "bool") @@ -493,7 +455,6 @@ func InferPurpose(name string, params, returns []string) string { return fmt.Sprintf("Registers or adds %s", lowerFirst(object)) } - // Fallback with return type context if strings.Contains(lower, "string") && containsType(returns, "string") { return fmt.Sprintf("Converts %s to its string representation", lowerFirst(name)) } @@ -506,8 +467,6 @@ func InferPurpose(name string, params, returns []string) string { return fmt.Sprintf("Performs %s", lowerFirst(strings.Join(words, " "))) } -// DescribeControlFlow analyzes function body text and returns a human-readable -// description of its control flow pattern. func DescribeControlFlow(funcBody string) string { hasFor := regexp.MustCompile(`\bfor\b`).MatchString(funcBody) hasRange := regexp.MustCompile(`\brange\b`).MatchString(funcBody) @@ -559,17 +518,14 @@ func DescribeControlFlow(funcBody string) string { if hasGoto { parts = append(parts, "with goto jumps") } - _ = hasRecursion // used above in broader context + _ = hasRecursion return strings.Join(parts, " ") } -// DetectSideEffects analyzes function body text and returns a list of detected -// side effects such as file I/O, network calls, goroutine spawning, etc. func DetectSideEffects(funcBody string) []string { var effects []string - // File I/O filePatterns := []string{ `os\.Open`, `os\.Create`, `os\.Remove`, `os\.Mkdir`, `os\.ReadFile`, `os\.WriteFile`, `os\.Stat`, @@ -583,7 +539,6 @@ func DetectSideEffects(funcBody string) []string { } } - // Network calls netPatterns := []string{ `http\.Get`, `http\.Post`, `http\.Do`, `net\.Dial`, `net\.Listen`, @@ -597,27 +552,22 @@ func DetectSideEffects(funcBody string) []string { } } - // Goroutine spawning if regexp.MustCompile(`\bgo\s+\w`).MatchString(funcBody) { effects = append(effects, "Goroutine spawning") } - // Mutex locking if regexp.MustCompile(`\.Lock\(\)|\.RLock\(\)`).MatchString(funcBody) { effects = append(effects, "Mutex locking") } - // Channel operations if regexp.MustCompile(`<-\s*\w|(\w+)\s*<-`).MatchString(funcBody) { effects = append(effects, "Channel communication") } - // Global/package-level mutation if regexp.MustCompile(`\b(os\.Setenv|os\.Exit|log\.Fatal)`).MatchString(funcBody) { effects = append(effects, "Process-level side effects") } - // Database operations dbPatterns := []string{ `\.Exec\(`, `\.Query\(`, `\.QueryRow\(`, `\.Begin\(`, `\.Commit\(`, `\.Rollback\(`, @@ -629,7 +579,6 @@ func DetectSideEffects(funcBody string) []string { } } - // Stdout/stderr writes if regexp.MustCompile(`fmt\.Print|fmt\.Fprint|os\.Stdout|os\.Stderr`).MatchString(funcBody) { effects = append(effects, "Standard output") } @@ -637,7 +586,6 @@ func DetectSideEffects(funcBody string) []string { return effects } -// FormatExplanation renders a CodeExplanation into a human-readable markdown-style string. func FormatExplanation(exp *CodeExplanation) string { var sb strings.Builder @@ -661,8 +609,6 @@ func FormatExplanation(exp *CodeExplanation) string { return sb.String() } -// --- Helper functions --- - func explainerExtractParams(fd *ast.FuncDecl) [][2]string { var params [][2]string if fd.Type.Params == nil { @@ -703,11 +649,9 @@ func explainerExtractDocComment(fd *ast.FuncDecl) string { return "" } text := strings.TrimSpace(fd.Doc.Text()) - // Remove the leading function name if the doc starts with it if strings.HasPrefix(text, fd.Name.Name+" ") { text = text[len(fd.Name.Name)+1:] } - // Take first sentence if idx := strings.Index(text, ". "); idx > 0 { return text[:idx+1] } @@ -895,7 +839,7 @@ func containsType(types []string, target string) bool { } func computeCyclomaticComplexity(body string) int { - cc := 1 // base complexity + cc := 1 patterns := []string{ `\bif\b`, `\belse if\b`, `\bfor\b`, `\bcase\b`, `&&`, `\|\|`, `\bselect\b`, @@ -951,7 +895,6 @@ func describeErrorHandling(body string) string { func extractDependencies(body string) []string { var deps []string - // Look for package.Function patterns re := regexp.MustCompile(`\b([a-z][a-z0-9]+)\.\w+`) matches := re.FindAllStringSubmatch(body, -1) seen := map[string]bool{} @@ -1019,7 +962,6 @@ func findConstructor(f *ast.File, typeName string) string { func detectImplementedInterfaces(f *ast.File, typeName string, methods []string) []string { var ifaces []string - // Common Go interfaces by method signature methodSet := map[string]bool{} for _, m := range methods { methodSet[m] = true diff --git a/internal/engine/code_explainer_test.go b/internal/engine/code/code_explainer_test.go similarity index 99% rename from internal/engine/code_explainer_test.go rename to internal/engine/code/code_explainer_test.go index 107e42dd..687aedff 100644 --- a/internal/engine/code_explainer_test.go +++ b/internal/engine/code/code_explainer_test.go @@ -1,4 +1,4 @@ -package engine +package code import ( "strings" diff --git a/internal/engine/code_lens.go b/internal/engine/code/code_lens.go similarity index 77% rename from internal/engine/code_lens.go rename to internal/engine/code/code_lens.go index bd793ca9..b66a2c67 100644 --- a/internal/engine/code_lens.go +++ b/internal/engine/code/code_lens.go @@ -1,6 +1,7 @@ -package engine +package code import ( + "context" "fmt" "os/exec" "regexp" @@ -11,27 +12,22 @@ import ( "unicode" ) -// CodeLens represents an inline annotation for a specific line in a file. type CodeLens struct { File string Line int Label string - Category string // "test_status", "complexity", "ownership", "age", "references", "coverage" + Category string Command string Tooltip string } -// LensGenerator is a function that produces code lenses for a given file and its content. type LensGenerator func(file, content string) []CodeLens -// CodeLensProvider manages a set of lens generators and produces annotations. type CodeLensProvider struct { Providers map[string]LensGenerator mu sync.RWMutex } -// NewCodeLensProvider creates a CodeLensProvider with built-in generators for -// test status, complexity, references, age, and coverage. func NewCodeLensProvider() *CodeLensProvider { p := &CodeLensProvider{ Providers: make(map[string]LensGenerator), @@ -44,14 +40,12 @@ func NewCodeLensProvider() *CodeLensProvider { return p } -// Register adds or replaces a named lens generator. func (p *CodeLensProvider) Register(name string, generator LensGenerator) { p.mu.Lock() defer p.mu.Unlock() p.Providers[name] = generator } -// Generate runs all registered providers and returns merged lenses sorted by line. func (p *CodeLensProvider) Generate(file, content string) []CodeLens { p.mu.RLock() defer p.mu.RUnlock() @@ -70,7 +64,6 @@ func (p *CodeLensProvider) Generate(file, content string) []CodeLens { return all } -// FilterByCategory returns only lenses matching the given category. func FilterByCategory(lenses []CodeLens, category string) []CodeLens { var result []CodeLens for _, l := range lenses { @@ -81,7 +74,6 @@ func FilterByCategory(lenses []CodeLens, category string) []CodeLens { return result } -// FormatLenses produces a human-readable summary of code lenses. func FormatLenses(file string, lenses []CodeLens) string { if len(lenses) == 0 { return fmt.Sprintf("Code Lenses for %s:\n (none)", file) @@ -94,11 +86,8 @@ func FormatLenses(file string, lenses []CodeLens) string { return strings.TrimRight(b.String(), "\n") } -// ---------- Built-in Generators ---------- - var testFuncRe = regexp.MustCompile(`(?m)^func\s+(Test\w+)\s*\(`) -// GenerateTestLens finds test functions and annotates them with last known status. func GenerateTestLens(file, content string) []CodeLens { if !strings.HasSuffix(file, "_test.go") { return nil @@ -134,21 +123,16 @@ func GenerateTestLens(file, content string) []CodeLens { return lenses } -// lookupTestStatus attempts to determine the last test result. -// In a real implementation this would query a test result cache. func lookupTestStatus(file, funcName string) string { - // Try running the test quickly to determine status dir := file if idx := strings.LastIndex(file, "/"); idx >= 0 { dir = file[:idx] } - cmd := exec.Command("go", "test", "-run", "^"+funcName+"$", "-count=1", "-timeout=10s", dir) + cmd := exec.CommandContext(context.Background(), "go", "test", "-run", "^"+funcName+"$", "-count=1", "-timeout=10s", dir) err := cmd.Run() if err == nil { return "PASS" } - // If the command fails it could be a real failure or an environment issue. - // Distinguish by exit code when possible. if exitErr, ok := err.(*exec.ExitError); ok { if exitErr.ExitCode() == 1 { return "FAIL" @@ -159,8 +143,6 @@ func lookupTestStatus(file, funcName string) string { var funcDeclRe = regexp.MustCompile(`(?m)^func\s+(?:\(\s*\w+\s+\*?\w+\s*\)\s+)?(\w+)\s*\(`) -// GenerateComplexityLens calculates cyclomatic complexity for each function and -// annotates functions that exceed a threshold of 5. func GenerateComplexityLens(file, content string) []CodeLens { const threshold = 5 var lenses []CodeLens @@ -184,7 +166,6 @@ func GenerateComplexityLens(file, content string) []CodeLens { return lenses } -// GenerateReferenceLens counts how many times each exported symbol is referenced. func GenerateReferenceLens(file, content string) []CodeLens { var lenses []CodeLens @@ -215,7 +196,6 @@ func GenerateReferenceLens(file, content string) []CodeLens { return lenses } -// GenerateAgeLens uses git blame to determine how recently each function was modified. func GenerateAgeLens(file, content string) []CodeLens { var lenses []CodeLens @@ -246,7 +226,6 @@ func GenerateAgeLens(file, content string) []CodeLens { return lenses } -// GenerateCoverageLens produces coverage annotations per function if coverage data is available. func GenerateCoverageLens(file, content string) []CodeLens { var lenses []CodeLens @@ -282,8 +261,6 @@ func GenerateCoverageLens(file, content string) []CodeLens { return lenses } -// ---------- Internal Helpers ---------- - type funcInfo struct { name string line int @@ -296,7 +273,6 @@ type symbolInfo struct { line int } -// extractFunctions parses Go source and extracts function declarations with their bodies. func extractFunctions(content string) []funcInfo { var funcs []funcInfo lines := strings.Split(content, "\n") @@ -336,10 +312,8 @@ func extractFunctions(content string) []funcInfo { return funcs } -// calculateCyclomaticComplexity computes a simplified cyclomatic complexity for a function body. func calculateCyclomaticComplexity(body string) int { cc := 1 - // Count decision points decisionPatterns := []string{ `\bif\b`, `\belse if\b`, @@ -349,7 +323,6 @@ func calculateCyclomaticComplexity(body string) int { `\b\|\|\b`, `\bselect\b`, } - // Use simpler token-based counting words := strings.Fields(body) for _, w := range words { switch w { @@ -357,14 +330,12 @@ func calculateCyclomaticComplexity(body string) int { cc++ } } - // Count && and || in the body _ = decisionPatterns cc += strings.Count(body, "&&") cc += strings.Count(body, "||") return cc } -// extractExportedSymbols finds exported function and type declarations. func extractExportedSymbols(content string) []symbolInfo { var symbols []symbolInfo lines := strings.Split(content, "\n") @@ -377,7 +348,6 @@ func extractExportedSymbols(content string) []symbolInfo { } continue } - // Check for exported type declarations trimmed := strings.TrimSpace(line) if strings.HasPrefix(trimmed, "type ") { parts := strings.Fields(trimmed) @@ -392,25 +362,21 @@ func extractExportedSymbols(content string) []symbolInfo { return symbols } -// countReferences counts occurrences of a symbol in the content (excluding its declaration). func countReferences(content, symbol string) int { - // Count all occurrences minus the declaration itself count := strings.Count(content, symbol) if count > 0 { - count-- // exclude the declaration + count-- } return count } -// blameEntry holds parsed git blame information for a single line. type blameEntry struct { line int date time.Time } -// getGitBlame runs git blame and returns parsed entries. func getGitBlame(file string) []blameEntry { - cmd := exec.Command("git", "blame", "--porcelain", file) + cmd := exec.CommandContext(context.Background(), "git", "blame", "--porcelain", file) out, err := cmd.Output() if err != nil { return nil @@ -431,7 +397,6 @@ func getGitBlame(file string) []blameEntry { }) } } - // Track the line number from the header parts := strings.Fields(l) if len(parts) >= 3 && len(parts[0]) == 40 { _, _ = fmt.Sscanf(parts[2], "%d", &lineNum) @@ -440,7 +405,6 @@ func getGitBlame(file string) []blameEntry { return entries } -// lookupAge finds the most recent modification date for lines near the given line. func lookupAge(entries []blameEntry, line int) string { if len(entries) == 0 { return "" @@ -456,7 +420,6 @@ func lookupAge(entries []blameEntry, line int) string { } if newest.IsZero() { - // Fall back to closest entry for _, e := range entries { if e.date.After(newest) { newest = e.date @@ -471,7 +434,6 @@ func lookupAge(entries []blameEntry, line int) string { return lensFormatDuration(time.Since(newest)) } -// lensFormatDuration formats a duration into a human-readable age string. func lensFormatDuration(d time.Duration) string { days := int(d.Hours() / 24) if days == 0 { @@ -493,7 +455,6 @@ func lensFormatDuration(d time.Duration) string { return fmt.Sprintf("%dy", days/365) } -// isRecent returns true if the age string indicates a recent modification (< 7 days). func isRecent(age string) bool { if age == "just now" { return true @@ -509,17 +470,14 @@ func isRecent(age string) bool { return false } -// loadCoverageData attempts to load coverage information for the given file. -// Returns a map of function name to coverage percentage, or nil if unavailable. func loadCoverageData(file string) map[string]float64 { - // Look for a coverage.out file in the same directory dir := file if idx := strings.LastIndex(file, "/"); idx >= 0 { dir = file[:idx] } coverFile := dir + "/coverage.out" - cmd := exec.Command("cat", coverFile) + cmd := exec.CommandContext(context.Background(), "cat", coverFile) out, err := cmd.Output() if err != nil { return nil @@ -528,7 +486,6 @@ func loadCoverageData(file string) map[string]float64 { return parseCoverageProfile(string(out), file) } -// parseCoverageProfile parses Go coverage profile output and returns per-function coverage. func parseCoverageProfile(profile, file string) map[string]float64 { result := make(map[string]float64) lines := strings.Split(profile, "\n") @@ -545,7 +502,6 @@ func parseCoverageProfile(profile, file string) map[string]float64 { if !strings.Contains(line, ":") || strings.HasPrefix(line, "mode:") { continue } - // Format: file:startLine.startCol,endLine.endCol statements count parts := strings.Fields(line) if len(parts) < 3 { continue @@ -554,7 +510,6 @@ func parseCoverageProfile(profile, file string) map[string]float64 { if !strings.Contains(loc, file) { continue } - // Parse start and end lines colonIdx := strings.LastIndex(loc, ":") if colonIdx < 0 { continue @@ -576,7 +531,6 @@ func parseCoverageProfile(profile, file string) map[string]float64 { return nil } - // This is a simplified mapping; full implementation would correlate with function line ranges _ = result _ = blocks return nil diff --git a/internal/engine/code_lens_test.go b/internal/engine/code/code_lens_test.go similarity index 99% rename from internal/engine/code_lens_test.go rename to internal/engine/code/code_lens_test.go index 9c2b3bb2..456bc0e5 100644 --- a/internal/engine/code_lens_test.go +++ b/internal/engine/code/code_lens_test.go @@ -1,4 +1,4 @@ -package engine +package code import ( "strings" diff --git a/internal/engine/code_reexports.go b/internal/engine/code_reexports.go new file mode 100644 index 00000000..c00d5c5f --- /dev/null +++ b/internal/engine/code_reexports.go @@ -0,0 +1,34 @@ +package engine + +import "github.com/GrayCodeAI/hawk/internal/engine/code" + +type ( + CodeSnippet = code.CodeSnippet + CodeContext = code.CodeContext + ContextExtractor = code.ContextExtractor + CodeLens = code.CodeLens + LensGenerator = code.LensGenerator + CodeLensProvider = code.CodeLensProvider + CodeAction = code.CodeAction + ActionDetector = code.ActionDetector + ActionRule = code.ActionRule + CodeExplanation = code.CodeExplanation + ExplanationSection = code.ExplanationSection + CodeExplainer = code.CodeExplainer +) + +func NewContextExtractor(projectDir string, maxTokens int) *ContextExtractor { + return code.NewContextExtractor(projectDir, maxTokens) +} +func FormatContext(ctx *CodeContext) string { return code.FormatContext(ctx) } +func NewCodeLensProvider() *CodeLensProvider { return code.NewCodeLensProvider() } +func NewActionDetector() *ActionDetector { return code.NewActionDetector() } +func NewCodeExplainer() *CodeExplainer { return code.NewCodeExplainer() } +func FormatExplanation(exp *CodeExplanation) string { return code.FormatExplanation(exp) } +func FormatSuggestions(actions []CodeAction, max int) string { + return code.FormatSuggestions(actions, max) +} + +func ApplyFix(action CodeAction, content string) (string, error) { + return code.ApplyFix(action, content) +} diff --git a/internal/engine/compact/aliases.go b/internal/engine/compact/aliases.go index 46d1b1e5..25f50c3e 100644 --- a/internal/engine/compact/aliases.go +++ b/internal/engine/compact/aliases.go @@ -1,104 +1,37 @@ -// Package compact is the Stage-1 namespace for the engine package's -// compaction-related types and functions. It currently re-exports the -// canonical symbols from package engine as type aliases and var aliases; -// no implementation lives here yet. -// -// New code in hawk should import this package instead of reaching into -// engine for compact symbols. When Stage 2 of the engine split lands, -// the implementation will move into this directory and the engine package -// will become the alias re-exporter (the inverse of the current direction). -// -// See REFACTOR_PLAN.md at the engine package root for the full split plan. +// Package compact provides compaction strategies, types, and helpers +// for context-window management. See ../REFACTOR_PLAN.md. package compact -import "github.com/GrayCodeAI/hawk/internal/engine" - -// Strategy is the contract every compaction strategy implements. -type Strategy = engine.CompactStrategy - // Result is the outcome of a compaction pass. -type Result = engine.CompactResult +type Result = CompactResult // Config tunes compaction behaviour. -type Config = engine.CompactConfig +type Config = CompactConfig // Variant identifies which compaction prompt variant to render. -type Variant = engine.CompactVariant - -// Registry stores strategies by name for runtime selection. -type Registry = engine.StrategyRegistry - -// AutoCompactor decides when and how to compact based on context pressure. -type AutoCompactor = engine.AutoCompactor - -// SmartCompactStrategy is the default LLM-driven compactor. -type SmartCompactStrategy = engine.SmartCompactStrategy - -// TruncateStrategy drops oldest messages first; cheap but lossy. -type TruncateStrategy = engine.TruncateStrategy - -// MicroCompactStrategy collapses adjacent short messages. -type MicroCompactStrategy = engine.MicroCompactStrategy - -// MicroCompactConfig tunes the micro-compactor. -type MicroCompactConfig = engine.MicroCompactConfig - -// SessionMemoryStrategy distils conversation into a compact memory blob. -type SessionMemoryStrategy = engine.SessionMemoryStrategy - -// SessionMemoryConfig tunes the session-memory compactor. -type SessionMemoryConfig = engine.SessionMemoryConfig - -// APICompactStrategy compacts at the API-call boundary (provider-specific). -type APICompactStrategy = engine.APICompactStrategy - -// APICompactConfig tunes the API-boundary compactor. -type APICompactConfig = engine.APICompactConfig - -// FileTracker remembers which files have been read/modified during a session; -// used by file-aware compactors to keep the relevant ones. -type FileTracker = engine.FileTracker - -// --------------------------------------------------------------------------- -// Constructors / defaults. -// --------------------------------------------------------------------------- - -// NewAutoCompactor constructs an auto-compactor with the given config. -func NewAutoCompactor(config Config) *AutoCompactor { - return engine.NewAutoCompactor(config) -} - -// NewFileTracker returns an empty file tracker. -func NewFileTracker() *FileTracker { - return engine.NewFileTracker() -} +type Variant = CompactVariant // DefaultConfig returns the default top-level compaction config. func DefaultConfig() Config { - return engine.DefaultCompactConfig() + return DefaultCompactConfig() } // DefaultMicroConfig returns the default micro-compactor config. func DefaultMicroConfig() MicroCompactConfig { - return engine.DefaultMicroCompactConfig() -} - -// DefaultSessionMemoryConfig returns the default session-memory compactor config. -func DefaultSessionMemoryConfig() SessionMemoryConfig { - return engine.DefaultSessionMemoryConfig() + return DefaultMicroCompactConfig() } // DefaultAPIConfig returns the default API-boundary compactor config. func DefaultAPIConfig() APICompactConfig { - return engine.DefaultAPICompactConfig() + return DefaultAPICompactConfig() } // BuildPrompt renders the compaction prompt template for the given variant. func BuildPrompt(variant Variant) string { - return engine.BuildCompactPrompt(variant) + return BuildCompactPrompt(variant) } // FormatSummary normalises a raw LLM summary for display. func FormatSummary(raw string) string { - return engine.FormatCompactSummary(raw) + return FormatCompactSummary(raw) } diff --git a/internal/engine/compact_api.go b/internal/engine/compact/api.go similarity index 59% rename from internal/engine/compact_api.go rename to internal/engine/compact/api.go index 4a8491b7..d1ab6449 100644 --- a/internal/engine/compact_api.go +++ b/internal/engine/compact/api.go @@ -1,38 +1,11 @@ -package engine +package compact import ( - "context" - "github.com/GrayCodeAI/eyrie/client" -) - -// APICompactStrategy uses API-level context edits to clear thinking blocks -// and old tool inputs without mutating local message content. -type APICompactStrategy struct{} -func (s *APICompactStrategy) Name() string { return "api_compact" } - -func (s *APICompactStrategy) ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool { - if tokenCount < 180000 { - return false - } - return countClearableToolResults(msgs) > 5 -} - -func (s *APICompactStrategy) Compact(ctx context.Context, sess *Session) (*CompactResult, error) { - tokensBefore := EstimateTokens(sess.messages) - result := apiCompactMessages(sess.messages, DefaultAPICompactConfig()) - tokensAfter := EstimateTokens(result) - - return &CompactResult{ - Messages: result, - TokensBefore: tokensBefore, - TokensAfter: tokensAfter, - Strategy: "api_compact", - }, nil -} + "github.com/GrayCodeAI/hawk/internal/engine/token" +) -// APICompactConfig controls API-level compaction. type APICompactConfig struct { TriggerTokens int KeepTargetTokens int @@ -41,7 +14,6 @@ type APICompactConfig struct { PreserveMutating bool } -// DefaultAPICompactConfig returns defaults matching the archive. func DefaultAPICompactConfig() APICompactConfig { return APICompactConfig{ TriggerTokens: 180000, @@ -52,17 +24,14 @@ func DefaultAPICompactConfig() APICompactConfig { } } -// mutatingTools are tools whose inputs should not be cleared (they modify state). var mutatingTools = map[string]bool{ "Edit": true, "Write": true, "NotebookEdit": true, } -// apiCompactMessages clears thinking content and old tool inputs/results -// for non-mutating tools when context is very large. -func apiCompactMessages(msgs []client.EyrieMessage, cfg APICompactConfig) []client.EyrieMessage { - totalTokens := EstimateTokens(msgs) +func APICompactMessages(msgs []client.EyrieMessage, cfg APICompactConfig) []client.EyrieMessage { + totalTokens := token.EstimateTokens(msgs) if totalTokens < cfg.TriggerTokens { return msgs } @@ -82,9 +51,9 @@ func apiCompactMessages(msgs []client.EyrieMessage, cfg APICompactConfig) []clie m := &result[i] if cfg.ClearThinking && m.Role == "assistant" && isThinkingMessage(*m) { - before := estimateMessageTokens(*m) + before := token.EstimateMessageTokens(*m) m.Content = "[Thinking content cleared]" - freed += before - estimateMessageTokens(*m) + freed += before - token.EstimateMessageTokens(*m) continue } @@ -109,7 +78,7 @@ func apiCompactMessages(msgs []client.EyrieMessage, cfg APICompactConfig) []clie } if m.ToolResult != nil && m.ToolResult.Content != "[Old tool result content cleared]" { - toolName := toolNameForResult(*m, result) + toolName := ToolNameForResult(*m, result) if !mutatingTools[toolName] { before := len(m.ToolResult.Content) / 4 if before > 100 { @@ -127,11 +96,11 @@ func apiCompactMessages(msgs []client.EyrieMessage, cfg APICompactConfig) []clie return result } -func countClearableToolResults(msgs []client.EyrieMessage) int { +func CountClearableToolResults(msgs []client.EyrieMessage) int { count := 0 for _, m := range msgs { if m.ToolResult != nil && m.ToolResult.Content != "[Old tool result content cleared]" { - toolName := toolNameForResult(m, msgs) + toolName := ToolNameForResult(m, msgs) if !mutatingTools[toolName] { count++ } diff --git a/internal/engine/compact_files.go b/internal/engine/compact/files.go similarity index 70% rename from internal/engine/compact_files.go rename to internal/engine/compact/files.go index ca7c6a55..7e5ebe6c 100644 --- a/internal/engine/compact_files.go +++ b/internal/engine/compact/files.go @@ -1,4 +1,4 @@ -package engine +package compact import ( "fmt" @@ -9,14 +9,11 @@ import ( "github.com/GrayCodeAI/eyrie/client" ) -// FileTracker maintains a cumulative record of files read and modified -// across the session lifetime, persisting through compactions. type FileTracker struct { - ReadFiles map[string]int // path -> count of reads - ModifiedFiles map[string]int // path -> count of modifications + ReadFiles map[string]int + ModifiedFiles map[string]int } -// NewFileTracker creates a new FileTracker with initialized maps. func NewFileTracker() *FileTracker { return &FileTracker{ ReadFiles: make(map[string]int), @@ -24,7 +21,6 @@ func NewFileTracker() *FileTracker { } } -// RecordRead notes that a file was read. func (ft *FileTracker) RecordRead(path string) { if path == "" { return @@ -32,7 +28,6 @@ func (ft *FileTracker) RecordRead(path string) { ft.ReadFiles[path]++ } -// RecordModified notes that a file was modified. func (ft *FileTracker) RecordModified(path string) { if path == "" { return @@ -40,8 +35,6 @@ func (ft *FileTracker) RecordModified(path string) { ft.ModifiedFiles[path]++ } -// ExtractFromMessages scans messages for tool calls and extracts file paths. -// Looks at Read tool calls for reads, Write/Edit for modifications. func (ft *FileTracker) ExtractFromMessages(messages []client.EyrieMessage) { for _, msg := range messages { if msg.Role != "assistant" { @@ -63,12 +56,6 @@ func (ft *FileTracker) ExtractFromMessages(messages []client.EyrieMessage) { } } -// FormatForSummary returns a text block suitable for injection into compaction summaries. -// Format: -// -// Read: path1.go (3x), path2.go (1x) -// Modified: path3.go (2x), path4.go (1x) -// func (ft *FileTracker) FormatForSummary() string { if len(ft.ReadFiles) == 0 && len(ft.ModifiedFiles) == 0 { return "" @@ -93,8 +80,6 @@ func (ft *FileTracker) FormatForSummary() string { return sb.String() } -// ParseFromSummary extracts previously tracked files from a compaction summary -// containing blocks, merging with current state. func (ft *FileTracker) ParseFromSummary(summary string) { start := strings.Index(summary, "") end := strings.Index(summary, "") @@ -115,7 +100,6 @@ func (ft *FileTracker) ParseFromSummary(summary string) { } } -// Merge combines another FileTracker's data into this one. func (ft *FileTracker) Merge(other *FileTracker) { if other == nil { return @@ -128,9 +112,7 @@ func (ft *FileTracker) Merge(other *FileTracker) { } } -// formatPathCounts formats a map of path->count into "path1.go (3x), path2.go (1x)" style. func formatPathCounts(m map[string]int) string { - // Sort paths for deterministic output paths := make([]string, 0, len(m)) for p := range m { paths = append(paths, p) @@ -144,7 +126,6 @@ func formatPathCounts(m map[string]int) string { return strings.Join(parts, ", ") } -// parsePathLine parses "path1.go (3x), path2.go (1x)" into the target map. func parsePathLine(line string, target map[string]int) { entries := strings.Split(line, ", ") for _, entry := range entries { @@ -153,10 +134,8 @@ func parsePathLine(line string, target map[string]int) { continue } - // Parse "path (Nx)" format parenIdx := strings.LastIndex(entry, " (") if parenIdx < 0 { - // No count, treat as single occurrence target[entry]++ continue } @@ -173,3 +152,40 @@ func parsePathLine(line string, target map[string]int) { target[path] += count } } + +func canonicalToolName(name string) string { + switch strings.ToLower(name) { + case "bash": + return "Bash" + case "file_read", "read": + return "Read" + case "file_write", "write": + return "Write" + case "file_edit", "edit": + return "Edit" + case "ls": + return "LS" + case "glob": + return "Glob" + case "grep": + return "Grep" + case "web_fetch", "webfetch": + return "WebFetch" + case "web_search", "websearch": + return "WebSearch" + case "tool_search", "toolsearch": + return "ToolSearch" + default: + return name + } +} + +func pathArgument(args map[string]interface{}) (string, bool) { + if p, ok := args["path"].(string); ok && p != "" { + return p, true + } + if p, ok := args["file_path"].(string); ok && p != "" { + return p, true + } + return "", false +} diff --git a/internal/engine/compact/files_test.go b/internal/engine/compact/files_test.go new file mode 100644 index 00000000..76082d2f --- /dev/null +++ b/internal/engine/compact/files_test.go @@ -0,0 +1,228 @@ +package compact + +import ( + "strings" + "testing" + + "github.com/GrayCodeAI/eyrie/client" +) + +func TestFileTracker_NewFileTracker(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + if ft == nil { + t.Fatal("NewFileTracker returned nil") + } + if ft.ReadFiles == nil || ft.ModifiedFiles == nil { + t.Error("maps should be initialized") + } + if len(ft.ReadFiles) != 0 || len(ft.ModifiedFiles) != 0 { + t.Error("new tracker should have empty maps") + } +} + +func TestFileTracker_RecordRead(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + + ft.RecordRead("main.go") + ft.RecordRead("main.go") + ft.RecordRead("config.go") + + if ft.ReadFiles["main.go"] != 2 { + t.Errorf("expected 2 reads for main.go, got %d", ft.ReadFiles["main.go"]) + } + if ft.ReadFiles["config.go"] != 1 { + t.Errorf("expected 1 read for config.go, got %d", ft.ReadFiles["config.go"]) + } + if len(ft.ReadFiles) != 2 { + t.Errorf("expected 2 entries, got %d", len(ft.ReadFiles)) + } +} + +func TestFileTracker_RecordRead_Empty(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.RecordRead("") // should not panic + if len(ft.ReadFiles) != 0 { + t.Error("empty path should not be recorded") + } +} + +func TestFileTracker_RecordModified(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.RecordModified("edit.go") + if ft.ModifiedFiles["edit.go"] != 1 { + t.Errorf("expected 1 mod for edit.go, got %d", ft.ModifiedFiles["edit.go"]) + } +} + +func TestFileTracker_RecordModified_Empty(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.RecordModified("") // should not panic +} + +func TestFileTracker_ExtractFromMessages(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + messages := []client.EyrieMessage{ + { + Role: "assistant", + ToolUse: []client.ToolCall{ + {Name: "Read", Arguments: map[string]interface{}{"path": "main.go"}}, + {Name: "Write", Arguments: map[string]interface{}{"file_path": "output.go"}}, + }, + }, + } + ft.ExtractFromMessages(messages) + if ft.ReadFiles["main.go"] != 1 { + t.Errorf("expected 1 read for main.go, got %d", ft.ReadFiles["main.go"]) + } + if ft.ModifiedFiles["output.go"] != 1 { + t.Errorf("expected 1 mod for output.go, got %d", ft.ModifiedFiles["output.go"]) + } +} + +func TestFileTracker_ExtractFromMessages_SkipNonAssistant(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + messages := []client.EyrieMessage{ + {Role: "user", ToolUse: []client.ToolCall{{Name: "Read", Arguments: map[string]interface{}{"path": "x.go"}}}}, + } + ft.ExtractFromMessages(messages) + if len(ft.ReadFiles) != 0 { + t.Error("should skip non-assistant messages") + } +} + +func TestFileTracker_Merge(t *testing.T) { + t.Parallel() + a := NewFileTracker() + a.RecordRead("main.go") + a.RecordRead("main.go") + + b := NewFileTracker() + b.RecordRead("main.go") + b.RecordRead("config.go") + + a.Merge(b) + if a.ReadFiles["main.go"] != 3 { + t.Errorf("expected 3 reads for main.go, got %d", a.ReadFiles["main.go"]) + } + if a.ReadFiles["config.go"] != 1 { + t.Errorf("expected 1 read for config.go, got %d", a.ReadFiles["config.go"]) + } +} + +func TestFileTracker_Merge_Nil(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.Merge(nil) // should not panic +} + +func TestFileTracker_FormatForSummary(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.RecordRead("a.go") + ft.RecordModified("b.go") + result := ft.FormatForSummary() + if !strings.Contains(result, "") { + t.Error("expected tag") + } + if !strings.Contains(result, "Read:") { + t.Error("expected Read section") + } + if !strings.Contains(result, "Modified:") { + t.Error("expected Modified section") + } + if !strings.Contains(result, "") { + t.Error("expected closing tag") + } +} + +func TestFileTracker_FormatForSummary_Empty(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + if ft.FormatForSummary() != "" { + t.Error("empty tracker should return empty string") + } +} + +func TestFileTracker_ParseFromSummary(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + summary := ` +Read: main.go (2x), config.go (1x) +Modified: edit.go (1x) +` + ft.ParseFromSummary(summary) + if ft.ReadFiles["main.go"] != 2 { + t.Errorf("expected 2 reads for main.go, got %d", ft.ReadFiles["main.go"]) + } + if ft.ReadFiles["config.go"] != 1 { + t.Errorf("expected 1 read for config.go, got %d", ft.ReadFiles["config.go"]) + } + if ft.ModifiedFiles["edit.go"] != 1 { + t.Errorf("expected 1 mod for edit.go, got %d", ft.ModifiedFiles["edit.go"]) + } +} + +func TestFileTracker_ParseFromSummary_NoMatch(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.ParseFromSummary("no tags here") // should not panic +} + +func TestFileTracker_ParseFromSummary_AbsentBlock(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.ParseFromSummary("\n") // empty block +} + +func TestFileTracker_CanonicalToolName(t *testing.T) { + tests := []struct{ input, expected string }{ + {"bash", "Bash"}, + {"Read", "Read"}, + {"file_read", "Read"}, + {"file_write", "Write"}, + {"EDIT", "Edit"}, + {"ls", "LS"}, + {"Glob", "Glob"}, + {"web_fetch", "WebFetch"}, + {"web_search", "WebSearch"}, + {"Tool_Search", "ToolSearch"}, + {"Unknown", "Unknown"}, + } + for _, tt := range tests { + got := canonicalToolName(tt.input) + if got != tt.expected { + t.Errorf("canonicalToolName(%q) = %q, want %q", tt.input, got, tt.expected) + } + } +} + +func TestFileTracker_CumulativeAcrossOperations(t *testing.T) { + t.Parallel() + ft := NewFileTracker() + ft.RecordRead("a.go") + ft.RecordRead("b.go") + ft.RecordModified("a.go") + + ft.RecordRead("a.go") + ft.RecordModified("b.go") + + if ft.ReadFiles["a.go"] != 2 { + t.Errorf("a.go reads: expected 2, got %d", ft.ReadFiles["a.go"]) + } + if ft.ReadFiles["b.go"] != 1 { + t.Errorf("b.go reads: expected 1, got %d", ft.ReadFiles["b.go"]) + } + if ft.ModifiedFiles["a.go"] != 1 { + t.Errorf("a.go mods: expected 1, got %d", ft.ModifiedFiles["a.go"]) + } + if ft.ModifiedFiles["b.go"] != 1 { + t.Errorf("b.go mods: expected 1, got %d", ft.ModifiedFiles["b.go"]) + } +} diff --git a/internal/engine/compact/micro.go b/internal/engine/compact/micro.go new file mode 100644 index 00000000..712fc463 --- /dev/null +++ b/internal/engine/compact/micro.go @@ -0,0 +1,94 @@ +package compact + +import ( + "time" + + "github.com/GrayCodeAI/eyrie/client" +) + +type MicroCompactConfig struct { + CompactableTools map[string]bool + TimeGapMins float64 + KeepRecent int +} + +func DefaultMicroCompactConfig() MicroCompactConfig { + return MicroCompactConfig{ + CompactableTools: compactableTools, + TimeGapMins: 60, + KeepRecent: 3, + } +} + +type resultInfo struct { + index int + toolName string +} + +func MicrocompactMessages(msgs []client.EyrieMessage, cfg MicroCompactConfig) []client.EyrieMessage { + var compactableResults []resultInfo + for i, m := range msgs { + if m.ToolResult == nil { + continue + } + toolName := ToolNameForResult(m, msgs) + if cfg.CompactableTools[toolName] { + compactableResults = append(compactableResults, resultInfo{index: i, toolName: toolName}) + } + } + + if len(compactableResults) <= cfg.KeepRecent { + return msgs + } + + toClear := len(compactableResults) - cfg.KeepRecent + clearSet := make(map[int]bool, toClear) + for i := 0; i < toClear; i++ { + clearSet[compactableResults[i].index] = true + } + + result := make([]client.EyrieMessage, len(msgs)) + copy(result, msgs) + for idx := range clearSet { + result[idx] = client.EyrieMessage{ + Role: result[idx].Role, + ToolResult: &client.ToolResult{ + ToolUseID: result[idx].ToolResult.ToolUseID, + Content: "[Old tool result content cleared]", + IsError: result[idx].ToolResult.IsError, + }, + } + } + + return result +} + +func ToolNameForResult(m client.EyrieMessage, msgs []client.EyrieMessage) string { + if m.ToolResult == nil { + return "" + } + targetID := m.ToolResult.ToolUseID + for i := len(msgs) - 1; i >= 0; i-- { + for _, tc := range msgs[i].ToolUse { + if tc.ID == targetID { + return tc.Name + } + } + } + return "" +} + +func HasTimeGap(msgs []client.EyrieMessage, threshold time.Duration) bool { + lastTextIdx := -1 + for i := len(msgs) - 1; i >= 0; i-- { + if HasTextContent(msgs[i]) && msgs[i].Role == "assistant" { + lastTextIdx = i + break + } + } + if lastTextIdx < 0 { + return false + } + messagesSinceText := len(msgs) - lastTextIdx - 1 + return messagesSinceText > 20 || threshold == 0 +} diff --git a/internal/engine/compact_prompt.go b/internal/engine/compact/prompt.go similarity index 84% rename from internal/engine/compact_prompt.go rename to internal/engine/compact/prompt.go index 4a9e2b03..82e49672 100644 --- a/internal/engine/compact_prompt.go +++ b/internal/engine/compact/prompt.go @@ -1,7 +1,4 @@ -package engine - -// CompactPrompt provides the system and user prompts used during LLM-based compaction. -// Ported from hawk-archive src/services/compact/prompt.ts. +package compact const noToolsPreamble = `CRITICAL: Respond with TEXT ONLY. Do NOT call any tools. @@ -70,7 +67,6 @@ const summaryTemplate = `Now provide your summary inside tags using EX ## Next Step - [based on most recent user messages, what should happen next — include direct quotes if user gave specific direction]` -// BuildCompactPrompt constructs the full compaction prompt for LLM-based summarization. func BuildCompactPrompt(variant CompactVariant) string { var analysis string switch variant { @@ -82,32 +78,27 @@ func BuildCompactPrompt(variant CompactVariant) string { return noToolsPreamble + analysis + "\n\n" + summaryTemplate } -// CompactVariant determines which compaction prompt style to use. type CompactVariant int const ( - CompactBase CompactVariant = iota // Full conversation - CompactPartial // Recent messages only - CompactUpTo // Prefix summarization + CompactBase CompactVariant = iota + CompactPartial + CompactUpTo ) -// FormatCompactSummary strips the drafting block and extracts the content. func FormatCompactSummary(raw string) string { - // Strip ... block start := indexOf(raw, "") end := indexOf(raw, "") if start >= 0 && end > start { raw = raw[:start] + raw[end+len(""):] } - // Extract ... content sumStart := indexOf(raw, "") sumEnd := indexOf(raw, "") if sumStart >= 0 && sumEnd > sumStart { return raw[sumStart+len("") : sumEnd] } - // If no tags, return as-is (fallback) return raw } diff --git a/internal/engine/compact_prompt_test.go b/internal/engine/compact/prompt_test.go similarity index 55% rename from internal/engine/compact_prompt_test.go rename to internal/engine/compact/prompt_test.go index 6d93176e..272f7e30 100644 --- a/internal/engine/compact_prompt_test.go +++ b/internal/engine/compact/prompt_test.go @@ -1,4 +1,4 @@ -package engine +package compact import ( "strings" @@ -28,44 +28,42 @@ func TestBuildCompactPrompt_Partial(t *testing.T) { func TestFormatCompactSummary_WithTags(t *testing.T) { raw := ` This is my internal analysis that should be stripped. -I'm thinking through the conversation... - -The user asked to implement a login feature. -Files modified: auth.go, handler.go. -Next step: add tests. +## Goal +- test task ` - result := FormatCompactSummary(raw) - if strings.Contains(result, "internal analysis") { - t.Error("analysis block should be stripped") + if strings.Contains(result, "analysis") { + t.Error("should strip block") } - if !strings.Contains(result, "login feature") { - t.Error("summary content should be preserved") + if !strings.Contains(result, "## Goal") { + t.Error("should keep content") } - if !strings.Contains(result, "add tests") { - t.Error("next step should be preserved") + if !strings.Contains(result, "test task") { + t.Error("should keep summary text") } } func TestFormatCompactSummary_NoTags(t *testing.T) { - raw := "Just a plain summary without any tags." + raw := "plain text response" result := FormatCompactSummary(raw) if result != raw { - t.Errorf("should return as-is when no tags, got %q", result) + t.Errorf("expected %q, got %q", raw, result) } } -func TestFormatCompactSummary_OnlyAnalysis(t *testing.T) { - raw := `thinking... -The actual summary content here.` - +func TestFormatCompactSummary_SummaryOnly(t *testing.T) { + raw := "just this" result := FormatCompactSummary(raw) - if strings.Contains(result, "thinking") { - t.Error("analysis should be stripped") + if result != "just this" { + t.Errorf("expected 'just this', got %q", result) } - if !strings.Contains(result, "actual summary") { - t.Error("remaining content should be kept") +} + +func TestBuildCompactPrompt_UpTo(t *testing.T) { + prompt := BuildCompactPrompt(CompactUpTo) + if !strings.Contains(prompt, "Chronologically") { + t.Error("UpTo variant should use base analysis (default)") } } diff --git a/internal/engine/compact/session_memory.go b/internal/engine/compact/session_memory.go new file mode 100644 index 00000000..b2d19137 --- /dev/null +++ b/internal/engine/compact/session_memory.go @@ -0,0 +1,92 @@ +package compact + +import ( + "os" + "path/filepath" + "strings" + + "github.com/GrayCodeAI/eyrie/client" + + "github.com/GrayCodeAI/hawk/internal/engine/token" +) + +type SessionMemoryConfig struct { + MinTokens int + MinTextBlockMessages int + MaxTokens int +} + +func DefaultSessionMemoryConfig() SessionMemoryConfig { + return SessionMemoryConfig{ + MinTokens: 10000, + MinTextBlockMessages: 5, + MaxTokens: 40000, + } +} + +func CalculateMessagesToKeepIndex(msgs []client.EyrieMessage, cfg SessionMemoryConfig) int { + if len(msgs) == 0 { + return 0 + } + + tokenCount := 0 + textBlocks := 0 + idx := len(msgs) - 1 + + for idx >= 0 { + tokenCount += token.EstimateMessageTokens(msgs[idx]) + if HasTextContent(msgs[idx]) { + textBlocks++ + } + + if tokenCount >= cfg.MinTokens && textBlocks >= cfg.MinTextBlockMessages { + break + } + if tokenCount >= cfg.MaxTokens { + break + } + idx-- + } + + if idx < 0 { + idx = 0 + } + return idx +} + +func FilterCompactBoundaries(msgs []client.EyrieMessage) []client.EyrieMessage { + result := make([]client.EyrieMessage, 0, len(msgs)) + for _, m := range msgs { + if IsCompactBoundary(m) { + continue + } + result = append(result, m) + } + return result +} + +func IsCompactBoundary(m client.EyrieMessage) bool { + if m.Role != "user" { + return false + } + return strings.HasPrefix(m.Content, "[Session memory summary]") || + strings.HasPrefix(m.Content, "[Conversation summary]") || + strings.HasPrefix(m.Content, "[Earlier conversation compacted") +} + +func SessionMemoryPath(sessionID string) string { + home, _ := os.UserHomeDir() + if sessionID != "" { + return filepath.Join(home, ".hawk", "sessions", sessionID, "memory.md") + } + return filepath.Join(home, ".hawk", "memory.md") +} + +func ReadSessionMemory(sessionID string) (string, error) { + path := SessionMemoryPath(sessionID) + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + return string(data), nil +} diff --git a/internal/engine/compact/strategy.go b/internal/engine/compact/strategy.go new file mode 100644 index 00000000..9f01a5ee --- /dev/null +++ b/internal/engine/compact/strategy.go @@ -0,0 +1,92 @@ +package compact + +import ( + "strings" + + "github.com/GrayCodeAI/eyrie/client" +) + +type CompactResult struct { + Messages []client.EyrieMessage + Summary string + TokensBefore int + TokensAfter int + Strategy string +} + +type CompactConfig struct { + AutoEnabled bool + ContextWindowSize int + AutoCompactBuffer int + MaxOutputTokens int + MaxFailures int +} + +func DefaultCompactConfig() CompactConfig { + return CompactConfig{ + AutoEnabled: true, + ContextWindowSize: 200000, + AutoCompactBuffer: 13000, + MaxOutputTokens: 20000, + MaxFailures: 3, + } +} + +var compactableTools = map[string]bool{ + "Bash": true, + "Read": true, + "Grep": true, + "Glob": true, + "WebFetch": true, + "WebSearch": true, + "Edit": true, + "Write": true, + "LS": true, + "ToolSearch": true, +} + +func IsCompactableTool(name string) bool { + return compactableTools[name] +} + +func AdjustIndexToPreserveAPIInvariants(msgs []client.EyrieMessage, startIdx int) int { + if startIdx <= 0 { + return 0 + } + if startIdx >= len(msgs) { + return len(msgs) + } + + idx := startIdx + for idx > 0 { + msg := msgs[idx] + if msg.ToolResult != nil { + idx-- + continue + } + if msg.Role == "assistant" && len(msg.ToolUse) > 0 { + resultCount := len(msg.ToolUse) + needed := 0 + for j := idx + 1; j < len(msgs) && needed < resultCount; j++ { + if msgs[j].ToolResult != nil { + needed++ + } else { + break + } + } + if needed < resultCount { + idx-- + continue + } + } + break + } + return idx +} + +func HasTextContent(m client.EyrieMessage) bool { + if m.ToolResult != nil { + return false + } + return strings.TrimSpace(m.Content) != "" +} diff --git a/internal/engine/compact/strategy_test.go b/internal/engine/compact/strategy_test.go new file mode 100644 index 00000000..4f0a0b8a --- /dev/null +++ b/internal/engine/compact/strategy_test.go @@ -0,0 +1,222 @@ +package compact + +import ( + "strings" + "testing" + + "github.com/GrayCodeAI/eyrie/client" + + "github.com/GrayCodeAI/hawk/internal/engine/token" +) + +func TestCompactEstimateTokens(t *testing.T) { + msgs := []client.EyrieMessage{ + {Role: "user", Content: "Hello world"}, + {Role: "assistant", Content: strings.Repeat("x", 400)}, + } + tokens := token.EstimateTokens(msgs) + if tokens < 1 { + t.Errorf("expected at least 1 token, got %d", tokens) + } + shortMsgs := []client.EyrieMessage{ + {Role: "user", Content: "hi"}, + } + shortTokens := token.EstimateTokens(shortMsgs) + if tokens <= shortTokens { + t.Errorf("expected more tokens for longer input: %d vs %d", tokens, shortTokens) + } +} + +func TestAdjustIndexToPreserveAPIInvariants(t *testing.T) { + tests := []struct { + name string + msgs []client.EyrieMessage + startIdx int + wantIdx int + }{ + { + name: "empty messages", + msgs: nil, + startIdx: 0, + wantIdx: 0, + }, + { + name: "no tool pairs", + msgs: []client.EyrieMessage{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi"}, + {Role: "user", Content: "bye"}, + }, + startIdx: 1, + wantIdx: 1, + }, + { + name: "tool_result at startIdx - moves back past tool_use", + msgs: []client.EyrieMessage{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "", ToolUse: []client.ToolCall{{ID: "t1", Name: "Bash"}}}, + {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: "output"}}, + {Role: "assistant", Content: "done"}, + }, + startIdx: 2, + wantIdx: 1, + }, + { + name: "at boundary already", + msgs: []client.EyrieMessage{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "response"}, + }, + startIdx: 1, + wantIdx: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AdjustIndexToPreserveAPIInvariants(tt.msgs, tt.startIdx) + if got != tt.wantIdx { + t.Errorf("AdjustIndexToPreserveAPIInvariants() = %d, want %d", got, tt.wantIdx) + } + }) + } +} + +func TestMicrocompactMessages(t *testing.T) { + msgs := []client.EyrieMessage{ + {Role: "user", Content: "read file.go"}, + {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t1", Name: "Read"}}}, + {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: "package main\nfunc main() {}"}}, + {Role: "assistant", Content: "Here's the file content"}, + {Role: "user", Content: "now read another"}, + {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t2", Name: "Read"}}}, + {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t2", Content: "package utils\nfunc Helper() {}"}}, + {Role: "assistant", Content: "Here's the second file"}, + {Role: "user", Content: "and another"}, + {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t3", Name: "Read"}}}, + {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t3", Content: "package config\nfunc Load() {}"}}, + {Role: "assistant", Content: "Here's the third"}, + {Role: "user", Content: "one more"}, + {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t4", Name: "Read"}}}, + {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t4", Content: "package api\nfunc Serve() {}"}}, + {Role: "assistant", Content: "Here's the fourth"}, + } + + cfg := MicroCompactConfig{ + CompactableTools: compactableTools, + TimeGapMins: 0, + KeepRecent: 2, + } + + result := MicrocompactMessages(msgs, cfg) + if len(result) != len(msgs) { + t.Fatalf("message count changed: got %d, want %d", len(result), len(msgs)) + } + + clearedCount := 0 + for _, m := range result { + if m.ToolResult != nil && m.ToolResult.Content == "[Old tool result content cleared]" { + clearedCount++ + } + } + if clearedCount != 2 { + t.Errorf("expected 2 cleared results, got %d", clearedCount) + } + + if result[10].ToolResult.Content == "[Old tool result content cleared]" { + t.Error("third-to-last result should be preserved") + } + if result[14].ToolResult.Content == "[Old tool result content cleared]" { + t.Error("last result should be preserved") + } +} + +func TestAPICompactMessages(t *testing.T) { + msgs := []client.EyrieMessage{ + {Role: "user", Content: "hello"}, + {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t1", Name: "Bash", Arguments: map[string]interface{}{"command": strings.Repeat("x", 1000)}}}}, + {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: strings.Repeat("output ", 1000)}}, + {Role: "assistant", Content: "done"}, + } + + cfg := APICompactConfig{ + TriggerTokens: 0, + KeepTargetTokens: 100, + ClearToolInputs: true, + ClearThinking: true, + PreserveMutating: true, + } + + result := APICompactMessages(msgs, cfg) + if len(result) != len(msgs) { + t.Fatalf("message count changed") + } + + if result[2].ToolResult.Content != "[Old tool result content cleared]" { + t.Error("expected tool result to be cleared") + } +} + +func TestAPICompactPreservesMutatingTools(t *testing.T) { + msgs := []client.EyrieMessage{ + {Role: "user", Content: "edit file"}, + {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t1", Name: "Edit", Arguments: map[string]interface{}{"old_string": strings.Repeat("x", 1000), "new_string": "y"}}}}, + {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: strings.Repeat("edited ", 500)}}, + {Role: "assistant", Content: "edited"}, + } + + cfg := APICompactConfig{ + TriggerTokens: 0, + KeepTargetTokens: 100, + ClearToolInputs: true, + ClearThinking: true, + PreserveMutating: true, + } + + result := APICompactMessages(msgs, cfg) + if result[2].ToolResult.Content == "[Old tool result content cleared]" { + t.Error("mutating tool result should be preserved") + } +} + +func TestCalculateMessagesToKeepIndex(t *testing.T) { + msgs := []client.EyrieMessage{ + {Role: "user", Content: strings.Repeat("hello ", 100)}, + {Role: "assistant", Content: strings.Repeat("response ", 100)}, + {Role: "user", Content: strings.Repeat("follow up ", 100)}, + {Role: "assistant", Content: strings.Repeat("answer ", 100)}, + {Role: "user", Content: strings.Repeat("more ", 100)}, + {Role: "assistant", Content: strings.Repeat("final ", 100)}, + } + + cfg := SessionMemoryConfig{ + MinTokens: 50, + MinTextBlockMessages: 2, + MaxTokens: 5000, + } + + idx := CalculateMessagesToKeepIndex(msgs, cfg) + if idx >= len(msgs) { + t.Errorf("keep index should be within messages range, got %d", idx) + } + if idx < 0 { + t.Errorf("keep index should be non-negative, got %d", idx) + } +} + +func TestFilterCompactBoundaries(t *testing.T) { + msgs := []client.EyrieMessage{ + {Role: "user", Content: "[Session memory summary]\nold stuff"}, + {Role: "assistant", Content: "Understood."}, + {Role: "user", Content: "real message"}, + {Role: "assistant", Content: "real response"}, + } + + filtered := FilterCompactBoundaries(msgs) + if len(filtered) != 3 { + t.Errorf("expected 3 messages after filtering, got %d", len(filtered)) + } + if filtered[0].Content != "Understood." { + t.Errorf("expected first kept message to be 'Understood.', got %q", filtered[0].Content) + } +} diff --git a/internal/engine/compaction_trigger.go b/internal/engine/compact/trigger.go similarity index 51% rename from internal/engine/compaction_trigger.go rename to internal/engine/compact/trigger.go index 82213578..378b946d 100644 --- a/internal/engine/compaction_trigger.go +++ b/internal/engine/compact/trigger.go @@ -1,25 +1,22 @@ -package engine +package compact import "time" -// CompactionTrigger monitors token usage and triggers compaction proactively. type CompactionTrigger struct { - Threshold float64 // trigger at this % of context window (e.g. 0.8 = 80%) - WindowSize int // total context window tokens + Threshold float64 + WindowSize int LastCompact time.Time - MinInterval time.Duration // don't compact more often than this + MinInterval time.Duration } -// NewCompactionTrigger creates a trigger with sensible defaults for solo dev use. func NewCompactionTrigger(windowSize int) *CompactionTrigger { return &CompactionTrigger{ - Threshold: 0.75, // compact at 75% full + Threshold: 0.75, WindowSize: windowSize, MinInterval: 30 * time.Second, } } -// ShouldCompact returns true if current token usage warrants compaction. func (ct *CompactionTrigger) ShouldCompact(currentTokens int) bool { if ct.WindowSize <= 0 { return false @@ -31,7 +28,6 @@ func (ct *CompactionTrigger) ShouldCompact(currentTokens int) bool { return usage >= ct.Threshold } -// MarkCompacted records that compaction just happened. func (ct *CompactionTrigger) MarkCompacted() { ct.LastCompact = time.Now() } diff --git a/internal/engine/compact_api_engine.go b/internal/engine/compact_api_engine.go new file mode 100644 index 00000000..955c0045 --- /dev/null +++ b/internal/engine/compact_api_engine.go @@ -0,0 +1,33 @@ +package engine + +import ( + "context" + + "github.com/GrayCodeAI/eyrie/client" + + "github.com/GrayCodeAI/hawk/internal/engine/compact" +) + +type APICompactStrategy struct{} + +func (s *APICompactStrategy) Name() string { return "api_compact" } + +func (s *APICompactStrategy) ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool { + if tokenCount < 180000 { + return false + } + return compact.CountClearableToolResults(msgs) > 5 +} + +func (s *APICompactStrategy) Compact(ctx context.Context, sess *Session) (*CompactResult, error) { + tokensBefore := EstimateTokens(sess.messages) + result := compact.APICompactMessages(sess.messages, DefaultAPICompactConfig()) + tokensAfter := EstimateTokens(result) + + return &CompactResult{ + Messages: result, + TokensBefore: tokensBefore, + TokensAfter: tokensAfter, + Strategy: "api_compact", + }, nil +} diff --git a/internal/engine/compact_files_test.go b/internal/engine/compact_files_test.go deleted file mode 100644 index b9ab2f02..00000000 --- a/internal/engine/compact_files_test.go +++ /dev/null @@ -1,243 +0,0 @@ -package engine - -import ( - "strings" - "testing" - - "github.com/GrayCodeAI/eyrie/client" -) - -func TestFileTracker_NewFileTracker(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - if ft == nil { - t.Fatal("NewFileTracker returned nil") - } - if ft.ReadFiles == nil || ft.ModifiedFiles == nil { - t.Error("maps should be initialized") - } - if len(ft.ReadFiles) != 0 || len(ft.ModifiedFiles) != 0 { - t.Error("new tracker should have empty maps") - } -} - -func TestFileTracker_RecordRead(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - - ft.RecordRead("main.go") - ft.RecordRead("main.go") - ft.RecordRead("config.go") - ft.RecordRead("") // empty path should be ignored - - if ft.ReadFiles["main.go"] != 2 { - t.Errorf("main.go reads = %d, want 2", ft.ReadFiles["main.go"]) - } - if ft.ReadFiles["config.go"] != 1 { - t.Errorf("config.go reads = %d, want 1", ft.ReadFiles["config.go"]) - } - if _, exists := ft.ReadFiles[""]; exists { - t.Error("empty path should not be tracked") - } -} - -func TestFileTracker_RecordModified(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - - ft.RecordModified("main.go") - ft.RecordModified("main.go") - ft.RecordModified("main.go") - ft.RecordModified("") - - if ft.ModifiedFiles["main.go"] != 3 { - t.Errorf("main.go modifications = %d, want 3", ft.ModifiedFiles["main.go"]) - } - if len(ft.ModifiedFiles) != 1 { - t.Errorf("expected 1 modified file, got %d", len(ft.ModifiedFiles)) - } -} - -func TestFileTracker_ExtractFromMessages(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - - messages := []client.EyrieMessage{ - {Role: "user", Content: "read main.go"}, - {Role: "assistant", ToolUse: []client.ToolCall{ - {Name: "Read", Arguments: map[string]interface{}{"file_path": "/src/main.go"}}, - {Name: "Edit", Arguments: map[string]interface{}{"file_path": "/src/config.go"}}, - }}, - {Role: "assistant", ToolUse: []client.ToolCall{ - {Name: "Write", Arguments: map[string]interface{}{"file_path": "/src/new.go"}}, - {Name: "Read", Arguments: map[string]interface{}{"file_path": "/src/main.go"}}, - }}, - } - - ft.ExtractFromMessages(messages) - - if ft.ReadFiles["/src/main.go"] != 2 { - t.Errorf("main.go reads = %d, want 2", ft.ReadFiles["/src/main.go"]) - } - if ft.ModifiedFiles["/src/config.go"] != 1 { - t.Errorf("config.go mods = %d, want 1", ft.ModifiedFiles["/src/config.go"]) - } - if ft.ModifiedFiles["/src/new.go"] != 1 { - t.Errorf("new.go mods = %d, want 1", ft.ModifiedFiles["/src/new.go"]) - } -} - -func TestFileTracker_FormatForSummary(t *testing.T) { - t.Parallel() - - t.Run("empty tracker", func(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - if got := ft.FormatForSummary(); got != "" { - t.Errorf("FormatForSummary() = %q, want empty", got) - } - }) - - t.Run("with files", func(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - ft.RecordRead("main.go") - ft.RecordRead("main.go") - ft.RecordModified("config.go") - - result := ft.FormatForSummary() - if !strings.Contains(result, "") { - t.Error("should contain tag") - } - if !strings.Contains(result, "") { - t.Error("should contain tag") - } - if !strings.Contains(result, "Read:") { - t.Error("should contain Read: section") - } - if !strings.Contains(result, "Modified:") { - t.Error("should contain Modified: section") - } - if !strings.Contains(result, "main.go") { - t.Error("should contain main.go") - } - }) -} - -func TestFileTracker_ParseFromSummary(t *testing.T) { - t.Parallel() - - t.Run("valid summary", func(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - summary := `Some context here. - -Read: main.go (2x), config.go (1x) -Modified: handler.go (3x) - -More context.` - - ft.ParseFromSummary(summary) - - if ft.ReadFiles["main.go"] != 2 { - t.Errorf("main.go reads = %d, want 2", ft.ReadFiles["main.go"]) - } - if ft.ReadFiles["config.go"] != 1 { - t.Errorf("config.go reads = %d, want 1", ft.ReadFiles["config.go"]) - } - if ft.ModifiedFiles["handler.go"] != 3 { - t.Errorf("handler.go mods = %d, want 3", ft.ModifiedFiles["handler.go"]) - } - }) - - t.Run("no tracked-files block", func(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - ft.ParseFromSummary("just a regular summary with no tracking data") - if len(ft.ReadFiles) != 0 || len(ft.ModifiedFiles) != 0 { - t.Error("should not parse anything from summary without tracked-files") - } - }) - - t.Run("empty block", func(t *testing.T) { - t.Parallel() - ft := NewFileTracker() - ft.ParseFromSummary("\n") - if len(ft.ReadFiles) != 0 || len(ft.ModifiedFiles) != 0 { - t.Error("should not parse anything from empty block") - } - }) -} - -func TestFileTracker_Merge(t *testing.T) { - t.Parallel() - - t.Run("merge into empty", func(t *testing.T) { - t.Parallel() - ft1 := NewFileTracker() - ft2 := NewFileTracker() - ft2.RecordRead("a.go") - ft2.RecordModified("b.go") - - ft1.Merge(ft2) - - if ft1.ReadFiles["a.go"] != 1 { - t.Errorf("a.go reads = %d, want 1", ft1.ReadFiles["a.go"]) - } - if ft1.ModifiedFiles["b.go"] != 1 { - t.Errorf("b.go mods = %d, want 1", ft1.ModifiedFiles["b.go"]) - } - }) - - t.Run("merge with overlap", func(t *testing.T) { - t.Parallel() - ft1 := NewFileTracker() - ft1.RecordRead("shared.go") - ft1.RecordRead("shared.go") - - ft2 := NewFileTracker() - ft2.RecordRead("shared.go") - - ft1.Merge(ft2) - - if ft1.ReadFiles["shared.go"] != 3 { - t.Errorf("shared.go reads = %d, want 3", ft1.ReadFiles["shared.go"]) - } - }) - - t.Run("merge nil", func(t *testing.T) { - t.Parallel() - ft1 := NewFileTracker() - ft1.RecordRead("x.go") - ft1.Merge(nil) - if ft1.ReadFiles["x.go"] != 1 { - t.Error("merge nil should not change tracker") - } - }) -} - -func TestFileTracker_RoundTrip(t *testing.T) { - t.Parallel() - ft1 := NewFileTracker() - ft1.RecordRead("main.go") - ft1.RecordRead("main.go") - ft1.RecordRead("config.go") - ft1.RecordModified("handler.go") - ft1.RecordModified("handler.go") - ft1.RecordModified("handler.go") - - summary := ft1.FormatForSummary() - - ft2 := NewFileTracker() - ft2.ParseFromSummary(summary) - - if ft2.ReadFiles["main.go"] != 2 { - t.Errorf("round-trip: main.go reads = %d, want 2", ft2.ReadFiles["main.go"]) - } - if ft2.ReadFiles["config.go"] != 1 { - t.Errorf("round-trip: config.go reads = %d, want 1", ft2.ReadFiles["config.go"]) - } - if ft2.ModifiedFiles["handler.go"] != 3 { - t.Errorf("round-trip: handler.go mods = %d, want 3", ft2.ModifiedFiles["handler.go"]) - } -} diff --git a/internal/engine/compact_micro.go b/internal/engine/compact_micro.go deleted file mode 100644 index 552462d2..00000000 --- a/internal/engine/compact_micro.go +++ /dev/null @@ -1,141 +0,0 @@ -package engine - -import ( - "context" - "time" - - "github.com/GrayCodeAI/eyrie/client" -) - -// MicroCompactStrategy clears old tool result content while preserving message structure. -type MicroCompactStrategy struct{} - -func (s *MicroCompactStrategy) Name() string { return "micro" } - -// ShouldTrigger fires when there are enough messages with compactable tool results -// and sufficient time has passed since the last assistant message (cache is cold). -func (s *MicroCompactStrategy) ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool { - if tokenCount < threshold/2 { - return false - } - compactableCount := 0 - for _, m := range msgs { - if m.ToolResult != nil && isCompactableTool(toolNameForResult(m, msgs)) { - compactableCount++ - } - } - if compactableCount < 5 { - return false - } - return hasTimeGap(msgs, 60*time.Minute) -} - -func (s *MicroCompactStrategy) Compact(ctx context.Context, sess *Session) (*CompactResult, error) { - tokensBefore := EstimateTokens(sess.messages) - result := microcompactMessages(sess.messages, DefaultMicroCompactConfig()) - tokensAfter := EstimateTokens(result) - - return &CompactResult{ - Messages: result, - TokensBefore: tokensBefore, - TokensAfter: tokensAfter, - Strategy: "micro", - }, nil -} - -// MicroCompactConfig controls micro-compaction behavior. -type MicroCompactConfig struct { - CompactableTools map[string]bool - TimeGapMins float64 - KeepRecent int -} - -// DefaultMicroCompactConfig returns the default micro-compaction settings. -func DefaultMicroCompactConfig() MicroCompactConfig { - return MicroCompactConfig{ - CompactableTools: compactableTools, - TimeGapMins: 60, - KeepRecent: 3, - } -} - -// microcompactMessages clears old tool result content from compactable tools, -// keeping the most recent N results intact. -func microcompactMessages(msgs []client.EyrieMessage, cfg MicroCompactConfig) []client.EyrieMessage { - type resultInfo struct { - index int - toolName string - } - - var compactableResults []resultInfo - for i, m := range msgs { - if m.ToolResult == nil { - continue - } - toolName := toolNameForResult(m, msgs) - if cfg.CompactableTools[toolName] { - compactableResults = append(compactableResults, resultInfo{index: i, toolName: toolName}) - } - } - - if len(compactableResults) <= cfg.KeepRecent { - return msgs - } - - toClear := len(compactableResults) - cfg.KeepRecent - clearSet := make(map[int]bool, toClear) - for i := 0; i < toClear; i++ { - clearSet[compactableResults[i].index] = true - } - - result := make([]client.EyrieMessage, len(msgs)) - copy(result, msgs) - for idx := range clearSet { - result[idx] = client.EyrieMessage{ - Role: result[idx].Role, - ToolResult: &client.ToolResult{ - ToolUseID: result[idx].ToolResult.ToolUseID, - Content: "[Old tool result content cleared]", - IsError: result[idx].ToolResult.IsError, - }, - } - } - - return result -} - -// toolNameForResult finds the tool name for a tool_result message by scanning -// backward for the matching tool_use. -func toolNameForResult(m client.EyrieMessage, msgs []client.EyrieMessage) string { - if m.ToolResult == nil { - return "" - } - targetID := m.ToolResult.ToolUseID - for i := len(msgs) - 1; i >= 0; i-- { - for _, tc := range msgs[i].ToolUse { - if tc.ID == targetID { - return tc.Name - } - } - } - return "" -} - -// hasTimeGap checks if there's a gap >= threshold since the last assistant message, -// indicating the cache is likely cold. -func hasTimeGap(msgs []client.EyrieMessage, threshold time.Duration) bool { - // In the absence of timestamps on messages, use message count as a proxy. - // More than 20 messages since last meaningful text exchange suggests a cold cache. - lastTextIdx := -1 - for i := len(msgs) - 1; i >= 0; i-- { - if hasTextContent(msgs[i]) && msgs[i].Role == "assistant" { - lastTextIdx = i - break - } - } - if lastTextIdx < 0 { - return false - } - messagesSinceText := len(msgs) - lastTextIdx - 1 - return messagesSinceText > 20 || threshold == 0 -} diff --git a/internal/engine/compact_micro_engine.go b/internal/engine/compact_micro_engine.go new file mode 100644 index 00000000..e83ec029 --- /dev/null +++ b/internal/engine/compact_micro_engine.go @@ -0,0 +1,43 @@ +package engine + +import ( + "context" + "time" + + "github.com/GrayCodeAI/eyrie/client" + + "github.com/GrayCodeAI/hawk/internal/engine/compact" +) + +type MicroCompactStrategy struct{} + +func (s *MicroCompactStrategy) Name() string { return "micro" } + +func (s *MicroCompactStrategy) ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool { + if tokenCount < threshold/2 { + return false + } + compactableCount := 0 + for _, m := range msgs { + if m.ToolResult != nil && compact.IsCompactableTool(compact.ToolNameForResult(m, msgs)) { + compactableCount++ + } + } + if compactableCount < 5 { + return false + } + return compact.HasTimeGap(msgs, 60*time.Minute) +} + +func (s *MicroCompactStrategy) Compact(ctx context.Context, sess *Session) (*CompactResult, error) { + tokensBefore := EstimateTokens(sess.messages) + result := compact.MicrocompactMessages(sess.messages, DefaultMicroCompactConfig()) + tokensAfter := EstimateTokens(result) + + return &CompactResult{ + Messages: result, + TokensBefore: tokensBefore, + TokensAfter: tokensAfter, + Strategy: "micro", + }, nil +} diff --git a/internal/engine/compact_reexports.go b/internal/engine/compact_reexports.go new file mode 100644 index 00000000..49566372 --- /dev/null +++ b/internal/engine/compact_reexports.go @@ -0,0 +1,45 @@ +// This file re-exports symbols from the compact sub-package so that existing +// callers of engine.* keep compiling during the Stage 2 migration. +// See REFACTOR_PLAN.md. +package engine + +import ( + "github.com/GrayCodeAI/eyrie/client" + + "github.com/GrayCodeAI/hawk/internal/engine/compact" +) + +type CompactVariant = compact.CompactVariant + +const ( + CompactBase = compact.CompactBase + CompactPartial = compact.CompactPartial + CompactUpTo = compact.CompactUpTo +) + +type ( + CompactResult = compact.CompactResult + CompactConfig = compact.CompactConfig + FileTracker = compact.FileTracker + MicroCompactConfig = compact.MicroCompactConfig + APICompactConfig = compact.APICompactConfig + SessionMemoryConfig = compact.SessionMemoryConfig + CompactionTrigger = compact.CompactionTrigger +) + +func DefaultCompactConfig() CompactConfig { return compact.DefaultCompactConfig() } +func DefaultMicroCompactConfig() MicroCompactConfig { return compact.DefaultMicroCompactConfig() } +func DefaultAPICompactConfig() APICompactConfig { return compact.DefaultAPICompactConfig() } +func DefaultSessionMemoryConfig() SessionMemoryConfig { return compact.DefaultSessionMemoryConfig() } +func NewFileTracker() *FileTracker { return compact.NewFileTracker() } +func NewCompactionTrigger(windowSize int) *CompactionTrigger { + return compact.NewCompactionTrigger(windowSize) +} + +func BuildCompactPrompt(variant CompactVariant) string { return compact.BuildCompactPrompt(variant) } +func FormatCompactSummary(raw string) string { return compact.FormatCompactSummary(raw) } +func IsCompactableTool(name string) bool { return compact.IsCompactableTool(name) } +func AdjustIndexToPreserveAPIInvariants(msgs []client.EyrieMessage, startIdx int) int { + return compact.AdjustIndexToPreserveAPIInvariants(msgs, startIdx) +} +func HasTextContent(m client.EyrieMessage) bool { return compact.HasTextContent(m) } diff --git a/internal/engine/compact_session_memory.go b/internal/engine/compact_session_memory.go deleted file mode 100644 index 8d30c821..00000000 --- a/internal/engine/compact_session_memory.go +++ /dev/null @@ -1,159 +0,0 @@ -package engine - -import ( - "context" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/GrayCodeAI/eyrie/client" -) - -// SessionMemoryStrategy uses the session memory file as a compaction summary -// instead of making an LLM call. -type SessionMemoryStrategy struct{} - -func (s *SessionMemoryStrategy) Name() string { return "session_memory" } - -func (s *SessionMemoryStrategy) ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool { - if tokenCount < threshold { - return false - } - memFile := sessionMemoryPath("") - info, err := os.Stat(memFile) - if err != nil || info.Size() < 100 { - return false - } - return true -} - -func (s *SessionMemoryStrategy) Compact(ctx context.Context, sess *Session) (*CompactResult, error) { - memContent, err := readSessionMemory("") - if err != nil { - return nil, fmt.Errorf("reading session memory: %w", err) - } - if strings.TrimSpace(memContent) == "" { - return nil, fmt.Errorf("session memory is empty") - } - - tokensBefore := EstimateTokens(sess.messages) - - cfg := DefaultSessionMemoryConfig() - keepIdx := calculateMessagesToKeepIndex(sess.messages, cfg) - keepIdx = adjustIndexToPreserveAPIInvariants(sess.messages, keepIdx) - - if keepIdx >= len(sess.messages)-2 { - return nil, fmt.Errorf("not enough messages to compact") - } - - kept := sess.messages[keepIdx:] - kept = filterCompactBoundaries(kept) - - result := make([]client.EyrieMessage, 0, len(kept)+2) - result = append(result, client.EyrieMessage{ - Role: "user", - Content: "[Session memory summary]\n" + memContent + "\n\n[Continue from the recent messages below.]", - }) - result = append(result, client.EyrieMessage{ - Role: "assistant", - Content: "Understood. I have the context from the session memory above. Continuing with the recent conversation.", - }) - result = append(result, kept...) - - tokensAfter := EstimateTokens(result) - - return &CompactResult{ - Messages: result, - Summary: memContent, - TokensBefore: tokensBefore, - TokensAfter: tokensAfter, - Strategy: "session_memory", - }, nil -} - -// SessionMemoryConfig controls session memory compaction thresholds. -type SessionMemoryConfig struct { - MinTokens int - MinTextBlockMessages int - MaxTokens int -} - -// DefaultSessionMemoryConfig returns defaults matching the archive. -func DefaultSessionMemoryConfig() SessionMemoryConfig { - return SessionMemoryConfig{ - MinTokens: 10000, - MinTextBlockMessages: 5, - MaxTokens: 40000, - } -} - -// calculateMessagesToKeepIndex walks backward from the end of messages -// until we have enough tokens and text-block messages to keep. -func calculateMessagesToKeepIndex(msgs []client.EyrieMessage, cfg SessionMemoryConfig) int { - if len(msgs) == 0 { - return 0 - } - - tokenCount := 0 - textBlocks := 0 - idx := len(msgs) - 1 - - for idx >= 0 { - tokenCount += estimateMessageTokens(msgs[idx]) - if hasTextContent(msgs[idx]) { - textBlocks++ - } - - if tokenCount >= cfg.MinTokens && textBlocks >= cfg.MinTextBlockMessages { - break - } - if tokenCount >= cfg.MaxTokens { - break - } - idx-- - } - - if idx < 0 { - idx = 0 - } - return idx -} - -// filterCompactBoundaries removes old compact boundary messages from kept messages. -func filterCompactBoundaries(msgs []client.EyrieMessage) []client.EyrieMessage { - result := make([]client.EyrieMessage, 0, len(msgs)) - for _, m := range msgs { - if isCompactBoundary(m) { - continue - } - result = append(result, m) - } - return result -} - -func isCompactBoundary(m client.EyrieMessage) bool { - if m.Role != "user" { - return false - } - return strings.HasPrefix(m.Content, "[Session memory summary]") || - strings.HasPrefix(m.Content, "[Conversation summary]") || - strings.HasPrefix(m.Content, "[Earlier conversation compacted") -} - -func sessionMemoryPath(sessionID string) string { - home, _ := os.UserHomeDir() - if sessionID != "" { - return filepath.Join(home, ".hawk", "sessions", sessionID, "memory.md") - } - return filepath.Join(home, ".hawk", "memory.md") -} - -func readSessionMemory(sessionID string) (string, error) { - path := sessionMemoryPath(sessionID) - data, err := os.ReadFile(path) - if err != nil { - return "", err - } - return string(data), nil -} diff --git a/internal/engine/compact_session_memory_engine.go b/internal/engine/compact_session_memory_engine.go new file mode 100644 index 00000000..fdeaff8f --- /dev/null +++ b/internal/engine/compact_session_memory_engine.go @@ -0,0 +1,72 @@ +package engine + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/GrayCodeAI/eyrie/client" + + "github.com/GrayCodeAI/hawk/internal/engine/compact" +) + +type SessionMemoryStrategy struct{} + +func (s *SessionMemoryStrategy) Name() string { return "session_memory" } + +func (s *SessionMemoryStrategy) ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool { + if tokenCount < threshold { + return false + } + memFile := compact.SessionMemoryPath("") + info, err := os.Stat(memFile) + if err != nil || info.Size() < 100 { + return false + } + return true +} + +func (s *SessionMemoryStrategy) Compact(ctx context.Context, sess *Session) (*CompactResult, error) { + memContent, err := compact.ReadSessionMemory("") + if err != nil { + return nil, fmt.Errorf("reading session memory: %w", err) + } + if strings.TrimSpace(memContent) == "" { + return nil, fmt.Errorf("session memory is empty") + } + + tokensBefore := EstimateTokens(sess.messages) + + cfg := DefaultSessionMemoryConfig() + keepIdx := compact.CalculateMessagesToKeepIndex(sess.messages, cfg) + keepIdx = compact.AdjustIndexToPreserveAPIInvariants(sess.messages, keepIdx) + + if keepIdx >= len(sess.messages)-2 { + return nil, fmt.Errorf("not enough messages to compact") + } + + kept := sess.messages[keepIdx:] + kept = compact.FilterCompactBoundaries(kept) + + result := make([]client.EyrieMessage, 0, len(kept)+2) + result = append(result, client.EyrieMessage{ + Role: "user", + Content: "[Session memory summary]\n" + memContent + "\n\n[Continue from the recent messages below.]", + }) + result = append(result, client.EyrieMessage{ + Role: "assistant", + Content: "Understood. I have the context from the session memory above. Continuing with the recent conversation.", + }) + result = append(result, kept...) + + tokensAfter := EstimateTokens(result) + + return &CompactResult{ + Messages: result, + Summary: memContent, + TokensBefore: tokensBefore, + TokensAfter: tokensAfter, + Strategy: "session_memory", + }, nil +} diff --git a/internal/engine/compact_split.go b/internal/engine/compact_split.go index b861c62d..2d622bf7 100644 --- a/internal/engine/compact_split.go +++ b/internal/engine/compact_split.go @@ -30,7 +30,7 @@ func (s *Session) SplitTurnNeeded(keepCount int) bool { tail := s.messages[len(s.messages)-keepCount:] totalTokens := 0 for _, msg := range tail { - totalTokens += estimateMessageTokens(msg) + totalTokens += EstimateMessageTokens(msg) } if len(tail) == 0 { return false @@ -43,7 +43,7 @@ func (s *Session) SplitTurnNeeded(keepCount int) bool { // Check if any single message in the tail exceeds the budget for _, msg := range tail { - if estimateMessageTokens(msg) > budget { + if EstimateMessageTokens(msg) > budget { return true } } @@ -67,7 +67,7 @@ func (s *Session) splitTurnCompact() { tail := s.messages[len(s.messages)-keepEnd:] totalTokens := 0 for _, msg := range tail { - totalTokens += estimateMessageTokens(msg) + totalTokens += EstimateMessageTokens(msg) } avgTokens := totalTokens / len(tail) budget := avgTokens * 3 @@ -77,7 +77,7 @@ func (s *Session) splitTurnCompact() { oversizedIdx := -1 for i, msg := range tail { - if estimateMessageTokens(msg) > budget { + if EstimateMessageTokens(msg) > budget { oversizedIdx = i break } diff --git a/internal/engine/compact_strategy.go b/internal/engine/compact_strategy.go deleted file mode 100644 index 9cb40bc9..00000000 --- a/internal/engine/compact_strategy.go +++ /dev/null @@ -1,172 +0,0 @@ -package engine - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/GrayCodeAI/eyrie/client" -) - -// CompactStrategy defines a conversation compaction approach. -type CompactStrategy interface { - Name() string - ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool - Compact(ctx context.Context, s *Session) (*CompactResult, error) -} - -// CompactResult holds the outcome of a compaction operation. -type CompactResult struct { - Messages []client.EyrieMessage - Summary string - TokensBefore int - TokensAfter int - Strategy string -} - -// CompactConfig controls auto-compaction behavior. -type CompactConfig struct { - AutoEnabled bool - ContextWindowSize int - AutoCompactBuffer int - MaxOutputTokens int - MaxFailures int -} - -// DefaultCompactConfig returns sensible defaults matching the archive behavior. -func DefaultCompactConfig() CompactConfig { - return CompactConfig{ - AutoEnabled: true, - ContextWindowSize: 200000, - AutoCompactBuffer: 13000, - MaxOutputTokens: 20000, - MaxFailures: 3, - } -} - -// StrategyRegistry manages compaction strategies in priority order. -type StrategyRegistry struct { - strategies []CompactStrategy - config CompactConfig -} - -// NewStrategyRegistry creates a registry with default strategies. -func NewStrategyRegistry(config CompactConfig) *StrategyRegistry { - r := &StrategyRegistry{config: config} - r.strategies = []CompactStrategy{ - &MicroCompactStrategy{}, - &SessionMemoryStrategy{}, - &SmartCompactStrategy{}, - &TruncateStrategy{}, - } - return r -} - -// SelectStrategy picks the highest-priority strategy whose trigger fires. -func (r *StrategyRegistry) SelectStrategy(msgs []client.EyrieMessage, tokenCount int) CompactStrategy { - threshold := r.config.ContextWindowSize - r.config.AutoCompactBuffer - r.config.MaxOutputTokens - for _, s := range r.strategies { - if s.ShouldTrigger(msgs, tokenCount, threshold) { - return s - } - } - return &TruncateStrategy{} -} - -// EstimateTokens provides a rough token estimate for messages. -func EstimateTokens(msgs []client.EyrieMessage) int { - total := 0 - for _, m := range msgs { - total += estimateMessageTokens(m) - } - return total -} - -func estimateMessageTokens(m client.EyrieMessage) int { - tokens := CountTokens(m.Content) - for _, tc := range m.ToolUse { - tokens += CountTokens(tc.Name) - for _, v := range tc.Arguments { - switch val := v.(type) { - case string: - tokens += CountTokens(val) - default: - if encoded, err := json.Marshal(v); err == nil { - tokens += CountTokens(string(encoded)) - } else { - // Fallback to string conversion for unknown types - tokens += CountTokens(fmt.Sprintf("%v", v)) - } - } - } - } - if m.ToolResult != nil { - tokens += CountTokens(m.ToolResult.Content) - } - return tokens -} - -// compactableTools are tools whose old results can be safely cleared. -var compactableTools = map[string]bool{ - "Bash": true, - "Read": true, - "Grep": true, - "Glob": true, - "WebFetch": true, - "WebSearch": true, - "Edit": true, - "Write": true, - "LS": true, - "ToolSearch": true, -} - -// isCompactableTool returns true if the tool's results can be cleared during micro-compaction. -func isCompactableTool(name string) bool { - return compactableTools[name] -} - -// adjustIndexToPreserveAPIInvariants walks backward from startIdx to ensure -// tool_use/tool_result pairs are never split. -func adjustIndexToPreserveAPIInvariants(msgs []client.EyrieMessage, startIdx int) int { - if startIdx <= 0 { - return 0 - } - if startIdx >= len(msgs) { - return len(msgs) - } - - idx := startIdx - for idx > 0 { - msg := msgs[idx] - if msg.ToolResult != nil { - idx-- - continue - } - if msg.Role == "assistant" && len(msg.ToolUse) > 0 { - resultCount := len(msg.ToolUse) - needed := 0 - for j := idx + 1; j < len(msgs) && needed < resultCount; j++ { - if msgs[j].ToolResult != nil { - needed++ - } else { - break - } - } - if needed < resultCount { - idx-- - continue - } - } - break - } - return idx -} - -// hasTextContent returns true if the message contains meaningful text (not just tool results). -func hasTextContent(m client.EyrieMessage) bool { - if m.ToolResult != nil { - return false - } - return strings.TrimSpace(m.Content) != "" -} diff --git a/internal/engine/compact_strategy_engine.go b/internal/engine/compact_strategy_engine.go new file mode 100644 index 00000000..3ba70bca --- /dev/null +++ b/internal/engine/compact_strategy_engine.go @@ -0,0 +1,39 @@ +package engine + +import ( + "context" + + "github.com/GrayCodeAI/eyrie/client" +) + +type CompactStrategy interface { + Name() string + ShouldTrigger(msgs []client.EyrieMessage, tokenCount, threshold int) bool + Compact(ctx context.Context, s *Session) (*CompactResult, error) +} + +type StrategyRegistry struct { + strategies []CompactStrategy + config CompactConfig +} + +func NewStrategyRegistry(config CompactConfig) *StrategyRegistry { + r := &StrategyRegistry{config: config} + r.strategies = []CompactStrategy{ + &MicroCompactStrategy{}, + &SessionMemoryStrategy{}, + &SmartCompactStrategy{}, + &TruncateStrategy{}, + } + return r +} + +func (r *StrategyRegistry) SelectStrategy(msgs []client.EyrieMessage, tokenCount int) CompactStrategy { + threshold := r.config.ContextWindowSize - r.config.AutoCompactBuffer - r.config.MaxOutputTokens + for _, s := range r.strategies { + if s.ShouldTrigger(msgs, tokenCount, threshold) { + return s + } + } + return &TruncateStrategy{} +} diff --git a/internal/engine/compact_strategy_test.go b/internal/engine/compact_strategy_test.go index 8fa8b44e..9dc4b472 100644 --- a/internal/engine/compact_strategy_test.go +++ b/internal/engine/compact_strategy_test.go @@ -11,188 +11,14 @@ import ( "github.com/GrayCodeAI/hawk/internal/observability/metrics" ) -func TestCompactEstimateTokens(t *testing.T) { - msgs := []client.EyrieMessage{ - {Role: "user", Content: "Hello world"}, - {Role: "assistant", Content: strings.Repeat("x", 400)}, - } - tokens := EstimateTokens(msgs) - if tokens < 1 { - t.Errorf("expected at least 1 token, got %d", tokens) - } - // Longer input should produce more tokens - shortMsgs := []client.EyrieMessage{ - {Role: "user", Content: "hi"}, - } - shortTokens := EstimateTokens(shortMsgs) - if tokens <= shortTokens { - t.Errorf("expected more tokens for longer input: %d vs %d", tokens, shortTokens) - } -} - -func TestAdjustIndexToPreserveAPIInvariants(t *testing.T) { - tests := []struct { - name string - msgs []client.EyrieMessage - startIdx int - wantIdx int - }{ - { - name: "empty messages", - msgs: nil, - startIdx: 0, - wantIdx: 0, - }, - { - name: "no tool pairs", - msgs: []client.EyrieMessage{ - {Role: "user", Content: "hello"}, - {Role: "assistant", Content: "hi"}, - {Role: "user", Content: "bye"}, - }, - startIdx: 1, - wantIdx: 1, - }, - { - name: "tool_result at startIdx - moves back past tool_use", - msgs: []client.EyrieMessage{ - {Role: "user", Content: "hello"}, - {Role: "assistant", Content: "", ToolUse: []client.ToolCall{{ID: "t1", Name: "Bash"}}}, - {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: "output"}}, - {Role: "assistant", Content: "done"}, - }, - startIdx: 2, - wantIdx: 1, - }, - { - name: "at boundary already", - msgs: []client.EyrieMessage{ - {Role: "user", Content: "hello"}, - {Role: "assistant", Content: "response"}, - }, - startIdx: 1, - wantIdx: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := adjustIndexToPreserveAPIInvariants(tt.msgs, tt.startIdx) - if got != tt.wantIdx { - t.Errorf("adjustIndexToPreserveAPIInvariants() = %d, want %d", got, tt.wantIdx) - } - }) - } -} - -func TestMicrocompactMessages(t *testing.T) { - msgs := []client.EyrieMessage{ - {Role: "user", Content: "read file.go"}, - {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t1", Name: "Read"}}}, - {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: "package main\nfunc main() {}"}}, - {Role: "assistant", Content: "Here's the file content"}, - {Role: "user", Content: "now read another"}, - {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t2", Name: "Read"}}}, - {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t2", Content: "package utils\nfunc Helper() {}"}}, - {Role: "assistant", Content: "Here's the second file"}, - {Role: "user", Content: "and another"}, - {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t3", Name: "Read"}}}, - {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t3", Content: "package config\nfunc Load() {}"}}, - {Role: "assistant", Content: "Here's the third"}, - {Role: "user", Content: "one more"}, - {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t4", Name: "Read"}}}, - {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t4", Content: "package api\nfunc Serve() {}"}}, - {Role: "assistant", Content: "Here's the fourth"}, - } - - cfg := MicroCompactConfig{ - CompactableTools: compactableTools, - TimeGapMins: 0, - KeepRecent: 2, - } - - result := microcompactMessages(msgs, cfg) - if len(result) != len(msgs) { - t.Fatalf("message count changed: got %d, want %d", len(result), len(msgs)) - } - - clearedCount := 0 - for _, m := range result { - if m.ToolResult != nil && m.ToolResult.Content == "[Old tool result content cleared]" { - clearedCount++ - } - } - if clearedCount != 2 { - t.Errorf("expected 2 cleared results, got %d", clearedCount) - } - - // Last 2 results should be preserved - if result[10].ToolResult.Content == "[Old tool result content cleared]" { - t.Error("third-to-last result should be preserved") - } - if result[14].ToolResult.Content == "[Old tool result content cleared]" { - t.Error("last result should be preserved") - } -} - func TestSessionMemoryStrategy_ShouldTrigger(t *testing.T) { s := &SessionMemoryStrategy{} msgs := makeMessages(50) - // Without a memory file, should not trigger if s.ShouldTrigger(msgs, 200000, 150000) { t.Error("should not trigger without memory file") } } -func TestAPICompactMessages(t *testing.T) { - msgs := []client.EyrieMessage{ - {Role: "user", Content: "hello"}, - {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t1", Name: "Bash", Arguments: map[string]interface{}{"command": strings.Repeat("x", 1000)}}}}, - {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: strings.Repeat("output ", 1000)}}, - {Role: "assistant", Content: "done"}, - } - - cfg := APICompactConfig{ - TriggerTokens: 0, - KeepTargetTokens: 100, - ClearToolInputs: true, - ClearThinking: true, - PreserveMutating: true, - } - - result := apiCompactMessages(msgs, cfg) - if len(result) != len(msgs) { - t.Fatalf("message count changed") - } - - if result[2].ToolResult.Content != "[Old tool result content cleared]" { - t.Error("expected tool result to be cleared") - } -} - -func TestAPICompactPreservesMutatingTools(t *testing.T) { - msgs := []client.EyrieMessage{ - {Role: "user", Content: "edit file"}, - {Role: "assistant", ToolUse: []client.ToolCall{{ID: "t1", Name: "Edit", Arguments: map[string]interface{}{"old_string": strings.Repeat("x", 1000), "new_string": "y"}}}}, - {Role: "user", ToolResult: &client.ToolResult{ToolUseID: "t1", Content: strings.Repeat("edited ", 500)}}, - {Role: "assistant", Content: "edited"}, - } - - cfg := APICompactConfig{ - TriggerTokens: 0, - KeepTargetTokens: 100, - ClearToolInputs: true, - ClearThinking: true, - PreserveMutating: true, - } - - result := apiCompactMessages(msgs, cfg) - // Edit tool results should NOT be cleared - if result[2].ToolResult.Content == "[Old tool result content cleared]" { - t.Error("mutating tool result should be preserved") - } -} - func TestAutoCompactor_CircuitBreaker(t *testing.T) { cfg := DefaultCompactConfig() cfg.MaxFailures = 2 @@ -230,48 +56,6 @@ func TestStrategyRegistry_SelectStrategy(t *testing.T) { } } -func TestCalculateMessagesToKeepIndex(t *testing.T) { - msgs := []client.EyrieMessage{ - {Role: "user", Content: strings.Repeat("hello ", 100)}, - {Role: "assistant", Content: strings.Repeat("response ", 100)}, - {Role: "user", Content: strings.Repeat("follow up ", 100)}, - {Role: "assistant", Content: strings.Repeat("answer ", 100)}, - {Role: "user", Content: strings.Repeat("more ", 100)}, - {Role: "assistant", Content: strings.Repeat("final ", 100)}, - } - - cfg := SessionMemoryConfig{ - MinTokens: 50, - MinTextBlockMessages: 2, - MaxTokens: 5000, - } - - idx := calculateMessagesToKeepIndex(msgs, cfg) - if idx >= len(msgs) { - t.Errorf("keep index should be within messages range, got %d", idx) - } - if idx < 0 { - t.Errorf("keep index should be non-negative, got %d", idx) - } -} - -func TestFilterCompactBoundaries(t *testing.T) { - msgs := []client.EyrieMessage{ - {Role: "user", Content: "[Session memory summary]\nold stuff"}, - {Role: "assistant", Content: "Understood."}, - {Role: "user", Content: "real message"}, - {Role: "assistant", Content: "real response"}, - } - - filtered := filterCompactBoundaries(msgs) - if len(filtered) != 3 { - t.Errorf("expected 3 messages after filtering, got %d", len(filtered)) - } - if filtered[0].Content != "Understood." { - t.Errorf("expected first kept message to be 'Understood.', got %q", filtered[0].Content) - } -} - func TestTruncateStrategy(t *testing.T) { sess := &Session{ messages: makeMessages(100), @@ -293,8 +77,6 @@ func TestTruncateStrategy(t *testing.T) { } } -// Helper functions - func makeMessages(n int) []client.EyrieMessage { msgs := make([]client.EyrieMessage, n) for i := range msgs { diff --git a/internal/engine/control/aliases.go b/internal/engine/control/aliases.go index 2284088d..87532c2e 100644 --- a/internal/engine/control/aliases.go +++ b/internal/engine/control/aliases.go @@ -1,44 +1,9 @@ -// Package control is the Stage-1 namespace for engine control-flow safety -// types — loop detection, stall detection, backtracking. See ../REFACTOR_PLAN.md. +// Package control provides engine control-flow safety types — loop +// detection, stall detection, and backtracking. +// +// Public types: LoopDetector, StallEntry, StallResult, StallDetector, +// DecisionPoint, BacktrackEngine. +// +// Public functions: NewLoopDetector, NewStallDetector, NewBacktrackEngine. +// Public constants: DoomLoopThreshold. package control - -import "github.com/GrayCodeAI/hawk/internal/engine" - -// LoopDetector watches for repeated tool-call patterns indicating the agent -// is stuck in a doom loop. -type LoopDetector = engine.LoopDetector - -// DoomLoopThreshold is the number of identical recent actions that flips a -// LoopDetector into "stuck" state. -const DoomLoopThreshold = engine.DoomLoopThreshold - -// NewLoopDetector returns a detector with the given sliding-window size and -// max-repeats threshold. -func NewLoopDetector(windowSize, maxRepeats int) *LoopDetector { - return engine.NewLoopDetector(windowSize, maxRepeats) -} - -// StallEntry is one observed action in the stall window. -type StallEntry = engine.StallEntry - -// StallResult is the verdict of a single stall check. -type StallResult = engine.StallResult - -// StallDetector flags long stretches of no observable progress. -type StallDetector = engine.StallDetector - -// NewStallDetector returns a detector with default thresholds. -func NewStallDetector() *StallDetector { - return engine.NewStallDetector() -} - -// DecisionPoint is a snapshot the agent can return to. -type DecisionPoint = engine.DecisionPoint - -// BacktrackEngine manages decision points and the rollback path. -type BacktrackEngine = engine.BacktrackEngine - -// NewBacktrackEngine returns a fresh backtrack engine. -func NewBacktrackEngine() *BacktrackEngine { - return engine.NewBacktrackEngine() -} diff --git a/internal/engine/backtrack.go b/internal/engine/control/backtrack.go similarity index 99% rename from internal/engine/backtrack.go rename to internal/engine/control/backtrack.go index 67786991..0f383720 100644 --- a/internal/engine/backtrack.go +++ b/internal/engine/control/backtrack.go @@ -1,4 +1,4 @@ -package engine +package control import ( "fmt" diff --git a/internal/engine/backtrack_test.go b/internal/engine/control/backtrack_test.go similarity index 99% rename from internal/engine/backtrack_test.go rename to internal/engine/control/backtrack_test.go index e46f612d..978f15a3 100644 --- a/internal/engine/backtrack_test.go +++ b/internal/engine/control/backtrack_test.go @@ -1,4 +1,4 @@ -package engine +package control import ( "strings" diff --git a/internal/engine/loop_detect.go b/internal/engine/control/loop_detect.go similarity index 99% rename from internal/engine/loop_detect.go rename to internal/engine/control/loop_detect.go index 4b7b0978..1d54cbd4 100644 --- a/internal/engine/loop_detect.go +++ b/internal/engine/control/loop_detect.go @@ -1,4 +1,4 @@ -package engine +package control import ( "crypto/sha256" diff --git a/internal/engine/loop_detect_test.go b/internal/engine/control/loop_detect_test.go similarity index 98% rename from internal/engine/loop_detect_test.go rename to internal/engine/control/loop_detect_test.go index e97f6b25..8cd28193 100644 --- a/internal/engine/loop_detect_test.go +++ b/internal/engine/control/loop_detect_test.go @@ -1,4 +1,4 @@ -package engine +package control import ( "fmt" diff --git a/internal/engine/stall_detector.go b/internal/engine/control/stall_detector.go similarity index 99% rename from internal/engine/stall_detector.go rename to internal/engine/control/stall_detector.go index 78486009..5bc6b0c7 100644 --- a/internal/engine/stall_detector.go +++ b/internal/engine/control/stall_detector.go @@ -1,4 +1,4 @@ -package engine +package control import ( "crypto/sha256" diff --git a/internal/engine/stall_detector_test.go b/internal/engine/control/stall_detector_test.go similarity index 99% rename from internal/engine/stall_detector_test.go rename to internal/engine/control/stall_detector_test.go index aae5e974..20567304 100644 --- a/internal/engine/stall_detector_test.go +++ b/internal/engine/control/stall_detector_test.go @@ -1,4 +1,4 @@ -package engine +package control import ( "fmt" diff --git a/internal/engine/control_reexports.go b/internal/engine/control_reexports.go new file mode 100644 index 00000000..9fdd8249 --- /dev/null +++ b/internal/engine/control_reexports.go @@ -0,0 +1,20 @@ +package engine + +import "github.com/GrayCodeAI/hawk/internal/engine/control" + +type ( + LoopDetector = control.LoopDetector + StallEntry = control.StallEntry + StallResult = control.StallResult + StallDetector = control.StallDetector + DecisionPoint = control.DecisionPoint + BacktrackEngine = control.BacktrackEngine +) + +const DoomLoopThreshold = control.DoomLoopThreshold + +var ( + NewLoopDetector = control.NewLoopDetector + NewStallDetector = control.NewStallDetector + NewBacktrackEngine = control.NewBacktrackEngine +) diff --git a/internal/engine/cost/aliases.go b/internal/engine/cost/aliases.go index 30b693bd..8170cc06 100644 --- a/internal/engine/cost/aliases.go +++ b/internal/engine/cost/aliases.go @@ -1,54 +1,3 @@ -// Package cost is the Stage-1 namespace for cost-tracking types and -// functions in package engine. See ../REFACTOR_PLAN.md. -// -// New code in hawk should import this package instead of reaching into -// engine for cost symbols. Implementation will move here in Stage 2. +// Package cost provides cost tracking, optimisation, and display +// for the hawk engine. See ../REFACTOR_PLAN.md. package cost - -import ( - "github.com/GrayCodeAI/hawk/internal/engine" - analytics "github.com/GrayCodeAI/hawk/internal/observability" -) - -// Cost is the canonical cost record (input/output tokens + USD). -type Cost = engine.Cost - -// Optimizer recommends cheaper models / shorter prompts when costs trend up. -type Optimizer = engine.CostOptimizer - -// Tracker accumulates per-session cost and persists it to analytics. -type Tracker = engine.CostTracker - -// RequestCost is the cost of a single LLM request. -type RequestCost = engine.RequestCost - -// ModelPrice is a per-million-token price tuple for a single model. -type ModelPrice = engine.ModelPrice - -// Recommendation is an Optimizer's suggested change. -type Recommendation = engine.Recommendation - -// NewOptimizer returns a fresh cost optimizer. -func NewOptimizer() *Optimizer { - return engine.NewCostOptimizer() -} - -// NewTracker returns a tracker scoped to the given session. -func NewTracker(sessionID string) *Tracker { - return engine.NewCostTracker(sessionID) -} - -// LoadHistory reads persisted cost entries from analytics storage. -func LoadHistory() ([]analytics.CostEntry, error) { - return engine.LoadCostHistory() -} - -// FormatDisplay renders a USD value for terminal display. -func FormatDisplay(totalUSD float64) string { - return engine.FormatCostDisplay(totalUSD) -} - -// ModelPricing returns input + output USD-per-million-token prices for a model. -func ModelPricing(modelName string) (inputPricePerM, outputPricePerM float64) { - return engine.ModelPricing(modelName) -} diff --git a/internal/engine/cost.go b/internal/engine/cost/cost.go similarity index 61% rename from internal/engine/cost.go rename to internal/engine/cost/cost.go index fd8d559c..85d07b21 100644 --- a/internal/engine/cost.go +++ b/internal/engine/cost/cost.go @@ -1,15 +1,10 @@ -package engine +package cost import ( "fmt" "sync" ) -func pricingForModel(model string) (float64, float64) { - return ModelPricing(model) -} - -// Cost tracks token usage and estimated cost. type Cost struct { mu sync.Mutex Model string @@ -20,42 +15,37 @@ type Cost struct { TotalCostUSD float64 } -// Add records token usage from a response. func (c *Cost) Add(prompt, completion int) { c.mu.Lock() defer c.mu.Unlock() c.PromptTokens += prompt c.CompletionTokens += completion - inPrice, outPrice := pricingForModel(c.Model) + inPrice, outPrice := ModelPricing(c.Model) c.TotalCostUSD += float64(prompt)*inPrice/1_000_000 + float64(completion)*outPrice/1_000_000 } -// AddCacheTokens records cache token usage (priced at ~10% of input). func (c *Cost) AddCacheTokens(read, write int) { c.mu.Lock() defer c.mu.Unlock() c.CacheReadTokens += read c.CacheWriteTokens += write - inPrice, _ := pricingForModel(c.Model) - c.TotalCostUSD += float64(read) * inPrice * 0.1 / 1_000_000 // cache reads are ~10% of input price - c.TotalCostUSD += float64(write) * inPrice * 1.25 / 1_000_000 // cache writes are ~125% of input price + inPrice, _ := ModelPricing(c.Model) + c.TotalCostUSD += float64(read) * inPrice * 0.1 / 1_000_000 + c.TotalCostUSD += float64(write) * inPrice * 1.25 / 1_000_000 } -// Total returns the estimated total cost in USD. func (c *Cost) Total() float64 { c.mu.Lock() defer c.mu.Unlock() return c.TotalCostUSD } -// TotalUSD returns the estimated total cost (same as Total — unified pricing). func (c *Cost) TotalUSD() float64 { c.mu.Lock() defer c.mu.Unlock() return c.TotalCostUSD } -// Summary returns a formatted cost string. func (c *Cost) Summary() string { c.mu.Lock() defer c.mu.Unlock() diff --git a/internal/engine/cost_display.go b/internal/engine/cost/cost_display.go similarity index 75% rename from internal/engine/cost_display.go rename to internal/engine/cost/cost_display.go index 0a0076b2..198b041a 100644 --- a/internal/engine/cost_display.go +++ b/internal/engine/cost/cost_display.go @@ -1,8 +1,7 @@ -package engine +package cost import "fmt" -// FormatCostDisplay returns a compact cost string for the status bar. func FormatCostDisplay(totalUSD float64) string { if totalUSD <= 0 { return "" diff --git a/internal/engine/cost/cost_display_test.go b/internal/engine/cost/cost_display_test.go new file mode 100644 index 00000000..c7f83fe7 --- /dev/null +++ b/internal/engine/cost/cost_display_test.go @@ -0,0 +1,30 @@ +package cost + +import ( + "testing" +) + +func TestFormatCostDisplay(t *testing.T) { + t.Parallel() + tests := []struct { + name string + cost float64 + want string + }{ + {"zero", 0, ""}, + {"negative", -1.0, ""}, + {"sub-cent", 0.005, "$0.0050"}, + {"cents", 0.15, "$0.150"}, + {"dollar", 2.5, "$2.50"}, + {"large", 100.0, "$100.00"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := FormatCostDisplay(tt.cost) + if got != tt.want { + t.Errorf("FormatCostDisplay(%f) = %q, want %q", tt.cost, got, tt.want) + } + }) + } +} diff --git a/internal/engine/cost_optimizer.go b/internal/engine/cost/cost_optimizer.go similarity index 82% rename from internal/engine/cost_optimizer.go rename to internal/engine/cost/cost_optimizer.go index ff2b8c6a..a0692c8e 100644 --- a/internal/engine/cost_optimizer.go +++ b/internal/engine/cost/cost_optimizer.go @@ -1,4 +1,4 @@ -package engine +package cost import ( "fmt" @@ -10,7 +10,6 @@ import ( "github.com/GrayCodeAI/hawk/internal/provider/routing" ) -// CostOptimizer analyzes usage patterns and suggests ways to reduce API costs. type CostOptimizer struct { History []RequestCost Recommendations []Recommendation @@ -18,20 +17,18 @@ type CostOptimizer struct { mu sync.RWMutex } -// RequestCost records the cost details of a single API request. type RequestCost struct { Model string Provider string InputTokens int OutputTokens int CostUSD float64 - TaskType string // "chat", "code", "review", "summarize" + TaskType string Duration time.Duration CacheHit bool Timestamp time.Time } -// ModelPrice holds the pricing for a model per million tokens. type ModelPrice struct { InputPerMillion float64 OutputPerMillion float64 @@ -39,16 +36,14 @@ type ModelPrice struct { CacheWritePerMillion float64 } -// Recommendation represents a cost optimization recommendation. type Recommendation struct { - Type string // "model_switch", "caching", "compression", "batching" + Type string Description string - EstimatedSavings float64 // USD per day - Priority string // "high", "medium", "low" - Action string // what to do + EstimatedSavings float64 + Priority string + Action string } -// NewCostOptimizer creates a CostOptimizer with default pricing for common models. func NewCostOptimizer() *CostOptimizer { return &CostOptimizer{ History: make([]RequestCost, 0), @@ -84,18 +79,34 @@ func NewCostOptimizer() *CostOptimizer { CacheReadPerMillion: 0.015, CacheWritePerMillion: 0.1875, }, + "tier:opus": { + InputPerMillion: 15.0, + OutputPerMillion: 75.0, + CacheReadPerMillion: 1.5, + CacheWritePerMillion: 18.75, + }, + "tier:sonnet": { + InputPerMillion: 3.0, + OutputPerMillion: 15.0, + CacheReadPerMillion: 0.3, + CacheWritePerMillion: 3.75, + }, + "tier:haiku": { + InputPerMillion: 0.25, + OutputPerMillion: 1.25, + CacheReadPerMillion: 0.025, + CacheWritePerMillion: 0.3125, + }, }, } } -// Record adds a RequestCost entry to the history. func (co *CostOptimizer) Record(cost RequestCost) { co.mu.Lock() defer co.mu.Unlock() co.History = append(co.History, cost) } -// Analyze scans the history for optimization opportunities and returns recommendations. func (co *CostOptimizer) Analyze() []Recommendation { co.mu.Lock() defer co.mu.Unlock() @@ -107,29 +118,22 @@ func (co *CostOptimizer) Analyze() []Recommendation { return recommendations } - // Model downgrade analysis: simple tasks on expensive models recommendations = append(recommendations, co.analyzeModelDowngrade()...) - // Caching analysis: repeated requests without cache hits recommendations = append(recommendations, co.analyzeCaching()...) - // Compression analysis: large input tokens recommendations = append(recommendations, co.analyzeCompression()...) - // Batching analysis: many small sequential calls recommendations = append(recommendations, co.analyzeBatching()...) - // Time-of-day analysis: recommend scheduling non-urgent work recommendations = append(recommendations, co.analyzeScheduling()...) - // Token reduction: output tokens consistently high recommendations = append(recommendations, co.analyzeTokenReduction()...) co.Recommendations = recommendations return recommendations } -// analyzeModelDowngrade checks if simple tasks use expensive models. func (co *CostOptimizer) analyzeModelDowngrade() []Recommendation { var recs []Recommendation @@ -148,7 +152,6 @@ func (co *CostOptimizer) analyzeModelDowngrade() []Recommendation { } if expensiveSimpleCount > 0 { - // Estimate savings: assume haiku would cost ~5% of opus, ~8% of sonnet estimatedSavings := expensiveSimpleCost * 0.9 days := co.historyDays() if days < 1 { @@ -168,7 +171,6 @@ func (co *CostOptimizer) analyzeModelDowngrade() []Recommendation { return recs } -// analyzeCaching checks for repeated requests without cache hits. func (co *CostOptimizer) analyzeCaching() []Recommendation { var recs []Recommendation @@ -188,7 +190,6 @@ func (co *CostOptimizer) analyzeCaching() []Recommendation { if totalRequests > 5 { cacheRate := float64(cacheHits) / float64(totalRequests) if cacheRate < 0.3 { - // Could save ~80% on input costs with caching estimatedSavings := totalInputCost * 0.8 days := co.historyDays() if days < 1 { @@ -209,7 +210,6 @@ func (co *CostOptimizer) analyzeCaching() []Recommendation { return recs } -// analyzeCompression checks if average input tokens exceed threshold. func (co *CostOptimizer) analyzeCompression() []Recommendation { var recs []Recommendation @@ -231,7 +231,6 @@ func (co *CostOptimizer) analyzeCompression() []Recommendation { avgInput := totalInput / len(co.History) if avgInput > 5000 && largeInputCount > 0 { - // Compression could save ~30% on large inputs estimatedSavings := largeInputCost * 0.3 days := co.historyDays() if days < 1 { @@ -251,7 +250,6 @@ func (co *CostOptimizer) analyzeCompression() []Recommendation { return recs } -// analyzeBatching checks for many small sequential calls. func (co *CostOptimizer) analyzeBatching() []Recommendation { var recs []Recommendation @@ -259,14 +257,12 @@ func (co *CostOptimizer) analyzeBatching() []Recommendation { return recs } - // Sort by timestamp sorted := make([]RequestCost, len(co.History)) copy(sorted, co.History) sort.Slice(sorted, func(i, j int) bool { return sorted[i].Timestamp.Before(sorted[j].Timestamp) }) - // Look for clusters of small requests within 60 seconds var batchableCount int var batchableCost float64 for i := 1; i < len(sorted); i++ { @@ -278,7 +274,6 @@ func (co *CostOptimizer) analyzeBatching() []Recommendation { } if batchableCount >= 5 { - // Batching overhead reduction ~20% estimatedSavings := batchableCost * 0.2 days := co.historyDays() if days < 1 { @@ -298,7 +293,6 @@ func (co *CostOptimizer) analyzeBatching() []Recommendation { return recs } -// analyzeScheduling recommends off-peak scheduling for non-urgent work. func (co *CostOptimizer) analyzeScheduling() []Recommendation { var recs []Recommendation @@ -306,7 +300,6 @@ func (co *CostOptimizer) analyzeScheduling() []Recommendation { return recs } - // Check what fraction of requests happen during business hours (9-17) var peakCount int var totalCost float64 for _, rc := range co.History { @@ -319,8 +312,7 @@ func (co *CostOptimizer) analyzeScheduling() []Recommendation { peakRatio := float64(peakCount) / float64(len(co.History)) if peakRatio > 0.7 { - // Batch API is ~50% cheaper for non-urgent work - nonUrgentFraction := 0.3 // assume 30% of work is non-urgent + nonUrgentFraction := 0.3 estimatedSavings := totalCost * nonUrgentFraction * 0.5 days := co.historyDays() if days < 1 { @@ -333,14 +325,13 @@ func (co *CostOptimizer) analyzeScheduling() []Recommendation { Description: fmt.Sprintf("%.0f%% of requests during peak hours — schedule non-urgent work off-peak", peakRatio*100), EstimatedSavings: dailySavings, Priority: "low", - Action: "Use batch API for non-urgent tasks to take advantage of 50% cost reduction", + Action: "Use batch API for non-urgent tasks to take advantage of 50%% cost reduction", }) } return recs } -// analyzeTokenReduction checks if output tokens are consistently high. func (co *CostOptimizer) analyzeTokenReduction() []Recommendation { var recs []Recommendation @@ -361,7 +352,6 @@ func (co *CostOptimizer) analyzeTokenReduction() []Recommendation { highRatio := float64(highOutputCount) / float64(len(co.History)) if highRatio > 0.5 { - // Shorter system prompts could reduce output by ~20% estimatedSavings := totalOutputCost * 0.2 days := co.historyDays() if days < 1 { @@ -381,7 +371,6 @@ func (co *CostOptimizer) analyzeTokenReduction() []Recommendation { return recs } -// DailyCost returns the sum of costs from the last 24 hours. func (co *CostOptimizer) DailyCost() float64 { co.mu.RLock() defer co.mu.RUnlock() @@ -396,7 +385,6 @@ func (co *CostOptimizer) DailyCost() float64 { return total } -// WeeklyCost returns the sum of costs from the last 7 days. func (co *CostOptimizer) WeeklyCost() float64 { co.mu.RLock() defer co.mu.RUnlock() @@ -411,7 +399,6 @@ func (co *CostOptimizer) WeeklyCost() float64 { return total } -// CostByModel returns a map of model name to total cost. func (co *CostOptimizer) CostByModel() map[string]float64 { co.mu.RLock() defer co.mu.RUnlock() @@ -423,7 +410,6 @@ func (co *CostOptimizer) CostByModel() map[string]float64 { return result } -// CostByTaskType returns a map of task type to total cost. func (co *CostOptimizer) CostByTaskType() map[string]float64 { co.mu.RLock() defer co.mu.RUnlock() @@ -435,7 +421,6 @@ func (co *CostOptimizer) CostByTaskType() map[string]float64 { return result } -// ProjectSavings calculates the total estimated daily savings if all recommendations are applied. func (co *CostOptimizer) ProjectSavings(recommendations []Recommendation) float64 { var total float64 for _, r := range recommendations { @@ -444,7 +429,6 @@ func (co *CostOptimizer) ProjectSavings(recommendations []Recommendation) float6 return total } -// FormatReport generates a formatted cost report string. func (co *CostOptimizer) FormatReport() string { co.mu.RLock() defer co.mu.RUnlock() @@ -456,7 +440,6 @@ func (co *CostOptimizer) FormatReport() string { b.WriteString("Cost Report (Last 7 Days):\n") b.WriteString(fmt.Sprintf("Total: $%.2f\n", weeklyCost)) - // By Model modelCosts := make(map[string]float64) for _, rc := range co.History { cutoff := time.Now().Add(-7 * 24 * time.Hour) @@ -477,7 +460,6 @@ func (co *CostOptimizer) FormatReport() string { } } - // By Task taskCosts := make(map[string]float64) for _, rc := range co.History { cutoff := time.Now().Add(-7 * 24 * time.Hour) @@ -498,18 +480,17 @@ func (co *CostOptimizer) FormatReport() string { } } - // Recommendations if len(co.Recommendations) > 0 { b.WriteString("\nRecommendations:\n") for _, rec := range co.Recommendations { - icon := "\U0001f535" // blue circle + icon := "\U0001f535" label := "LOW" switch rec.Priority { case "high": - icon = "\U0001f7e2" // green circle + icon = "\U0001f7e2" label = "HIGH" case "medium": - icon = "\U0001f7e1" // yellow circle + icon = "\U0001f7e1" label = "MED" } b.WriteString(fmt.Sprintf("%s %s: %s (saves ~$%.2f/day)\n", icon, label, rec.Action, rec.EstimatedSavings)) @@ -523,7 +504,6 @@ func (co *CostOptimizer) FormatReport() string { return b.String() } -// WhatIf calculates what the total cost would have been if all requests used the given model. func (co *CostOptimizer) WhatIf(model string) float64 { co.mu.RLock() defer co.mu.RUnlock() @@ -538,8 +518,6 @@ func (co *CostOptimizer) WhatIf(model string) float64 { return total } -// Helper methods - func (co *CostOptimizer) normalizeModel(model string) string { if info, ok := routing.Find(model); ok && info.Name != "" { return info.Name @@ -608,7 +586,6 @@ func (co *CostOptimizer) projectSavingsLocked(recommendations []Recommendation) return total } -// keyValue is used for sorting maps by value. type keyValue struct { Key string Value float64 @@ -620,7 +597,7 @@ func sortMapByValue(m map[string]float64) []keyValue { kvs = append(kvs, keyValue{k, v}) } sort.Slice(kvs, func(i, j int) bool { - return kvs[i].Value > kvs[j].Value // descending + return kvs[i].Value > kvs[j].Value }) return kvs } diff --git a/internal/engine/cost_optimizer_test.go b/internal/engine/cost/cost_optimizer_test.go similarity index 93% rename from internal/engine/cost_optimizer_test.go rename to internal/engine/cost/cost_optimizer_test.go index 8bac46e5..c4d32d15 100644 --- a/internal/engine/cost_optimizer_test.go +++ b/internal/engine/cost/cost_optimizer_test.go @@ -1,11 +1,27 @@ -package engine +package cost import ( "strings" "testing" "time" + + eycatalog "github.com/GrayCodeAI/eyrie/catalog" + "github.com/GrayCodeAI/hawk/internal/provider/routing" ) +const testProvider = "anthropic" + +func testTierModels(t *testing.T, provider string) (haiku, sonnet, opus string) { + t.Helper() + haiku = routing.PreferredModelForTier(provider, eycatalog.TierHaiku, "") + sonnet = routing.PreferredModelForTier(provider, eycatalog.TierSonnet, "") + opus = routing.PreferredModelForTier(provider, eycatalog.TierOpus, "") + if haiku == "" || sonnet == "" || opus == "" { + t.Fatalf("eyrie catalog missing tier models for provider %q", provider) + } + return haiku, sonnet, opus +} + func TestNewCostOptimizer(t *testing.T) { co := NewCostOptimizer() if co == nil { @@ -17,8 +33,8 @@ func TestNewCostOptimizer(t *testing.T) { if len(co.Recommendations) != 0 { t.Errorf("expected empty recommendations, got %d", len(co.Recommendations)) } - if len(co.ModelPricing) != 5 { - t.Errorf("expected 5 model pricings, got %d", len(co.ModelPricing)) + if len(co.ModelPricing) != 8 { + t.Errorf("expected 8 model pricings, got %d", len(co.ModelPricing)) } // Verify specific pricing @@ -199,8 +215,8 @@ func TestWhatIf(t *testing.T) { if haikuCost <= 0 || sonnetCost <= 0 { t.Fatalf("WhatIf returned non-positive costs: haiku=%.4f sonnet=%.4f", haikuCost, sonnetCost) } - if haikuCost >= sonnetCost { - t.Errorf("WhatIf haiku (%.4f) should be cheaper than sonnet (%.4f)", haikuCost, sonnetCost) + if haikuCost > sonnetCost { + t.Errorf("WhatIf haiku (%.4f) should not be more expensive than sonnet (%.4f)", haikuCost, sonnetCost) } } @@ -435,8 +451,8 @@ func TestWhatIfAllModels(t *testing.T) { if haikuCost <= 0 || sonnetCost <= 0 { t.Fatalf("WhatIf returned non-positive: haiku=%f sonnet=%f", haikuCost, sonnetCost) } - if haikuCost >= sonnetCost { - t.Errorf("WhatIf haiku (%.4f) should be cheaper than sonnet (%.4f)", haikuCost, sonnetCost) + if haikuCost > sonnetCost { + t.Errorf("WhatIf haiku (%.4f) should not be more expensive than sonnet (%.4f)", haikuCost, sonnetCost) } } diff --git a/internal/engine/cost/cost_table.go b/internal/engine/cost/cost_table.go new file mode 100644 index 00000000..4b6cf889 --- /dev/null +++ b/internal/engine/cost/cost_table.go @@ -0,0 +1,24 @@ +package cost + +import "github.com/GrayCodeAI/hawk/internal/provider/routing" + +// Tier-based default pricing when catalog data is unavailable. +var tierDefaults = map[routing.CostTier][2]float64{ + routing.CostTierCheap: {0.15, 0.60}, + routing.CostTierMid: {3.0, 15.0}, + routing.CostTierExpensive: {15.0, 75.0}, +} + +func ModelPricing(modelName string) (inputPricePerM, outputPricePerM float64) { + info, ok := routing.Find(modelName) + if ok && (info.InputPrice > 0 || info.OutputPrice > 0) { + return info.InputPrice, info.OutputPrice + } + // Fall back to tier-based defaults so routing decisions still produce + // meaningful cost estimates even when the catalog lacks pricing data. + tier := routing.CostTierOf(modelName) + if d, ok := tierDefaults[tier]; ok { + return d[0], d[1] + } + return 3.0, 15.0 +} diff --git a/internal/engine/cost_tracker.go b/internal/engine/cost/cost_tracker.go similarity index 82% rename from internal/engine/cost_tracker.go rename to internal/engine/cost/cost_tracker.go index 73342579..40ba9381 100644 --- a/internal/engine/cost_tracker.go +++ b/internal/engine/cost/cost_tracker.go @@ -1,4 +1,4 @@ -package engine +package cost import ( "encoding/json" @@ -10,8 +10,6 @@ import ( analytics "github.com/GrayCodeAI/hawk/internal/observability" ) -// CostTracker records per-request cost entries for analytics and optimization. -// Data is appended to ~/.hawk/cost.jsonl for cross-session analysis. type CostTracker struct { mu sync.Mutex entries []analytics.CostEntry @@ -19,7 +17,6 @@ type CostTracker struct { filePath string } -// NewCostTracker creates a tracker that persists to ~/.hawk/cost.jsonl. func NewCostTracker(sessionID string) *CostTracker { home, _ := os.UserHomeDir() return &CostTracker{ @@ -28,7 +25,6 @@ func NewCostTracker(sessionID string) *CostTracker { } } -// Record adds a cost entry and persists it. func (ct *CostTracker) Record(entry analytics.CostEntry) error { ct.mu.Lock() defer ct.mu.Unlock() @@ -40,7 +36,6 @@ func (ct *CostTracker) Record(entry analytics.CostEntry) error { return ct.appendToFile(entry) } -// SessionTotal returns total USD spent in the current session. func (ct *CostTracker) SessionTotal() float64 { ct.mu.Lock() defer ct.mu.Unlock() @@ -51,7 +46,6 @@ func (ct *CostTracker) SessionTotal() float64 { return total } -// Entries returns all recorded entries for this session. func (ct *CostTracker) Entries() []analytics.CostEntry { ct.mu.Lock() defer ct.mu.Unlock() @@ -60,7 +54,6 @@ func (ct *CostTracker) Entries() []analytics.CostEntry { return out } -// LoadHistory reads all historical cost entries from the JSONL file. func LoadCostHistory() ([]analytics.CostEntry, error) { home, _ := os.UserHomeDir() path := filepath.Join(home, ".hawk", "cost.jsonl") diff --git a/internal/engine/cost_tracker_test.go b/internal/engine/cost/cost_tracker_test.go similarity index 99% rename from internal/engine/cost_tracker_test.go rename to internal/engine/cost/cost_tracker_test.go index 5a89bbd5..1753509c 100644 --- a/internal/engine/cost_tracker_test.go +++ b/internal/engine/cost/cost_tracker_test.go @@ -1,4 +1,4 @@ -package engine +package cost import ( "os" diff --git a/internal/engine/cost_display_test.go b/internal/engine/cost_display_test.go deleted file mode 100644 index afd296bd..00000000 --- a/internal/engine/cost_display_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package engine - -import ( - "context" - "strings" - "testing" - "time" -) - -func TestFormatCostDisplay(t *testing.T) { - t.Parallel() - tests := []struct { - name string - cost float64 - want string - }{ - {"zero", 0, ""}, - {"negative", -1.0, ""}, - {"sub-cent", 0.005, "$0.0050"}, - {"cents", 0.15, "$0.150"}, - {"dollar", 2.5, "$2.50"}, - {"large", 100.0, "$100.00"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := FormatCostDisplay(tt.cost) - if got != tt.want { - t.Errorf("FormatCostDisplay(%f) = %q, want %q", tt.cost, got, tt.want) - } - }) - } -} - -func TestDefaultTimeoutConfig(t *testing.T) { - t.Parallel() - cfg := DefaultTimeoutConfig() - if cfg.PerTurn != 60*time.Second { - t.Errorf("PerTurn = %v, want 60s", cfg.PerTurn) - } - if cfg.PerTool != 120*time.Second { - t.Errorf("PerTool = %v, want 120s", cfg.PerTool) - } - if cfg.Total != 0 { - t.Errorf("Total = %v, want 0 (no default deadline)", cfg.Total) - } -} - -func TestWithTimeout(t *testing.T) { - t.Parallel() - - t.Run("with total", func(t *testing.T) { - t.Parallel() - cfg := TimeoutConfig{Total: 5 * time.Second} - ctx, cancel := WithTimeout(context.Background(), cfg) - defer cancel() - - deadline, ok := ctx.Deadline() - if !ok { - t.Error("context should have deadline") - } - if time.Until(deadline) > 6*time.Second { - t.Error("deadline too far in future") - } - }) - - t.Run("without total", func(t *testing.T) { - t.Parallel() - cfg := TimeoutConfig{Total: 0} - ctx, cancel := WithTimeout(context.Background(), cfg) - defer cancel() - - _, ok := ctx.Deadline() - if ok { - t.Error("context should not have deadline when Total=0") - } - }) -} - -func TestRemainingTime(t *testing.T) { - t.Parallel() - cfg := TimeoutConfig{Total: 10 * time.Second} - ctx, cancel := WithTimeout(context.Background(), cfg) - defer cancel() - - remaining := RemainingTime(ctx) - if remaining == "" { - t.Error("RemainingTime should return non-empty string") - } - if !strings.Contains(remaining, "s") && !strings.Contains(remaining, "m") { - t.Errorf("RemainingTime() = %q, expected time unit", remaining) - } -} - -func TestRemainingTime_WithoutDeadline(t *testing.T) { - t.Parallel() - remaining := RemainingTime(context.Background()) - if remaining != "" { - t.Errorf("RemainingTime() = %q, want empty for no deadline", remaining) - } -} diff --git a/internal/engine/cost_reexports.go b/internal/engine/cost_reexports.go new file mode 100644 index 00000000..9d3decda --- /dev/null +++ b/internal/engine/cost_reexports.go @@ -0,0 +1,23 @@ +package engine + +import ( + "github.com/GrayCodeAI/hawk/internal/engine/cost" + analytics "github.com/GrayCodeAI/hawk/internal/observability" +) + +type ( + Cost = cost.Cost + CostOptimizer = cost.CostOptimizer + CostTracker = cost.CostTracker + RequestCost = cost.RequestCost + ModelPrice = cost.ModelPrice + Recommendation = cost.Recommendation +) + +func NewCostOptimizer() *CostOptimizer { return cost.NewCostOptimizer() } +func NewCostTracker(sessionID string) *CostTracker { return cost.NewCostTracker(sessionID) } +func LoadCostHistory() ([]analytics.CostEntry, error) { return cost.LoadCostHistory() } +func FormatCostDisplay(totalUSD float64) string { return cost.FormatCostDisplay(totalUSD) } +func ModelPricing(modelName string) (inputPricePerM, outputPricePerM float64) { + return cost.ModelPricing(modelName) +} diff --git a/internal/engine/cost_table.go b/internal/engine/cost_table.go deleted file mode 100644 index 28a0fb21..00000000 --- a/internal/engine/cost_table.go +++ /dev/null @@ -1,12 +0,0 @@ -package engine - -import "github.com/GrayCodeAI/hawk/internal/provider/routing" - -// ModelPricing returns input/output price per million tokens for a model. -func ModelPricing(modelName string) (inputPricePerM, outputPricePerM float64) { - info, ok := routing.Find(modelName) - if !ok { - return 3.0, 15.0 // conservative default - } - return info.InputPrice, info.OutputPrice -} diff --git a/internal/engine/ctxmgr/aliases.go b/internal/engine/ctxmgr/aliases.go index b5957be2..07cac927 100644 --- a/internal/engine/ctxmgr/aliases.go +++ b/internal/engine/ctxmgr/aliases.go @@ -1,59 +1,5 @@ -// Package ctxmgr is the Stage-1 namespace for context budget, decay, packing, +// Package ctxmgr is the namespace for context budget, decay, packing, // providers, visualisation, and read-only context. See ../REFACTOR_PLAN.md. // // Named "ctxmgr" (not "context") to avoid shadowing the stdlib context package. package ctxmgr - -import ( - "time" - - "github.com/GrayCodeAI/hawk/internal/engine" -) - -type ( - ContextBudget = engine.ContextBudget - ContextAllocation = engine.ContextAllocation - ContextDecay = engine.ContextDecay - DecayEntry = engine.DecayEntry - DecayStats = engine.DecayStats - PackingStrategy = engine.PackingStrategy - ContextPacker = engine.ContextPacker - ScoredMessage = engine.ScoredMessage - PackingResult = engine.PackingResult - ContextProvider = engine.ContextProvider - ContextItem = engine.ContextItem - ContextManager = engine.ContextManager - GitContextProvider = engine.GitContextProvider - FileContextProvider = engine.FileContextProvider - ErrorContextProvider = engine.ErrorContextProvider - DependencyContextProvider = engine.DependencyContextProvider - ContextVisualizer = engine.ContextVisualizer - ContextSection = engine.ContextSection - VizContextItem = engine.VizContextItem - ContextSnapshot = engine.ContextSnapshot - ReadOnlyContext = engine.ReadOnlyContext - ContextFile = engine.ContextFile - ContextFileOption = engine.ContextFileOption - ContextStats = engine.ContextStats -) - -func NewContextBudget(contextSize int) *ContextBudget { return engine.NewContextBudget(contextSize) } - -func NewContextDecay(halfLife time.Duration) *ContextDecay { return engine.NewContextDecay(halfLife) } - -func NewContextPacker(maxTokens int) *ContextPacker { return engine.NewContextPacker(maxTokens) } - -func NewContextManager(budget int) *ContextManager { return engine.NewContextManager(budget) } - -func NewContextVisualizer(max int) *ContextVisualizer { return engine.NewContextVisualizer(max) } - -func NewReadOnlyContext(maxBudget int) *ReadOnlyContext { return engine.NewReadOnlyContext(maxBudget) } - -func FormatContextItems(items []ContextItem) string { return engine.FormatContextItems(items) } - -func PrioritizeItems(items []ContextItem, budget int) []ContextItem { - return engine.PrioritizeItems(items, budget) -} -func SuggestFiles(projectDir string) []string { return engine.SuggestFiles(projectDir) } -func WithPinned() ContextFileOption { return engine.WithPinned() } -func WithAutoRefresh() ContextFileOption { return engine.WithAutoRefresh() } diff --git a/internal/engine/context_budget.go b/internal/engine/ctxmgr/context_budget.go similarity index 99% rename from internal/engine/context_budget.go rename to internal/engine/ctxmgr/context_budget.go index b5a78bc6..5c5c7598 100644 --- a/internal/engine/context_budget.go +++ b/internal/engine/ctxmgr/context_budget.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "fmt" diff --git a/internal/engine/context_budget_test.go b/internal/engine/ctxmgr/context_budget_test.go similarity index 99% rename from internal/engine/context_budget_test.go rename to internal/engine/ctxmgr/context_budget_test.go index 01a4687e..0308cc24 100644 --- a/internal/engine/context_budget_test.go +++ b/internal/engine/ctxmgr/context_budget_test.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "strings" diff --git a/internal/engine/context_collapse.go b/internal/engine/ctxmgr/context_collapse.go similarity index 99% rename from internal/engine/context_collapse.go rename to internal/engine/ctxmgr/context_collapse.go index c4fa5777..f83806f9 100644 --- a/internal/engine/context_collapse.go +++ b/internal/engine/ctxmgr/context_collapse.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "fmt" diff --git a/internal/engine/context_collapse_test.go b/internal/engine/ctxmgr/context_collapse_test.go similarity index 99% rename from internal/engine/context_collapse_test.go rename to internal/engine/ctxmgr/context_collapse_test.go index 69a0beee..799a556f 100644 --- a/internal/engine/context_collapse_test.go +++ b/internal/engine/ctxmgr/context_collapse_test.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "testing" diff --git a/internal/engine/context_decay.go b/internal/engine/ctxmgr/context_decay.go similarity index 99% rename from internal/engine/context_decay.go rename to internal/engine/ctxmgr/context_decay.go index e97aef1a..973a82ce 100644 --- a/internal/engine/context_decay.go +++ b/internal/engine/ctxmgr/context_decay.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "fmt" diff --git a/internal/engine/context_decay_test.go b/internal/engine/ctxmgr/context_decay_test.go similarity index 99% rename from internal/engine/context_decay_test.go rename to internal/engine/ctxmgr/context_decay_test.go index 588d0d6a..30d900f1 100644 --- a/internal/engine/context_decay_test.go +++ b/internal/engine/ctxmgr/context_decay_test.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "strings" diff --git a/internal/engine/context_packer.go b/internal/engine/ctxmgr/context_packer.go similarity index 99% rename from internal/engine/context_packer.go rename to internal/engine/ctxmgr/context_packer.go index 1b65cde8..5348839a 100644 --- a/internal/engine/context_packer.go +++ b/internal/engine/ctxmgr/context_packer.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "fmt" diff --git a/internal/engine/context_packer_test.go b/internal/engine/ctxmgr/context_packer_test.go similarity index 99% rename from internal/engine/context_packer_test.go rename to internal/engine/ctxmgr/context_packer_test.go index 679cf699..1bc517a8 100644 --- a/internal/engine/context_packer_test.go +++ b/internal/engine/ctxmgr/context_packer_test.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "math" diff --git a/internal/engine/context_providers.go b/internal/engine/ctxmgr/context_providers.go similarity index 99% rename from internal/engine/context_providers.go rename to internal/engine/ctxmgr/context_providers.go index 767236cf..e7847d1d 100644 --- a/internal/engine/context_providers.go +++ b/internal/engine/ctxmgr/context_providers.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "context" diff --git a/internal/engine/context_providers_test.go b/internal/engine/ctxmgr/context_providers_test.go similarity index 99% rename from internal/engine/context_providers_test.go rename to internal/engine/ctxmgr/context_providers_test.go index 461b8e6f..2daaf502 100644 --- a/internal/engine/context_providers_test.go +++ b/internal/engine/ctxmgr/context_providers_test.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "context" diff --git a/internal/engine/context_viz.go b/internal/engine/ctxmgr/context_viz.go similarity index 99% rename from internal/engine/context_viz.go rename to internal/engine/ctxmgr/context_viz.go index 8c4df59f..63e1cef2 100644 --- a/internal/engine/context_viz.go +++ b/internal/engine/ctxmgr/context_viz.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "fmt" diff --git a/internal/engine/context_viz_test.go b/internal/engine/ctxmgr/context_viz_test.go similarity index 99% rename from internal/engine/context_viz_test.go rename to internal/engine/ctxmgr/context_viz_test.go index 5b7984a0..df507940 100644 --- a/internal/engine/context_viz_test.go +++ b/internal/engine/ctxmgr/context_viz_test.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "strings" diff --git a/internal/engine/readonly_context.go b/internal/engine/ctxmgr/readonly_context.go similarity index 99% rename from internal/engine/readonly_context.go rename to internal/engine/ctxmgr/readonly_context.go index 8517593b..98d901f2 100644 --- a/internal/engine/readonly_context.go +++ b/internal/engine/ctxmgr/readonly_context.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "fmt" diff --git a/internal/engine/readonly_context_test.go b/internal/engine/ctxmgr/readonly_context_test.go similarity index 99% rename from internal/engine/readonly_context_test.go rename to internal/engine/ctxmgr/readonly_context_test.go index 14c279c9..84822d7d 100644 --- a/internal/engine/readonly_context_test.go +++ b/internal/engine/ctxmgr/readonly_context_test.go @@ -1,4 +1,4 @@ -package engine +package ctxmgr import ( "fmt" diff --git a/internal/engine/diff/aliases.go b/internal/engine/diff/aliases.go index 69da1c12..0c2c471c 100644 --- a/internal/engine/diff/aliases.go +++ b/internal/engine/diff/aliases.go @@ -2,48 +2,8 @@ // summariser, test selector, and 3-way merge. See ../REFACTOR_PLAN.md. package diff -import "github.com/GrayCodeAI/hawk/internal/engine" - type ( - PendingChange = engine.PendingChange - DiffSandbox = engine.DiffSandbox - StagingArea = engine.StagingArea - StagedChange = engine.StagedChange - StagedHunk = engine.StagedHunk - Preview = engine.DiffPreview - FileChange = engine.FileChange - Hunk = engine.DiffHunk - Line = engine.DiffLine - ChangeStats = engine.ChangeStats - Summary = engine.DiffSummary - FileSummary = engine.FileSummary - Summarizer = engine.DiffSummarizer - TestSelector = engine.TestSelector - SelectedTests = engine.SelectedTests - Diff3Result = engine.Diff3Result - Diff3Conflict = engine.Diff3Conflict - Diff3Stats = engine.Diff3Stats - Diff3Region = engine.Diff3Region - Edit = engine.Edit + Preview = DiffPreview + Summary = DiffSummary + Summarizer = DiffSummarizer ) - -func NewDiffSandbox() *DiffSandbox { return engine.NewDiffSandbox() } -func NewStagingArea() *StagingArea { return engine.NewStagingArea() } -func NewDiffPreview() *Preview { return engine.NewDiffPreview() } -func NewSummarizer() *Summarizer { return engine.NewDiffSummarizer() } -func NewTestSelector(projectDir string) *TestSelector { return engine.NewTestSelector(projectDir) } -func ComputeDiff(old, new string) []Hunk { return engine.ComputeDiff(old, new) } -func ComputeMyersDiff(a, b []string) []Line { return engine.ComputeMyersDiff(a, b) } -func RenderUnified(change *FileChange) string { return engine.RenderUnified(change) } -func Merge3(base, ours, theirs string) *Diff3Result { return engine.Merge3(base, ours, theirs) } -func MergeClean(base, ours, theirs string) (string, bool) { - return engine.MergeClean(base, ours, theirs) -} -func FormatConflictMarkers(c Diff3Conflict) string { return engine.FormatConflictMarkers(c) } -func LCS(a, b []string) []string { return engine.LCS(a, b) } -func EditScript(from, to []string) []Edit { return engine.EditScript(from, to) } -func BuildDependencyGraph(dir string) map[string][]string { return engine.BuildDependencyGraph(dir) } - -func GenerateTestCommand(s *SelectedTests, lang string) string { - return engine.GenerateTestCommand(s, lang) -} diff --git a/internal/engine/diff3.go b/internal/engine/diff/diff3.go similarity index 99% rename from internal/engine/diff3.go rename to internal/engine/diff/diff3.go index f59f9e80..e504b1b2 100644 --- a/internal/engine/diff3.go +++ b/internal/engine/diff/diff3.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "fmt" diff --git a/internal/engine/diff3_test.go b/internal/engine/diff/diff3_test.go similarity index 99% rename from internal/engine/diff3_test.go rename to internal/engine/diff/diff3_test.go index eb8d5a7f..4a99d6c4 100644 --- a/internal/engine/diff3_test.go +++ b/internal/engine/diff/diff3_test.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "strings" diff --git a/internal/engine/diff_preview.go b/internal/engine/diff/diff_preview.go similarity index 99% rename from internal/engine/diff_preview.go rename to internal/engine/diff/diff_preview.go index 5ccf99c2..1b8c921b 100644 --- a/internal/engine/diff_preview.go +++ b/internal/engine/diff/diff_preview.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "crypto/rand" diff --git a/internal/engine/diff_preview_test.go b/internal/engine/diff/diff_preview_test.go similarity index 99% rename from internal/engine/diff_preview_test.go rename to internal/engine/diff/diff_preview_test.go index 97ef5868..c1db48e9 100644 --- a/internal/engine/diff_preview_test.go +++ b/internal/engine/diff/diff_preview_test.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "strings" diff --git a/internal/engine/diff_staging.go b/internal/engine/diff/diff_staging.go similarity index 99% rename from internal/engine/diff_staging.go rename to internal/engine/diff/diff_staging.go index 45c3e673..5292c059 100644 --- a/internal/engine/diff_staging.go +++ b/internal/engine/diff/diff_staging.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "fmt" diff --git a/internal/engine/diff_staging_test.go b/internal/engine/diff/diff_staging_test.go similarity index 99% rename from internal/engine/diff_staging_test.go rename to internal/engine/diff/diff_staging_test.go index 1e377674..36ff4370 100644 --- a/internal/engine/diff_staging_test.go +++ b/internal/engine/diff/diff_staging_test.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "os" diff --git a/internal/engine/diff_summarizer.go b/internal/engine/diff/diff_summarizer.go similarity index 99% rename from internal/engine/diff_summarizer.go rename to internal/engine/diff/diff_summarizer.go index 51f54c88..4cfdc227 100644 --- a/internal/engine/diff_summarizer.go +++ b/internal/engine/diff/diff_summarizer.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "fmt" diff --git a/internal/engine/diff_summarizer_test.go b/internal/engine/diff/diff_summarizer_test.go similarity index 99% rename from internal/engine/diff_summarizer_test.go rename to internal/engine/diff/diff_summarizer_test.go index 37335804..e80d05ff 100644 --- a/internal/engine/diff_summarizer_test.go +++ b/internal/engine/diff/diff_summarizer_test.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "strings" diff --git a/internal/engine/diff_test_selector.go b/internal/engine/diff/diff_test_selector.go similarity index 99% rename from internal/engine/diff_test_selector.go rename to internal/engine/diff/diff_test_selector.go index fef309f2..dfc99494 100644 --- a/internal/engine/diff_test_selector.go +++ b/internal/engine/diff/diff_test_selector.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "bufio" diff --git a/internal/engine/diff_test_selector_test.go b/internal/engine/diff/diff_test_selector_test.go similarity index 99% rename from internal/engine/diff_test_selector_test.go rename to internal/engine/diff/diff_test_selector_test.go index bd69ed41..c566bbee 100644 --- a/internal/engine/diff_test_selector_test.go +++ b/internal/engine/diff/diff_test_selector_test.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "os" diff --git a/internal/engine/diffsandbox.go b/internal/engine/diff/diffsandbox.go similarity index 99% rename from internal/engine/diffsandbox.go rename to internal/engine/diff/diffsandbox.go index 97014e63..211fb31d 100644 --- a/internal/engine/diffsandbox.go +++ b/internal/engine/diff/diffsandbox.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "fmt" diff --git a/internal/engine/diffsandbox_test.go b/internal/engine/diff/diffsandbox_test.go similarity index 99% rename from internal/engine/diffsandbox_test.go rename to internal/engine/diff/diffsandbox_test.go index c6b15a85..b7bd140e 100644 --- a/internal/engine/diffsandbox_test.go +++ b/internal/engine/diff/diffsandbox_test.go @@ -1,4 +1,4 @@ -package engine +package diff import ( "os" diff --git a/internal/engine/diff_reexports.go b/internal/engine/diff_reexports.go new file mode 100644 index 00000000..d1bd3706 --- /dev/null +++ b/internal/engine/diff_reexports.go @@ -0,0 +1,55 @@ +package engine + +import "github.com/GrayCodeAI/hawk/internal/engine/diff" + +// Types from diff sub-package. + +type ( + PendingChange = diff.PendingChange + DiffSandbox = diff.DiffSandbox + StagingArea = diff.StagingArea + StagedChange = diff.StagedChange + StagedHunk = diff.StagedHunk + DiffPreview = diff.DiffPreview + FileChange = diff.FileChange + DiffHunk = diff.DiffHunk + DiffLine = diff.DiffLine + ChangeStats = diff.ChangeStats + DiffSummary = diff.DiffSummary + FileSummary = diff.FileSummary + DiffSummarizer = diff.DiffSummarizer + TestSelector = diff.TestSelector + SelectedTests = diff.SelectedTests + Diff3Result = diff.Diff3Result + Diff3Conflict = diff.Diff3Conflict + Diff3Stats = diff.Diff3Stats + Diff3Region = diff.Diff3Region + Edit = diff.Edit +) + +// Short-name aliases. + +type ( + Preview = diff.DiffPreview + Summarizer = diff.DiffSummarizer +) + +// Functions. + +var ( + NewDiffSandbox = diff.NewDiffSandbox + NewStagingArea = diff.NewStagingArea + NewDiffPreview = diff.NewDiffPreview + NewDiffSummarizer = diff.NewDiffSummarizer + NewTestSelector = diff.NewTestSelector + ComputeDiff = diff.ComputeDiff + ComputeMyersDiff = diff.ComputeMyersDiff + RenderUnified = diff.RenderUnified + Merge3 = diff.Merge3 + MergeClean = diff.MergeClean + FormatConflictMarkers = diff.FormatConflictMarkers + LCS = diff.LCS + EditScript = diff.EditScript + BuildDependencyGraph = diff.BuildDependencyGraph + GenerateTestCommand = diff.GenerateTestCommand +) diff --git a/internal/engine/doc_updater_test.go b/internal/engine/doc_updater_test.go deleted file mode 100644 index c87787a8..00000000 --- a/internal/engine/doc_updater_test.go +++ /dev/null @@ -1,631 +0,0 @@ -package engine - -import ( - "os" - "path/filepath" - "strings" - "testing" -) - -func TestNewDocUpdater(t *testing.T) { - du := NewDocUpdater() - if du == nil { - t.Fatal("NewDocUpdater returned nil") - } -} - -func TestDetectStaleDocumentation_SignatureChanged(t *testing.T) { - du := NewDocUpdater() - - oldContent := `package main - -// ValidateToken validates a JWT token -func ValidateToken(token string) bool { - return true -} -` - - newContent := `package main - -// ValidateToken validates a JWT token -func ValidateToken(ctx context.Context, token string) bool { - return true -} -` - - updates := du.DetectStaleDocumentation("src/auth.go", oldContent, newContent) - if len(updates) == 0 { - t.Fatal("expected at least one update for signature change") - } - - found := false - for _, u := range updates { - if u.Symbol == "ValidateToken" { - found = true - if !strings.Contains(u.Reason, "signature_changed") { - t.Errorf("expected reason to contain 'signature_changed', got %q", u.Reason) - } - if u.File != "src/auth.go" { - t.Errorf("expected file 'src/auth.go', got %q", u.File) - } - if u.OldDoc == "" { - t.Error("expected OldDoc to be set") - } - if u.NewDoc == "" { - t.Error("expected NewDoc to be set") - } - } - } - if !found { - t.Error("did not find update for ValidateToken") - } -} - -func TestDetectStaleDocumentation_NewParams(t *testing.T) { - du := NewDocUpdater() - - oldContent := `package main - -// ProcessData handles data processing -func ProcessData(data []byte) error { - return nil -} -` - - newContent := `package main - -// ProcessData handles data processing -func ProcessData(data []byte, timeout int, retries int) error { - return nil -} -` - - updates := du.DetectStaleDocumentation("src/process.go", oldContent, newContent) - if len(updates) == 0 { - t.Fatal("expected at least one update for new params") - } - - found := false - for _, u := range updates { - if u.Symbol == "ProcessData" { - found = true - if !strings.Contains(u.Reason, "signature_changed") && !strings.Contains(u.Reason, "new_params") { - t.Errorf("expected reason to be signature_changed or new_params, got %q", u.Reason) - } - } - } - if !found { - t.Error("did not find update for ProcessData") - } -} - -func TestDetectStaleDocumentation_OutdatedReference(t *testing.T) { - du := NewDocUpdater() - - oldContent := `package main - -// ProcessRequest uses oldHelper to process incoming requests -func ProcessRequest(r *http.Request) error { - return nil -} - -func oldHelper() {} -` - - newContent := `package main - -// ProcessRequest uses oldHelper to process incoming requests -func ProcessRequest(r *http.Request) error { - return nil -} -` - - updates := du.DetectStaleDocumentation("src/handler.go", oldContent, newContent) - if len(updates) == 0 { - t.Fatal("expected at least one update for outdated reference") - } - - found := false - for _, u := range updates { - if u.Symbol == "ProcessRequest" && u.Reason == "outdated_reference" { - found = true - if !strings.Contains(u.NewDoc, "[removed:oldHelper]") { - t.Errorf("expected NewDoc to contain removed marker, got %q", u.NewDoc) - } - } - } - if !found { - t.Error("did not find outdated_reference update for ProcessRequest") - } -} - -func TestDetectStaleDocumentation_NoChanges(t *testing.T) { - du := NewDocUpdater() - - content := `package main - -// Hello says hello -func Hello(name string) string { - return "hello " + name -} -` - - updates := du.DetectStaleDocumentation("src/hello.go", content, content) - if len(updates) != 0 { - t.Errorf("expected no updates when content is unchanged, got %d", len(updates)) - } -} - -func TestGenerateDocUpdate(t *testing.T) { - du := NewDocUpdater() - - tests := []struct { - name string - funcName string - signature string - oldDoc string - wantSub string - }{ - { - name: "add context param", - funcName: "ValidateToken", - signature: "(ctx context.Context, token string) bool", - oldDoc: "// ValidateToken validates a JWT token", - wantSub: "context", - }, - { - name: "add new named param", - funcName: "Process", - signature: "(data []byte, limit int) error", - oldDoc: "// Process processes the data", - wantSub: "limit", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := du.GenerateDocUpdate(tt.funcName, tt.signature, tt.oldDoc) - if !strings.Contains(result, tt.wantSub) { - t.Errorf("GenerateDocUpdate() = %q, want substring %q", result, tt.wantSub) - } - if !strings.HasPrefix(result, "//") { - t.Errorf("GenerateDocUpdate() should start with //, got %q", result) - } - }) - } -} - -func TestGenerateDocUpdate_PreservesPrefix(t *testing.T) { - du := NewDocUpdater() - - result := du.GenerateDocUpdate("Foo", "(bar int) string", "// Foo does something") - if !strings.HasPrefix(result, "// Foo") { - t.Errorf("expected doc to preserve 'Foo' prefix, got %q", result) - } -} - -func TestScanProjectForStaleDocs(t *testing.T) { - du := NewDocUpdater() - - // Create a temp project - dir := t.TempDir() - - // Create a file with a reference to a non-existent symbol - content := `package main - -// HandleRequest uses NonExistentProcessor to handle requests -func HandleRequest() error { - return nil -} - -// DoWork performs work -func DoWork() { -} -` - err := os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o644) - if err != nil { - t.Fatal(err) - } - - updates := du.ScanProjectForStaleDocs(dir) - - found := false - for _, u := range updates { - if u.Symbol == "HandleRequest" && u.Reason == "outdated_reference" { - found = true - } - } - if !found { - t.Error("expected to find outdated_reference for HandleRequest referencing NonExistentProcessor") - } -} - -func TestScanProjectForStaleDocs_NoStale(t *testing.T) { - du := NewDocUpdater() - - dir := t.TempDir() - - content := `package main - -// DoWork performs work -func DoWork() { -} - -// Helper assists DoWork -func Helper() { -} -` - err := os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o644) - if err != nil { - t.Fatal(err) - } - - updates := du.ScanProjectForStaleDocs(dir) - - // Helper references DoWork which exists, so no stale docs - for _, u := range updates { - if u.Symbol == "Helper" { - t.Errorf("unexpected stale doc for Helper: %+v", u) - } - } -} - -func TestScanProjectForStaleDocs_SkipsVendor(t *testing.T) { - du := NewDocUpdater() - - dir := t.TempDir() - - // Create vendor directory with stale docs - vendorDir := filepath.Join(dir, "vendor") - os.MkdirAll(vendorDir, 0o755) - - vendorContent := `package vendor - -// BadFunc uses MissingThing -func BadFunc() {} -` - os.WriteFile(filepath.Join(vendorDir, "bad.go"), []byte(vendorContent), 0o644) - - // Create main file without issues - mainContent := `package main - -// DoWork performs work -func DoWork() { -} -` - os.WriteFile(filepath.Join(dir, "main.go"), []byte(mainContent), 0o644) - - updates := du.ScanProjectForStaleDocs(dir) - - for _, u := range updates { - if strings.Contains(u.File, "vendor") { - t.Errorf("should skip vendor directory, found update: %+v", u) - } - } -} - -func TestFormatUpdates_Empty(t *testing.T) { - du := NewDocUpdater() - - result := du.FormatUpdates(nil) - if result != "No stale documentation found." { - t.Errorf("unexpected output for empty updates: %q", result) - } -} - -func TestFormatUpdates_Multiple(t *testing.T) { - du := NewDocUpdater() - - updates := []DocUpdate{ - { - File: "src/auth.go", - Line: 15, - OldDoc: "// ValidateToken validates a JWT token", - NewDoc: "// ValidateToken validates a JWT token using the provided context", - Symbol: "ValidateToken", - Reason: "signature_changed (added ctx parameter)", - }, - { - File: "src/handler.go", - Line: 42, - OldDoc: "// ProcessRequest uses oldHelper", - NewDoc: "", - Symbol: "ProcessRequest", - Reason: "outdated_reference", - }, - { - File: "src/data.go", - Line: 8, - OldDoc: "// Transform transforms data", - NewDoc: "// Transform transforms data with limit", - Symbol: "Transform", - Reason: "new_params", - }, - } - - result := du.FormatUpdates(updates) - - if !strings.Contains(result, "Stale Documentation (3 items):") { - t.Error("expected header with count") - } - if !strings.Contains(result, "src/auth.go:15 — ValidateToken") { - t.Error("expected first entry") - } - if !strings.Contains(result, "signature_changed (added ctx parameter)") { - t.Error("expected reason for first entry") - } - if !strings.Contains(result, "src/handler.go:42 — ProcessRequest") { - t.Error("expected second entry") - } - if !strings.Contains(result, "outdated_reference") { - t.Error("expected reason for second entry") - } - if !strings.Contains(result, "src/data.go:8 — Transform") { - t.Error("expected third entry") - } - if !strings.Contains(result, `"// ValidateToken validates a JWT token"`) { - t.Error("expected old doc quoted") - } - if !strings.Contains(result, `"// ValidateToken validates a JWT token using the provided context"`) { - t.Error("expected new doc quoted") - } -} - -func TestFormatUpdates_SingleItem(t *testing.T) { - du := NewDocUpdater() - - updates := []DocUpdate{ - { - File: "main.go", - Line: 5, - OldDoc: "// Run runs", - NewDoc: "// Run runs with options", - Symbol: "Run", - Reason: "new_params", - }, - } - - result := du.FormatUpdates(updates) - if !strings.Contains(result, "Stale Documentation (1 items):") { - t.Error("expected header with count 1") - } - if !strings.Contains(result, "main.go:5 — Run") { - t.Error("expected entry") - } -} - -func TestApplyUpdates(t *testing.T) { - du := NewDocUpdater() - - content := `package main - -// ValidateToken validates a JWT token -func ValidateToken(ctx context.Context, token string) bool { - return true -} - -// ProcessData handles data -func ProcessData(data []byte) error { - return nil -} -` - - updates := []DocUpdate{ - { - File: "main.go", - Line: 4, - OldDoc: "// ValidateToken validates a JWT token", - NewDoc: "// ValidateToken validates a JWT token using the provided context", - Symbol: "ValidateToken", - Reason: "signature_changed", - }, - } - - result := du.ApplyUpdates(updates, content) - - if !strings.Contains(result, "// ValidateToken validates a JWT token using the provided context") { - t.Error("expected updated doc to be applied") - } - if strings.Contains(result, "// ValidateToken validates a JWT token\n") { - t.Error("old doc should have been replaced") - } - // Other docs should be untouched - if !strings.Contains(result, "// ProcessData handles data") { - t.Error("unrelated docs should not be modified") - } -} - -func TestApplyUpdates_MultipleUpdates(t *testing.T) { - du := NewDocUpdater() - - content := `package main - -// Foo does foo -func Foo(a int) {} - -// Bar does bar -func Bar(b int) {} -` - - updates := []DocUpdate{ - { - File: "main.go", - Line: 4, - OldDoc: "// Foo does foo", - NewDoc: "// Foo does foo with options", - Symbol: "Foo", - Reason: "new_params", - }, - { - File: "main.go", - Line: 7, - OldDoc: "// Bar does bar", - NewDoc: "// Bar does bar with context", - Symbol: "Bar", - Reason: "new_params", - }, - } - - result := du.ApplyUpdates(updates, content) - - if !strings.Contains(result, "// Foo does foo with options") { - t.Error("expected Foo doc to be updated") - } - if !strings.Contains(result, "// Bar does bar with context") { - t.Error("expected Bar doc to be updated") - } -} - -func TestApplyUpdates_EmptyUpdates(t *testing.T) { - du := NewDocUpdater() - - content := `package main - -// Hello says hello -func Hello() {} -` - - result := du.ApplyUpdates(nil, content) - if result != content { - t.Error("content should be unchanged with no updates") - } -} - -func TestApplyUpdates_SkipsEmptyDoc(t *testing.T) { - du := NewDocUpdater() - - content := `package main - -// Hello says hello -func Hello() {} -` - - updates := []DocUpdate{ - { - File: "main.go", - Line: 4, - OldDoc: "", - NewDoc: "// Hello says hello world", - Symbol: "Hello", - Reason: "new_params", - }, - } - - result := du.ApplyUpdates(updates, content) - // Should not modify since OldDoc is empty (can't find what to replace) - if !strings.Contains(result, "// Hello says hello") { - t.Error("should not modify when OldDoc is empty") - } -} - -func TestDocUpdParseFunctions(t *testing.T) { - content := `package main - -// Add adds two numbers -func Add(a, b int) int { - return a + b -} - -// Greet greets the user -func (s *Server) Greet(name string) string { - return "hello " + name -} - -func noDoc() {} -` - - funcs := docUpdParseFunctions(content) - - if _, ok := funcs["Add"]; !ok { - t.Error("expected to find Add function") - } - if funcs["Add"].Doc != "// Add adds two numbers" { - t.Errorf("unexpected doc for Add: %q", funcs["Add"].Doc) - } - - if _, ok := funcs["Greet"]; !ok { - t.Error("expected to find Greet function (method)") - } - if funcs["Greet"].Doc != "// Greet greets the user" { - t.Errorf("unexpected doc for Greet: %q", funcs["Greet"].Doc) - } - - if _, ok := funcs["noDoc"]; !ok { - t.Error("expected to find noDoc function") - } - if funcs["noDoc"].Doc != "" { - t.Errorf("expected empty doc for noDoc, got %q", funcs["noDoc"].Doc) - } -} - -func TestDocUpdExtractParams(t *testing.T) { - tests := []struct { - sig string - want int - params []string - }{ - {"(a int, b string) error", 2, []string{"a int", "b string"}}, - {"() error", 0, nil}, - {"(ctx context.Context) error", 1, []string{"ctx context.Context"}}, - {"(data []byte, opts ...Option) error", 2, []string{"data []byte", "opts ...Option"}}, - } - - for _, tt := range tests { - params := docUpdExtractParams(tt.sig) - if len(params) != tt.want { - t.Errorf("docUpdExtractParams(%q): got %d params, want %d: %v", tt.sig, len(params), tt.want, params) - } - if tt.params != nil { - for i, p := range tt.params { - if i < len(params) && params[i] != p { - t.Errorf("docUpdExtractParams(%q)[%d]: got %q, want %q", tt.sig, i, params[i], p) - } - } - } - } -} - -func TestDocUpdDetectSignatureChangeDetail(t *testing.T) { - detail := docUpdDetectSignatureChangeDetail( - "(token string) bool", - "(ctx context.Context, token string) bool", - ) - if !strings.Contains(detail, "added ctx parameter") { - t.Errorf("expected 'added ctx parameter', got %q", detail) - } - - detail = docUpdDetectSignatureChangeDetail( - "(ctx context.Context, token string) bool", - "(token string) bool", - ) - if !strings.Contains(detail, "removed ctx parameter") { - t.Errorf("expected 'removed ctx parameter', got %q", detail) - } -} - -func TestDocUpdaterConcurrentAccess(t *testing.T) { - du := NewDocUpdater() - - oldContent := `package main - -// Work does work -func Work(a int) {} -` - newContent := `package main - -// Work does work -func Work(a int, b int) {} -` - - done := make(chan bool, 10) - for i := 0; i < 10; i++ { - go func() { - _ = du.DetectStaleDocumentation("file.go", oldContent, newContent) - done <- true - }() - } - - for i := 0; i < 10; i++ { - <-done - } -} diff --git a/internal/engine/docgen_test.go b/internal/engine/docgen_test.go deleted file mode 100644 index 350fc834..00000000 --- a/internal/engine/docgen_test.go +++ /dev/null @@ -1,761 +0,0 @@ -package engine - -import ( - "os" - "path/filepath" - "strings" - "testing" - "time" -) - -// createTestGoFile creates a temporary Go source file for testing. -func createTestGoFile(t *testing.T, dir, filename, content string) { - t.Helper() - err := os.WriteFile(filepath.Join(dir, filename), []byte(content), 0o644) - if err != nil { - t.Fatalf("failed to create test file %s: %v", filename, err) - } -} - -func TestParseGoPackage_ExtractsFunctionsAndTypes(t *testing.T) { - dir := t.TempDir() - - src := `// Package auth provides authentication utilities. -package auth - -import "time" - -// Claims represents JWT claims. -type Claims struct { - // UserID is the unique user identifier. - UserID string - // ExpiresAt is the token expiry time. - ExpiresAt time.Time -} - -// Validator checks tokens. -type Validator interface { - // Validate checks if a token is valid. - Validate(token string) error -} - -// ValidateToken validates a JWT token and returns the claims. -func ValidateToken(token string) (*Claims, error) { - return nil, nil -} - -// NewClaims creates new claims for a user. -func NewClaims(userID string, ttl time.Duration) *Claims { - return nil -} - -// helper is an unexported function. -func helper() {} -` - createTestGoFile(t, dir, "auth.go", src) - - dg := NewDocGenerator(dir) - pkg, err := dg.parseGoPackage(dir) - if err != nil { - t.Fatalf("parseGoPackage failed: %v", err) - } - - if pkg == nil { - t.Fatal("expected non-nil package doc") - } - - if pkg.Name != "auth" { - t.Errorf("expected package name 'auth', got '%s'", pkg.Name) - } - - if pkg.Description != "Package auth provides authentication utilities." { - t.Errorf("unexpected package description: %s", pkg.Description) - } - - // Should have 2 exported functions (helper is unexported) - if len(pkg.Functions) != 2 { - t.Errorf("expected 2 exported functions, got %d", len(pkg.Functions)) - for _, f := range pkg.Functions { - t.Logf(" function: %s", f.Name) - } - } - - // Check ValidateToken - var validateFn *FunctionDoc - for i := range pkg.Functions { - if pkg.Functions[i].Name == "ValidateToken" { - validateFn = &pkg.Functions[i] - break - } - } - if validateFn == nil { - t.Fatal("expected to find ValidateToken function") - } - if !validateFn.Exported { - t.Error("ValidateToken should be exported") - } - if validateFn.Description != "ValidateToken validates a JWT token and returns the claims." { - t.Errorf("unexpected description: %s", validateFn.Description) - } - if len(validateFn.Parameters) != 1 { - t.Errorf("expected 1 parameter, got %d", len(validateFn.Parameters)) - } else { - if validateFn.Parameters[0].Name != "token" { - t.Errorf("expected param name 'token', got '%s'", validateFn.Parameters[0].Name) - } - } - if !strings.Contains(validateFn.Returns, "*Claims") { - t.Errorf("expected returns to contain '*Claims', got '%s'", validateFn.Returns) - } - - // Should have 2 types: Claims (struct) and Validator (interface) - if len(pkg.Types) != 2 { - t.Errorf("expected 2 types, got %d", len(pkg.Types)) - for _, typ := range pkg.Types { - t.Logf(" type: %s (%s)", typ.Name, typ.Kind) - } - } - - var claimsType *TypeDoc - var validatorType *TypeDoc - for i := range pkg.Types { - switch pkg.Types[i].Name { - case "Claims": - claimsType = &pkg.Types[i] - case "Validator": - validatorType = &pkg.Types[i] - } - } - - if claimsType == nil { - t.Fatal("expected to find Claims type") - } - if claimsType.Kind != "struct" { - t.Errorf("expected Claims kind 'struct', got '%s'", claimsType.Kind) - } - if len(claimsType.Fields) != 2 { - t.Errorf("expected 2 fields in Claims, got %d", len(claimsType.Fields)) - } - - if validatorType == nil { - t.Fatal("expected to find Validator type") - } - if validatorType.Kind != "interface" { - t.Errorf("expected Validator kind 'interface', got '%s'", validatorType.Kind) - } -} - -func TestRenderMarkdown_ProducesValidOutput(t *testing.T) { - doc := &ProjectDoc{ - Name: "myproject", - Description: "A test project for documentation.", - Architecture: "Simple package layout.", - QuickStart: "import \"myproject\"", - GeneratedAt: time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC), - Packages: []PackageDoc{ - { - Name: "auth", - Path: "pkg/auth", - Description: "Authentication package.", - Functions: []FunctionDoc{ - { - Name: "ValidateToken", - Signature: "func ValidateToken(token string) (*Claims, error)", - Description: "Validates a JWT token.", - Exported: true, - }, - }, - Types: []TypeDoc{ - { - Name: "Claims", - Kind: "struct", - Description: "JWT claims.", - Fields: []FieldDoc{ - {Name: "UserID", Type: "string", Desc: "User identifier"}, - }, - Methods: []FunctionDoc{ - { - Name: "IsExpired", - Signature: "func (c *Claims) IsExpired() bool", - Exported: true, - }, - }, - }, - }, - }, - }, - } - - md := RenderMarkdown(doc) - - // Check title - if !strings.Contains(md, "# myproject") { - t.Error("markdown should contain project title") - } - - // Check description - if !strings.Contains(md, "A test project for documentation.") { - t.Error("markdown should contain description") - } - - // Check architecture section - if !strings.Contains(md, "## Architecture") { - t.Error("markdown should contain Architecture section") - } - - // Check quick start section - if !strings.Contains(md, "## Quick Start") { - t.Error("markdown should contain Quick Start section") - } - - // Check package section - if !strings.Contains(md, "### package auth") { - t.Error("markdown should contain package header") - } - - // Check function signature - if !strings.Contains(md, "func ValidateToken(token string) (*Claims, error)") { - t.Error("markdown should contain function signature") - } - - // Check type - if !strings.Contains(md, "`type Claims struct`") { - t.Error("markdown should contain type header") - } - - // Check field table - if !strings.Contains(md, "| UserID | string | User identifier |") { - t.Error("markdown should contain field table row") - } - - // Check methods - if !strings.Contains(md, "func (c *Claims) IsExpired() bool") { - t.Error("markdown should contain method signature") - } - - // Check timestamp - if !strings.Contains(md, "2025-01-15") { - t.Error("markdown should contain generation timestamp") - } - - // Check valid markdown structure (no double blank lines at start) - if strings.HasPrefix(md, "\n") { - t.Error("markdown should not start with blank line") - } -} - -func TestRenderHTML_ContainsExpectedElements(t *testing.T) { - doc := &ProjectDoc{ - Name: "testproject", - Description: "HTML test project.", - GeneratedAt: time.Now(), - Packages: []PackageDoc{ - { - Name: "core", - Description: "Core functionality.", - Functions: []FunctionDoc{ - { - Name: "Init", - Signature: "func Init() error", - Description: "Initializes the system.", - Exported: true, - }, - }, - Types: []TypeDoc{ - { - Name: "Config", - Kind: "struct", - Description: "Configuration.", - Fields: []FieldDoc{ - {Name: "Port", Type: "int", Desc: "Server port"}, - }, - }, - }, - }, - }, - } - - html := RenderHTML(doc) - - // Check HTML structure - if !strings.Contains(html, "") { - t.Error("HTML should contain DOCTYPE") - } - if !strings.Contains(html, "testproject - Documentation") { - t.Error("HTML should contain title element") - } - - // Check navigation - if !strings.Contains(html, "