diff --git a/.claude/wip/grpc-plan.md b/.claude/wip/grpc-plan.md new file mode 100644 index 0000000..c9b2291 --- /dev/null +++ b/.claude/wip/grpc-plan.md @@ -0,0 +1,379 @@ + +gRPC Tunnel to Replace snyk-broker +Context + +Axon currently uses a forked snyk-broker (Node.js) to tunnel HTTP traffic between the Cortex backend and customer-network agents via WebSocket (Primus). This has been fragile — the Node.js WebSocket stack has many failure modes, adds build complexity (Node.js in Docker image), and includes much unused snyk-broker code. We're replacing it with a native Go gRPC bidirectional streaming tunnel that is more reliable, simpler, and eliminates the Node.js dependency. + +Key architectural change: the RegistrationReflector is eliminated. All accept file rule matching, header injection, auth handling, variable resolution, plugin execution, and _POOL load balancing are handled by a standalone RequestExecutor component that takes an AcceptFileRule and applies it to a given HTTP request. This is a single-concern component that works identically regardless of relay mode. +Phase 1: Proto Definitions + Server Skeleton with BROKER_SERVER Compat +Proto definitions + +Create /src/axon/proto/tunnel/tunnel.proto: + +service TunnelService { + rpc Tunnel(stream TunnelClientMessage) returns (stream TunnelServerMessage); +} + +// Client → Server envelope +message TunnelClientMessage { + oneof message { + ClientHello hello = 1; + Heartbeat heartbeat = 2; + HttpResponse http_response = 3; + } +} + +// Server → Client envelope +message TunnelServerMessage { + oneof message { + ServerHello hello = 1; + Heartbeat heartbeat = 2; + HttpRequest http_request = 3; + } +} + +message ClientHello { + string broker_token = 1; // Cortex-API-issued token, used for BROKER_SERVER dispatch routing + string client_version = 2; + string tenant_id = 3; // from Cortex API registration + string integration = 4; // e.g. "github", "jira" + string alias = 5; // integration alias + string instance_id = 6; // unique agent instance ID (from config.InstanceId) + string cortex_api_token = 7; // for server-side validation (optional, JWT) + map metadata = 8; +} +message ServerHello { + string server_id = 1; // server hostname (or UUID fallback) — client uses for metrics tagging and dedup + int32 heartbeat_interval_ms = 2; + string stream_id = 3; // server-generated UUID for this specific stream, used in BROKER_SERVER notifications +} +message Heartbeat { int64 timestamp_ms = 1; } +message HttpRequest { string request_id = 1; string method = 2; string path = 3; map headers = 4; bytes body = 5; int32 chunk_index = 6; bool is_final = 7; } +message HttpResponse { string request_id = 1; int32 status_code = 2; map headers = 3; bytes body = 4; int32 chunk_index = 5; bool is_final = 6; } + +Server module (/src/axon/server/) + +New Go module with its own go.mod. Dependencies: grpc, protobuf, zap, fx, uber/tally + tally/prometheus, golang-jwt/jwt/v5. + +Files to create: + + server/go.mod, server/go.sum + server/Makefile — proto generation, build, test targets + server/cmd/main.go — fx app bootstrap, gRPC server + HTTP server startup + server/config/config.go — config struct: GrpcPort, HttpPort, BrokerServerURL, JWTPublicKeyPath, HeartbeatIntervalMs, env-var driven + server/tunnel/service.go — implements TunnelServiceServer. On stream open: read ClientHello, optionally validate cortex_api_token (JWT), check broker_token for collision, store {tenant_id, integration, alias, instance_id} mapping, call POST /broker-server/client-connected with token + SHA-256 hash, send ServerHello. On stream close: call client-deleted. Send heartbeats on interval. Close stream if client misses 2 heartbeats. + server/tunnel/client_registry.go — thread-safe sync.RWMutex map of hashedToken → clientEntry{tenantId, integration, alias, instanceId, streams[]}. Supports multiple tunnels per token. Provides GetIdentity(hashedToken) for metrics tagging. Tracks instance_id per stream to distinguish multiple agent instances for the same integration. + server/broker/broker_server_client.go — HTTP client wrapping BROKER_SERVER API: ServerConnected(), ServerDeleted(), ClientConnected(token, hashedToken, metadata), ClientDeleted(token, hashedToken) + server/metrics/metrics.go — uber/tally scope + prometheus reporter. All metrics tagged with {server_id, tenant_id, integration, alias} where applicable. Metrics: + tunnel.connections.active (gauge) — currently open tunnel streams + tunnel.connections.total (counter) — total tunnel connections over lifetime + tunnel.heartbeat.sent / tunnel.heartbeat.received (counters) + tunnel.heartbeat.missed (counter) — heartbeats expected but not received + tunnel.dispatch.count (counter, by method, status_code) — HTTP requests dispatched through tunnel + tunnel.dispatch.duration_ms (histogram) — end-to-end dispatch latency + tunnel.dispatch.inflight (gauge) — currently pending requests + tunnel.dispatch.errors (counter, by error_type) — dispatch failures (timeout, no_tunnel, stream_error) + tunnel.dispatch.bytes_sent / tunnel.dispatch.bytes_received (counter) — traffic volume + tunnel.auth.failures (counter) — JWT validation failures + tunnel.stream.duration_seconds (histogram) — how long tunnel streams stay alive + +Token and identity flow + +The Cortex API issues the broker token (as it does today). The client passes the token plus identity metadata to the gRPC server in ClientHello. The server stores the mapping for metrics tagging and dispatch routing. + +Flow: + + Client calls Cortex API POST /api/v1/relay/register → receives server_uri, token, plus identity info (tenant_id, integration, alias) + Client opens gRPC stream, sends ClientHello with broker_token + tenant_id + integration + alias + instance_id + Server validates (optional JWT check on cortex_api_token), checks for token collision, stores: + + broker_token → { tenant_id, integration, alias, instance_id, stream_handles[] } + + Server calls POST /broker-server/client-connected with the token + SHA-256 hash + Server returns ServerHello with server_id and heartbeat_interval_ms + +This means: + + All server metrics tagged with {tenant_id, integration, alias} — zero external lookups + The Cortex API controls token issuance (existing behavior preserved) + Server just registers and maps the client-provided token to its identity + client_registry.go maps hashedToken → clientEntry{tenantId, integration, alias, instanceId, streams[]} + On token collision (different client claiming same token), server rejects with error + +Files to modify: + + /src/axon/Makefile — add server targets + +Verification + + Unit tests for client_registry.go (add/remove/lookup) + Unit tests for broker_server_client.go using httptest.Server + Integration test: start server, connect test gRPC client with ClientHello, verify mock BROKER_SERVER receives client-connected, disconnect, verify client-deleted + +Phase 2: RequestExecutor + Server HTTP Dispatch + Client Tunnel +RequestExecutor — standalone accept file rule engine + +A new standalone component at agent/server/requestexecutor/ that is the single point of responsibility for applying accept file rules to HTTP requests. This replaces all the logic currently split across the reflector, snyk-broker, and accept file rendering. It has no knowledge of gRPC, WebSocket, or tunnel mechanics. + +Responsibilities: + + Rule matching: given an HTTP method + path, find the first matching accept file rule (method match or "any", path glob/wildcard match) + URL rewriting: rewrite the request URL using the matched rule's origin + Auth injection: apply the rule's auth scheme (bearer, basic, custom header) + Header injection: resolve and inject custom headers from the rule, including: + Environment variable expansion: ${VAR}, ${VAR:default} + Plugin execution: ${plugin:name} → runs executable, captures output + _POOL load balancing: when an origin or variable uses _POOL suffix (e.g., GITHUB_API_POOL=https://api1.github.com,https://api2.github.com), parse comma-separated values and round-robin across them per-request + TLS handling: support HTTPS origins with configurable CA cert and HttpDisableTLS + Request execution: execute the rewritten request via http.Client and return the response + +Interface: + +// RequestExecutor applies accept file rules to execute HTTP requests +type RequestExecutor interface { + // Execute matches the request against accept file rules, rewrites it, + // and executes it against the target origin. Returns the response. + Execute(ctx context.Context, method, path string, headers map[string]string, body []byte) (*ExecutorResponse, error) +} + +type ExecutorResponse struct { + StatusCode int + Headers map[string]string + Body []byte // or io.ReadCloser for streaming +} + +Files to create: + + agent/server/requestexecutor/executor.go — RequestExecutor interface and implementation + agent/server/requestexecutor/rule_matcher.go — accept file rule matching (method + path glob) + agent/server/requestexecutor/pool.go — _POOL variable parsing and round-robin selection + agent/server/requestexecutor/executor_test.go — comprehensive tests + +Reuses existing code: + + agent/server/snykbroker/acceptfile/accept_file.go — accept file parsing, rendering, rule structures + agent/server/snykbroker/acceptfile/resolver.go — variable resolution (env vars, defaults). The varIsSet() function already checks _POOL suffix (line 277) but doesn't implement rotation — we add that. + agent/server/snykbroker/acceptfile/plugin.go — plugin execution for ${plugin:name} headers + +_POOL implementation details: + + Currently resolver.go:277 only checks os.Getenv(varName+"_POOL") != "" for variable presence + New pool.go adds: parse comma-separated pool values, maintain atomic round-robin counter per pool variable, return next value on each resolution + Pool resolution is transparent to the rest of the system — the resolver returns a single value each time + +Server-side HTTP dispatch + +Files to create: + + server/dispatch/handler.go — HTTP handler on /broker/:token/*path. Extracts token, SHA-256 hashes, looks up client stream(s), generates UUID request_id, sends HttpRequest down stream, waits for HttpResponse (with timeout), writes HTTP response back. Round-robins across multiple tunnels for same token. + server/dispatch/pending_requests.go — map of request_id → chan HttpResponse with timeout cleanup + +Files to modify: + + server/tunnel/service.go — on receiving HttpResponse from client stream, deliver to pending request channel + server/cmd/main.go — mount dispatch HTTP handler + +Client-side gRPC tunnel + +Extract RelayInstanceManager interface to a shared location: + +Files to create: + + agent/server/relay/interfaces.go — extracted RelayInstanceManager interface + agent/server/grpctunnel/tunnel_client.go — implements RelayInstanceManager: + Start(): call registration.Register() for server URI + token + identity (tenant_id, integration, alias), render accept file, create RequestExecutor from rendered rules, open gRPC connection, open N tunnel streams (default 2) each sending ClientHello with broker_token + identity, receive ServerHello, start heartbeat + request handler goroutines + Request handler: receive HttpRequest, delegate to RequestExecutor.Execute(), send HttpResponse back (chunked if large) + Reconnection: on stream error, exponential backoff (1s → 30s max) + Restart(): close all streams, re-register, re-establish + Close(): cancel contexts, wait for goroutines + agent/server/grpctunnel/module.go — fx module mirroring snykbroker/module.go + +Files to modify: + + agent/server/snykbroker/relay_instance_manager.go — import interface from server/relay/interfaces.go + agent/config/config.go — add fields: RelayMode (enum: snyk-broker | grpc-tunnel, default snyk-broker), TunnelCount (int, default 2), env vars RELAY_MODE, TUNNEL_COUNT. BROKER_SERVER_URL is reused as the gRPC server address when RELAY_MODE=grpc-tunnel (same env var, different meaning per mode). + agent/cmd/stack.go or agent/cmd/serve.go — conditional module selection: when RelayMode == "grpc-tunnel" use grpctunnel.Module, else snykbroker.Module + agent/server/snykbroker/module.go — update for interface extraction + +Chunked streaming for large bodies + +Bodies are split into chunks (max 1MB each) sent as a sequence of messages sharing the same request_id. The first chunk carries status_code + headers (for responses) or method + path + headers (for requests). Subsequent chunks carry only body, chunk_index, and is_final. The receiving side reassembles by buffering chunks in order until is_final=true. For small payloads (≤1MB), a single message with chunk_index=0, is_final=true is sent — no overhead. + +Implementation: + + server/dispatch/pending_requests.go — accumulates chunks per request_id, resolves the pending request channel only when is_final=true + RequestExecutor returns response body; tunnel_client.go chunks it into 1MB pieces for sending + server/dispatch/handler.go — reassembles chunks and writes the full HTTP response back to the BROKER_SERVER caller + +Key design notes + + No reflector needed. The RequestExecutor handles all accept file logic natively: rule matching, header injection, auth, variable resolution, plugins, and _POOL rotation. The reflector is eliminated from the gRPC tunnel path entirely. + The RequestExecutor is a standalone component with a single concern — it knows nothing about tunnels, gRPC, or WebSocket. It takes a request and produces a response using accept file rules. + Multi-tunnel with server dedup: Client opens N tunnel streams (default 2). After receiving ServerHello, the client checks server_id — if it's already connected to that server, it closes the duplicate stream and retries (with backoff + jitter to land on a different LB target). This ensures tunnels are spread across distinct server instances for real fault isolation. Client tags all metrics with server_id so we can track which servers each agent is connected to. + Server server_id: Set from HOSTNAME env var (standard in k8s pods). Falls back to a generated UUID if HOSTNAME is unset or "localhost". Returned in every ServerHello. + +Verification + + Unit test RequestExecutor: given accept file rules + HTTP request, verify: + Correct rule matching (method + path) + Origin URL rewriting + Bearer/basic auth injection + Custom header resolution (env vars, plugins) + _POOL round-robin rotation + TLS configuration + Rejection when no rule matches + Unit test pending_requests.go: request/response correlation, chunking, and timeout + Integration test: server + mock BROKER_SERVER + client with test accept file → end-to-end HTTP dispatch through tunnel + Flag switching test: RELAY_MODE=snyk-broker starts old path, RELAY_MODE=grpc-tunnel starts new path + +Phase 3: Auth, Hardening, Dockerfiles +JWT Authentication + + server/auth/jwt.go — gRPC stream interceptor validating JWT bearer token from ClientHello. Loads public key from config file path. Disabled when no key configured. + +Graceful shutdown + + Server: SIGTERM → POST /broker-server/server-deleted, drain active tunnels (30s), stop gRPC + Client: Close() → cancel stream contexts, wait for goroutines, clean up temp files + +Health endpoints + + Server: /healthz returning 200 when running + Server: /broker/:token/systemcheck — tunnel health check to client and return result + +Dockerfiles + + server/docker/Dockerfile — multi-stage Go build, no Node.js + docker/Dockerfile.grpc — lighter agent image without Node.js/snyk-broker + +Client-side observability + + Client metrics (prometheus/client_golang): grpc_tunnel_connections_active (gauge), grpc_tunnel_requests_total (counter), grpc_tunnel_reconnects_total (counter), grpc_tunnel_request_duration_ms (histogram) + All tagged with server_id (from ServerHello) + tenant_id + integration + alias + Enables per-server-instance visibility from the client side + +Verification + + JWT unit tests: valid/invalid/expired tokens + Graceful shutdown test: SIGTERM → verify callbacks fire, no goroutine leaks + Adapt existing test/relay/relay_test.sh for RELAY_MODE=grpc-tunnel + +Phase 4: Migration and Cleanup + + Change RelayMode default from snyk-broker to grpc-tunnel + Add deprecation warning for snyk-broker mode + Remove Node.js + snyk-broker from main docker/Dockerfile + Remove RegistrationReflector and related config (HttpRelayReflectorMode, ReflectorWebSocketUpgrade) + Update README.relay.md + Keep snykbroker package for rollback, mark deprecated + +Verification + + Full regression suite with new default + Docker image size reduction (~200MB+ savings) + Backward compat: RELAY_MODE=snyk-broker still works + +Failure Modes and Mitigations +1. Spoofing / unauthorized connections + +Risk is low. BROKER_SERVER only dispatches to tokens issued by the Cortex API. A rogue client registering with a random token creates a dead-end entry — Cortex will never dispatch to it. Guessing a valid UUID (122 bits entropy) is impractical. The worst case is DoS via mass registration of junk entries. + +Mitigations: + + Rate-limit new tunnel connections per source IP (prevents registry flooding) + JWT validation (when enabled) confirms client identity — useful for metrics accuracy, not critical for security + TLS on gRPC listener prevents token interception in transit (configurable, default on in production) + +2. Multiple connections and reconnection races + +Problem: Client reconnects before server detects old stream is dead → two registry entries for same token. Or client-deleted for old stream arrives at BROKER_SERVER after client-connected for new stream → BROKER_SERVER drops the client. + +Mitigations: + + Each stream gets a server-generated stream_id (UUID). The client-connected and client-deleted calls to BROKER_SERVER include this stream_id so they can be correlated — a client-deleted for stream A doesn't affect stream B. + On ClientHello with a token that already exists in the registry from the same (tenant_id, instance_id): this is a reconnect. Add the new stream handle to the existing entry. The old dead stream will be cleaned up by heartbeat timeout. + On ClientHello with a token claimed by a different (tenant_id, instance_id): reject with ALREADY_EXISTS error. + Registry operations for the same token are serialized (per-token mutex or channel) to prevent connect/disconnect races. + +3. BROKER_SERVER notification durability + +Problem: client-connected POST fails → BROKER_SERVER can't route to client. client-deleted POST fails → BROKER_SERVER routes to dead tunnel. Server crashes → no cleanup notifications sent. + +Mitigations: + + Dispatch is ready immediately. On ClientHello, the server registers the stream in client_registry and sends ServerHello right away — the tunnel is live and can dispatch requests. The client-connected POST to BROKER_SERVER happens asynchronously. The registry entry tracks a brokerServerRegistered flag (starts false, set true on success). This means traffic can flow even if BROKER_SERVER is temporarily down. + client-connected retries indefinitely with backoff in a background goroutine. Backoff starts at 1s, caps at 30s. Continues until success or stream close. The entry stays in registry and dispatches traffic regardless of BROKER_SERVER notification status. + client-deleted retries with backoff (3 attempts). If it still fails, log an error. The TTL mechanism (below) handles cleanup. + Periodic re-registration: Server sends POST /broker-server/client-connected for all active connections every N minutes (e.g., 5 min). This acts as a TTL refresh — if BROKER_SERVER has a stale entry, it gets corrected. If the server crashed and restarted, it re-registers its clients. + Server lifecycle: server-connected on startup, server-deleted on graceful shutdown. On crash, BROKER_SERVER should have a TTL for server entries (server-side concern, not ours to implement, but we should document the expectation). + Idempotency: All BROKER_SERVER calls should be idempotent. Duplicate client-connected calls with same token are no-ops. This is critical for the periodic re-registration pattern. + +4. Heartbeat and stale connection detection + +Problem: Heartbeats prove stream liveness, not client health. A deadlocked client process keeps the TCP connection alive but never processes requests. Large chunked responses block the ordered gRPC stream, preventing heartbeats from flowing. + +Mitigations: + + Two-layer keepalive: + gRPC keepalive (transport level): keepalive.ServerParameters{Time: 30s, Timeout: 10s} + keepalive.EnforcementPolicy{MinTime: 15s}. Detects dead TCP connections (half-open sockets, network partitions). + Application heartbeat (tunnel level): Server sends Heartbeat every heartbeat_interval_ms (default 30s). Client must respond within 2 * heartbeat_interval_ms. Detects hung client processes. + Heartbeat vs. chunked response ordering: Since gRPC streams are ordered, a large response in progress blocks heartbeats. Two options: + Option A (recommended): Interleave heartbeat responses between chunks. The chunking sender checks if a heartbeat is pending and sends the heartbeat response before the next chunk. This adds minimal complexity to the chunk sender. + Option B: Use the gRPC keepalive as the liveness signal during long transfers. If the TCP connection is alive, the transfer is making progress. App heartbeats are skipped while an inflight request is active on that stream. + Dispatch health check: The server can periodically send a lightweight HttpRequest with a special path (e.g., /__tunnel_health) that the client responds to immediately. This validates the full request/response path, not just stream liveness. Run every 5 minutes. + Heartbeat latency tracking: Server tracks round-trip heartbeat latency per stream as a metric. Sudden latency spikes indicate degraded connections. + +5. Server rolling deploys and thundering herd + +Problem: Server instance goes down → all its tunnels die → all clients reconnect simultaneously to other instances → overload. + +Mitigations: + + Graceful drain: On SIGTERM, server sends a GoAway-style message to clients (could be a special ServerHello with a reconnect hint) and stops accepting new connections. Existing requests complete (30s drain). Clients begin reconnecting with jitter before the server fully shuts down. + Client reconnect jitter: On disconnect, client waits random(0, 5s) before first reconnect attempt. Combined with exponential backoff, this spreads reconnections over time. + Connection rate limiting on server: Server limits new tunnel connections to N per second (configurable) to prevent overload during mass reconnection. + +6. Partial chunk delivery + +Problem: Client sends chunks 0 and 1 of a response, then stream dies before chunk 2 (is_final=true). Server holds incomplete response data, the dispatch caller blocks forever. + +Mitigations: + + Dispatch timeout: Every pending request has a deadline (configurable, default 60s). If is_final is not received by deadline, the partial response is discarded and the caller gets a 504 Gateway Timeout. + Chunk cleanup on stream close: When a stream closes, all pending requests that were dispatched on that stream are immediately failed with 502 Bad Gateway. The pending_requests.go map tracks which stream each request was dispatched on. + Retry on other stream: If the token has multiple active streams and one fails mid-response, the server can retry the request on another stream (if the request is idempotent — determined by method: GET/HEAD are safe to retry, POST/PUT/DELETE are not). + +7. Token collision vs. legitimate reconnect + +Problem: Same token arrives in a new ClientHello. Is it a reconnect (same agent, new stream) or a hijack (different agent, stolen token)? + +Decision logic: + + Same (tenant_id, instance_id) as existing entry → reconnect. Add stream, keep existing identity. + Same tenant_id, different instance_id → new instance for same integration. This is valid (e.g., scaling up agents). Add stream to same token entry, track both instance_ids. + Different tenant_id → collision/hijack. Reject with error. Log security event. + +8. BROKER_SERVER unavailability + +Problem: BROKER_SERVER is down or unreachable. + +Mitigations: + + Traffic flows regardless. Tunnel streams accept and dispatch requests immediately, independent of BROKER_SERVER notification status. The client-connected and server-connected calls retry indefinitely in background goroutines. + Health endpoint reports status. /healthz returns 200 but includes broker_server_connected: true/false in the response body for visibility. + Periodic re-registration (every 5 min) ensures BROKER_SERVER eventually learns about all active connections once it recovers. + +Critical Files Reference +File Role +relay_instance_manager.go Current RelayInstanceManager interface + snyk-broker impl; interface to extract +module.go DI module pattern to replicate for grpctunnel +stack.go fx stack wiring; needs conditional module selection +config.go Agent config; needs RelayMode, TunnelServerAddress, TunnelCount +accept_file.go Accept file parsing/rendering — reuse for RequestExecutor +resolver.go Variable resolution + _POOL detection (line 277) — extend for pool rotation +plugin.go Plugin execution for ${plugin:name} — reuse in RequestExecutor +reflector.go HTTP reflector — eliminated in gRPC tunnel mode, logic moves to RequestExecutor +registration.go Registration interface — reused by grpctunnel +cortex-axon-agent.proto Existing proto — reference for style/conventions +Dockerfile Current Docker build — to be updated in Phase 4 +Makefile Proto generation pattern to replicate for server diff --git a/.claude/wip/grpc-server.md b/.claude/wip/grpc-server.md new file mode 100644 index 0000000..e67040c --- /dev/null +++ b/.claude/wip/grpc-server.md @@ -0,0 +1,62 @@ +# Standadlone GRPC Server + +Goal: replace snyk-broker with a native-Go implemented Grpc streaming interface for tunneling HTTP traffic. + +## Current architecture + +This project currently communicates with the Cortex background by hosting an instance of the snyk-broker component. This connects to another snyk-broker in the Cortex infrastructure via a websocket. It's only functions are to: + +* Initiate a two-way websocket tunnel by reaching out to the server +* Maintain the health of this tunnel +* Serve as a dumb pipe for HTTP traffic from the server side to these agent instances + +So the flow is: + +- Client [Axon] initializes `snyk-broker` +- `snyk-broker` contacts the server side (e.g. https://relay.cortex.io) and establises a websocket tunnel +- Server side then can dispatch HTTP calls that come through the tunnel then are executed by Axon inside the customer's network +- Responses are then played back through the tunnel. + +On the server side, snyk-broker interfaces with an HTTP server called the `BROKER_SERVER` which it communicates a set of operations to: + +- server-connected / deleted: a server-side snyk-broker instance is registering itself with the broker server +- client-connected / deleted: a client-side instance has regsered with the server-side snyk-broker and this information is sent to the broker-serve. this includes a BROKER_TOKEN which can be used to route traffic back to a specific client instance + +On the server side, the BROKER_SERVER also supports a dispatching operation like: + +`GET http://broker-server:8080/broker/$token/some/path` + +The broker server then takes the token and determines which snyk-broker server instance owns that connection. It then envelopes the HTTP request, and that is sent over a websocket to the client side. The client side then compares the path and method to it's accept.json file and uses that information to either reject the call (no matching config) or rewrite it to a local call, then tunnel the response back. + +### Problems with this + +- There is a lot of code in snyk-broker we don't care about. We only want the HTTP tunneling, none of the other stuff +- Snyk-broker is written in node which complicates development and installation +- The node websocket stack is complicated and we've seen very fragile. The semantics between server and client have been difficult to get right and there are a lot of failure modes that have been difficult to anticipate. + +## New plan + +Rather than take on this complexity, I'd like to instead move to an exactly compatable system that is written in Go and GRPC with the following high level architecture: + +- build a set of protobuf service files that define the flow between client and server +- in the axon project add a new root folder called server that implements the server side +- in the axon /agent folder we implement a new client to talk to this server. based on a flag we will instantiate either that or the existing snyk relay_instance_manager. ideally this is a very abstract interface so most of axon has no idea which we have injected +- this should be designed for durability of connection + - one of the message types is "heartbeat" and both sides regularly send the other a heartbeat to validate a working tunnel + - when a side can't heartbeat it should aggressively try to establish a new tunnel + - client side should support multiple (probably 2) of these running concurrently eg if one tunnel has problems, can switch to a healthy tunnel, kill the exisitng one and re-establish. since server side is sticky hopefully this will allow connecting to multiple remote heads. + - goal: we should never need to restart client instances to get them to reconnect, they should know they are in a disconnected state +- should support bearer authentication, starting with expecting a valid non-expired JWT signed by cortex. this should be optional to start with and the server side should support specifying a JWT public secret file for validating the JWT against. we don't need to protect all traffic but ideally we can require a valid cortex token. +- need to add a new server/docker/Dockerfile for building just the server component. +- the server side should be able to safely run any number of server components. +- the server side should emit prometheus metrics for it's primary operations. It should use uber/tally as the main interface for emitting metrics from code, backed by a prometheus recorder. +- the server side should emit structured zap JSON logging. + +### Investigation + +- the client routing side should support the existing accept.json format. See examples in agent/server/snykbroker/accept_files for the format to handle. This stack will replace the snyk-broker and reflector pieces of the exisitng architecture. + + +- please investigate the BROKER_SERVER interface here: https://github.com/cortexapps/snyk-broker/blob/16805ee1f3318c783df7ed35085ec9aa941bff6e/lib/server/infra/dispatcher.ts#L178. we want the server to support interacting with a server that supports this interface, given a hostport eg BROKER_SERVER_URL +- In https://github.com/cortexapps/snyk-broker/ undertand the usage of the BROKER_TOKEN raw and hashed versions in the API. Each server instance will need to keep track of it's client connections and the raw and hashed token for each. +- We want to be very paranoid about how to deal with problems, for example if the GCP load balancer doesn't have long enough TTLs can we recognize infrastructure closing our ports; can we handle the server side instances rolling, etc. \ No newline at end of file diff --git a/.claude/wip/grpc-tunnel-e2e-test.md b/.claude/wip/grpc-tunnel-e2e-test.md new file mode 100644 index 0000000..0f3e2ab --- /dev/null +++ b/.claude/wip/grpc-tunnel-e2e-test.md @@ -0,0 +1,79 @@ +# gRPC Tunnel E2E Test + +## Status: PASSING (both proxy and no-proxy modes) + +## Fixes Applied to Get Tests Passing + +### 1. `python:3.8-alpine` → `python:3.13-alpine` +Python 3.8 is EOL and the Docker image was removed from Docker Hub. +- **File**: `agent/test/relay/docker-compose.grpc.yml` + +### 2. Missing `CORTEX_TENANT_ID` env var +The server requires `tenant_id` in ClientHello but the docker-compose didn't set `CORTEX_TENANT_ID` for the axon-relay container. +- **File**: `agent/test/relay/docker-compose.grpc.yml` — added `CORTEX_TENANT_ID: test-tenant` + +### 3. Separated gRPC TLS from HTTP TLS config +`DISABLE_TLS` controlled both gRPC transport credentials and HTTP client TLS verification. When running with proxy (`CA_CERT_PATH` set), `http_client.go` panicked: "Cannot use custom CA cert with TLS verification disabled". Added a new `GRPC_INSECURE` config field specifically for gRPC tunnel connections. +- **Files**: + - `agent/config/config.go` — added `GrpcInsecure bool` field, read from `GRPC_INSECURE` env var + - `agent/server/grpctunnel/tunnel_client.go` — uses `GrpcInsecure` instead of `HttpDisableTLS` + - `agent/test/relay/docker-compose.grpc.yml` — uses `GRPC_INSECURE: "true"` instead of `DISABLE_TLS: "true"` + +### 4. Removed snyk-broker-specific header check +The test checked for `x-axon-relay-instance` header which is injected by the snyk-broker reflector, not the gRPC tunnel path. +- **File**: `agent/test/relay/relay_test.grpc.sh` + +### 5. macOS compatibility fix +`stat -c%s` doesn't work on macOS (BSD stat). Changed to `wc -c <` which is portable. +- **File**: `agent/test/relay/relay_test.grpc.sh` + +## Running the Tests + +```bash +# No-proxy mode +cd agent/test/relay && PROXY=0 ./relay_test.grpc.sh + +# With proxy mode +cd agent/test/relay && PROXY=1 ./relay_test.grpc.sh + +# Both (via Makefile) +cd agent && make grpc-relay-test +``` + +## Test Architecture + +``` + Host + | + v + grpc-tunnel-server (HTTP :8080, gRPC :50052) + | + gRPC bidirectional stream + | + v + axon-relay (RELAY_MODE=grpc-tunnel) + | + HTTP request execution + | + v + python-server (:80, serves /tmp) + or GitHub (HTTPS) + or cortex-fake (:8081, echo endpoint) +``` + +## Test Cases +1. Text file relay (write to /tmp, fetch via tunnel) +2. Binary file relay (1MB, SHA-256 checksum verification) +3. HTTPS relay (GitHub README fetch) +4. Proxy header injection (PROXY=1 only) — verifies `x-proxy-mitmproxy` +5. Accept file header injection (PROXY=1 only) — verifies `added-fake-server` +6. Plugin header injection (PROXY=1 only) — verifies `HOME=/root` +7. gRPC tunnel stream establishment (PROXY=1 only) — log check + +## Remaining Phase 2 Tasks +- None — Phase 2 is complete (code + e2e tests passing) + +## Phase 3 & 4 (Not Started) +- Phase 3: JWT auth, graceful shutdown hardening, health endpoints +- Phase 4: Migration, cleanup +- Plan doc: `.claude/wip/grpc-plan.md` diff --git a/.vscode/launch.json b/.vscode/launch.json index 08dba90..789c165 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -64,20 +64,42 @@ "console": "integratedTerminal", }, + { - "name": "Launch Go EV-Data Example", + "name": "Launch Agent (relay)", "type": "go", "request": "launch", "mode": "auto", - "program": "${workspaceFolder}/examples/go/axon-ev-sync/main.go", + "program": "${workspaceFolder}/agent/main.go", + "args": [ + "relay", + "-i", + "github", + "--alias", + "relay2", + "-v" + ], + "envFile": "${workspaceFolder}/.env", "env": { - "LOG_LEVEL": "info" - } - }, - + "SCAFFOLD_DIR": "${workspaceFolder}/scaffold", + "SCAFFOLD_DOCKER_IMAGE": "cortex-axon-agent:local", + "PORT" : "7399", + "BUILTIN_PLUGIN_DIR": "${workspaceFolder}/agent/server/snykbroker/plugins", + } + }, + { + "name": "Launch gRPC Tunnel Server", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/server/cmd/main.go", + "env": { + "LOG_LEVEL": "info" + } + }, { - "name": "Launch Agent (relay)", + "name": "Launch Agent (relay gRPC)", "type": "go", "request": "launch", "mode": "auto", @@ -87,18 +109,24 @@ "-i", "github", "--alias", - "relay2", + "relay-grpc", "-v" ], "envFile": "${workspaceFolder}/.env", "env": { "SCAFFOLD_DIR": "${workspaceFolder}/scaffold", "SCAFFOLD_DOCKER_IMAGE": "cortex-axon-agent:local", - "PORT" : "7399", + "PORT": "7399", "BUILTIN_PLUGIN_DIR": "${workspaceFolder}/agent/server/snykbroker/plugins", - } + "RELAY_MODE": "grpc-tunnel", + "BROKER_SERVER_URL": "localhost:50052", + "BROKER_TOKEN": "4f49654b-000-0000-000-9deef1d9f2f6", + "CORTEX_TENANT_ID": "1", + "GRPC_INSECURE": "true", + "TUNNEL_COUNT": "1", + "CORTEX_API_TOKEN": "abc-123" + } }, - ] } \ No newline at end of file diff --git a/Makefile b/Makefile index 0a04bcc..615527c 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ all: setup proto: $(MAKE) -C agent proto + $(MAKE) -C server proto $(MAKE) -C sdks/python proto .PHONY: proto @@ -25,6 +26,7 @@ setup: test: $(MAKE) -C agent test + $(MAKE) -C server test @echo "TODO: sdk go test" $(MAKE) -C sdks/go test $(MAKE) -C scaffold test diff --git a/agent/Makefile b/agent/Makefile index 8caf507..9a6f461 100644 --- a/agent/Makefile +++ b/agent/Makefile @@ -11,7 +11,20 @@ GO_FILES := $(patsubst proto/%.proto,$(GENERATED_PATH)/%.pb.go,$(PROTO_FILES)) GOPATH ?= $(HOME)/go GOBIN ?= $(GOPATH)/bin -proto: setup $(GO_FILES) version +TUNNEL_PROTO = ../proto/tunnel/tunnel.proto +TUNNEL_GENERATED = $(GENERATED_DIR)/github.com/cortexapps/axon/tunnelpb + +proto: setup $(GO_FILES) tunnel-proto version + +tunnel-proto: $(TUNNEL_PROTO) + @echo "Generating tunnel protobuf for agent" + @mkdir -p $(GENERATED_DIR) + @protoc -I=../proto/tunnel \ + --go_out=$(GENERATED_DIR) --go-grpc_out=$(GENERATED_DIR) \ + --go_opt=Mtunnel.proto=github.com/cortexapps/axon/tunnelpb \ + --go-grpc_opt=Mtunnel.proto=github.com/cortexapps/axon/tunnelpb \ + $(TUNNEL_PROTO) +.PHONY: tunnel-proto version: $(GO_SDK_DIR)/version/agentversion.txt $(PYTHON_SDK_DIR)/cortex_axon/agentversion.py @@ -83,7 +96,17 @@ relay-test-with-proxy: relay-test: relay-test-no-proxy relay-test-with-proxy -.PHONY: relay-test relay-test-no-proxy relay-test-with-proxy +grpc-relay-test-no-proxy: + @echo "Running gRPC relay tests: no proxy" + cd test/relay && export PROXY=0 && ./relay_test.grpc.sh + +grpc-relay-test-with-proxy: + @echo "Running gRPC relay tests: with proxy" + cd test/relay && export PROXY=1 && ./relay_test.grpc.sh + +grpc-relay-test: grpc-relay-test-no-proxy grpc-relay-test-with-proxy + +.PHONY: relay-test relay-test-no-proxy relay-test-with-proxy grpc-relay-test grpc-relay-test-no-proxy grpc-relay-test-with-proxy run: proto go run main.go serve diff --git a/agent/cmd/grpc-tunnel-server/main.go b/agent/cmd/grpc-tunnel-server/main.go new file mode 100644 index 0000000..1454e18 --- /dev/null +++ b/agent/cmd/grpc-tunnel-server/main.go @@ -0,0 +1,341 @@ +// gRPC tunnel server for E2E testing. +// This server mimics the Cortex-side tunnel endpoint that accepts +// gRPC connections from Axon agents and dispatches HTTP requests through them. +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/cortexapps/axon/.generated/proto/github.com/cortexapps/axon/tunnelpb" + "github.com/google/uuid" + "google.golang.org/grpc" +) + +// TunnelServer implements the gRPC TunnelService. +type TunnelServer struct { + tunnelpb.UnimplementedTunnelServiceServer + + mu sync.RWMutex + streams map[string]*tunnelStream // keyed by broker token +} + +type tunnelStream struct { + stream tunnelpb.TunnelService_TunnelServer + hello *tunnelpb.ClientHello + streamID string + + // Pending requests waiting for responses + pendingMu sync.Mutex + pending map[string]chan *tunnelpb.HttpResponse +} + +func NewTunnelServer() *TunnelServer { + return &TunnelServer{ + streams: make(map[string]*tunnelStream), + } +} + +func (s *TunnelServer) Tunnel(stream tunnelpb.TunnelService_TunnelServer) error { + // Wait for ClientHello + msg, err := stream.Recv() + if err != nil { + log.Printf("Failed to receive first message: %v", err) + return err + } + + hello := msg.GetHello() + if hello == nil { + log.Printf("First message was not ClientHello") + return fmt.Errorf("first message must be ClientHello") + } + + streamID := uuid.New().String() + ts := &tunnelStream{ + stream: stream, + hello: hello, + streamID: streamID, + pending: make(map[string]chan *tunnelpb.HttpResponse), + } + + // Register stream + s.mu.Lock() + s.streams[hello.BrokerToken] = ts + s.mu.Unlock() + + log.Printf("Tunnel stream established: token=%s alias=%s integration=%s streamID=%s", + hello.BrokerToken, hello.Alias, hello.Integration, streamID) + + // Send ServerHello + serverHello := &tunnelpb.TunnelServerMessage{ + Message: &tunnelpb.TunnelServerMessage_Hello{ + Hello: &tunnelpb.ServerHello{ + ServerId: getServerID(), + HeartbeatIntervalMs: 30000, + StreamId: streamID, + }, + }, + } + if err := stream.Send(serverHello); err != nil { + s.removeStream(hello.BrokerToken) + return err + } + + // Handle incoming messages + for { + msg, err := stream.Recv() + if err == io.EOF { + log.Printf("Tunnel stream closed (EOF): token=%s", hello.BrokerToken) + break + } + if err != nil { + log.Printf("Tunnel stream error: token=%s err=%v", hello.BrokerToken, err) + break + } + + switch m := msg.Message.(type) { + case *tunnelpb.TunnelClientMessage_Heartbeat: + // Respond to heartbeat + hb := &tunnelpb.TunnelServerMessage{ + Message: &tunnelpb.TunnelServerMessage_Heartbeat{ + Heartbeat: &tunnelpb.Heartbeat{TimestampMs: time.Now().UnixMilli()}, + }, + } + if err := stream.Send(hb); err != nil { + log.Printf("Failed to send heartbeat response: %v", err) + } + + case *tunnelpb.TunnelClientMessage_HttpResponse: + ts.handleResponse(m.HttpResponse) + } + } + + s.removeStream(hello.BrokerToken) + return nil +} + +func (s *TunnelServer) removeStream(token string) { + s.mu.Lock() + delete(s.streams, token) + s.mu.Unlock() + log.Printf("Tunnel stream removed: token=%s", token) +} + +func (s *TunnelServer) getStream(token string) *tunnelStream { + s.mu.RLock() + defer s.mu.RUnlock() + return s.streams[token] +} + +func (s *TunnelServer) streamCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.streams) +} + +func (ts *tunnelStream) handleResponse(resp *tunnelpb.HttpResponse) { + ts.pendingMu.Lock() + ch, ok := ts.pending[resp.RequestId] + ts.pendingMu.Unlock() + + if ok { + ch <- resp + } else { + log.Printf("Received response for unknown request: %s", resp.RequestId) + } +} + +func (ts *tunnelStream) sendRequest(ctx context.Context, req *tunnelpb.HttpRequest) (*tunnelpb.HttpResponse, error) { + // Create response channel + respChan := make(chan *tunnelpb.HttpResponse, 1) + + ts.pendingMu.Lock() + ts.pending[req.RequestId] = respChan + ts.pendingMu.Unlock() + + defer func() { + ts.pendingMu.Lock() + delete(ts.pending, req.RequestId) + ts.pendingMu.Unlock() + }() + + // Send request + msg := &tunnelpb.TunnelServerMessage{ + Message: &tunnelpb.TunnelServerMessage_HttpRequest{ + HttpRequest: req, + }, + } + if err := ts.stream.Send(msg); err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // Wait for response with timeout + timeout := 60 * time.Second + if req.TimeoutMs > 0 { + timeout = time.Duration(req.TimeoutMs) * time.Millisecond + } + + select { + case resp := <-respChan: + return resp, nil + case <-time.After(timeout): + return nil, fmt.Errorf("request timeout after %v", timeout) + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// HTTPHandler handles dispatch requests from the test +type HTTPHandler struct { + server *TunnelServer +} + +func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + + // Handle healthz endpoint + if path == "/healthz" { + h.handleHealthz(w, r) + return + } + + // Handle broker dispatch: /broker/{token}/{path...} + if strings.HasPrefix(path, "/broker/") { + h.handleDispatch(w, r) + return + } + + http.NotFound(w, r) +} + +func (h *HTTPHandler) handleHealthz(w http.ResponseWriter, r *http.Request) { + count := h.server.streamCount() + resp := map[string]interface{}{ + "status": "ok", + "streams": count, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func (h *HTTPHandler) handleDispatch(w http.ResponseWriter, r *http.Request) { + // Parse /broker/{token}/{path...} + parts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/broker/"), "/", 2) + if len(parts) < 1 { + http.Error(w, "missing token", http.StatusBadRequest) + return + } + + token := parts[0] + path := "/" + if len(parts) > 1 { + path = "/" + parts[1] + } + + // Find stream for token + ts := h.server.getStream(token) + if ts == nil { + http.Error(w, fmt.Sprintf("no tunnel for token: %s", token), http.StatusBadGateway) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + + // Convert headers + headers := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + + // Create HTTP request + req := &tunnelpb.HttpRequest{ + RequestId: uuid.New().String(), + Method: r.Method, + Path: path, + Headers: headers, + Body: body, + ChunkIndex: 0, + IsFinal: true, + TimeoutMs: 60000, + } + + // Send through tunnel + resp, err := ts.sendRequest(r.Context(), req) + if err != nil { + http.Error(w, fmt.Sprintf("tunnel request failed: %v", err), http.StatusBadGateway) + return + } + + // Write response + for k, v := range resp.Headers { + w.Header().Set(k, v) + } + w.WriteHeader(int(resp.StatusCode)) + w.Write(resp.Body) +} + +func getServerID() string { + if id := os.Getenv("HOSTNAME"); id != "" { + return id + } + return uuid.New().String() +} + +func main() { + grpcPort := os.Getenv("GRPC_PORT") + if grpcPort == "" { + grpcPort = "50051" + } + + httpPort := os.Getenv("HTTP_PORT") + if httpPort == "" { + httpPort = "8080" + } + + // Create tunnel server + tunnelServer := NewTunnelServer() + + // Start gRPC server + grpcLis, err := net.Listen("tcp", ":"+grpcPort) + if err != nil { + log.Fatalf("Failed to listen on gRPC port %s: %v", grpcPort, err) + } + + grpcServer := grpc.NewServer() + tunnelpb.RegisterTunnelServiceServer(grpcServer, tunnelServer) + + go func() { + log.Printf("gRPC server listening on :%s", grpcPort) + if err := grpcServer.Serve(grpcLis); err != nil { + log.Fatalf("gRPC server failed: %v", err) + } + }() + + // Start HTTP server + httpHandler := &HTTPHandler{server: tunnelServer} + httpServer := &http.Server{ + Addr: ":" + httpPort, + Handler: httpHandler, + } + + log.Printf("HTTP server listening on :%s", httpPort) + if err := httpServer.ListenAndServe(); err != nil { + log.Fatalf("HTTP server failed: %v", err) + } +} diff --git a/agent/cmd/relay.go b/agent/cmd/relay.go index 9d15582..f369462 100644 --- a/agent/cmd/relay.go +++ b/agent/cmd/relay.go @@ -6,6 +6,7 @@ import ( "github.com/cortexapps/axon/common" "github.com/cortexapps/axon/config" + "github.com/cortexapps/axon/server/grpctunnel" "github.com/cortexapps/axon/server/handler" "github.com/cortexapps/axon/server/snykbroker" "github.com/spf13/cobra" @@ -81,11 +82,16 @@ func init() { } func buildRelayStack(cmd *cobra.Command, cfg config.AgentConfig, integrationInfo common.IntegrationInfo) fx.Option { + relayModule := snykbroker.Module + if cfg.IsGRPCTunnel() { + relayModule = grpctunnel.Module + } + stack := fx.Options( initStack(cmd, cfg, integrationInfo), AgentModule, fx.Provide(handler.NewHandlerManager), - snykbroker.Module, + relayModule, ) return stack } diff --git a/agent/cmd/serve.go b/agent/cmd/serve.go index 73ad5e3..f3e6563 100644 --- a/agent/cmd/serve.go +++ b/agent/cmd/serve.go @@ -7,6 +7,7 @@ import ( "github.com/cortexapps/axon/common" "github.com/cortexapps/axon/config" + "github.com/cortexapps/axon/server/grpctunnel" "github.com/cortexapps/axon/server/http" "github.com/cortexapps/axon/server/snykbroker" "github.com/spf13/cobra" @@ -41,11 +42,16 @@ var serveCmd = &cobra.Command{ Alias: config.IntegrationAlias, } + relayModule := snykbroker.Module + if config.IsGRPCTunnel() { + relayModule = grpctunnel.Module + } + stack := fx.Options( initStack(cmd, config, info), AgentModule, http.Module, - snykbroker.Module, + relayModule, ) startAgent(stack) diff --git a/agent/config/config.go b/agent/config/config.go index 4a6d964..3fc44d9 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -59,6 +59,13 @@ func (m RelayReflectorMode) IsEnabled() bool { return m != RelayReflectorDisabled } +type RelayMode string + +const ( + RelayModeSnykBroker RelayMode = "snyk-broker" + RelayModeGrpcTunnel RelayMode = "grpc-tunnel" +) + type AgentConfig struct { GrpcPort int CortexApiBaseUrl string @@ -86,12 +93,26 @@ type AgentConfig struct { HttpRelayReflectorMode RelayReflectorMode ReflectorWebSocketUpgrade bool RelayIdleTimeout time.Duration + + // RelayMode selects the tunnel implementation: "snyk-broker" or "grpc-tunnel". + RelayMode string + // TunnelCount is the number of parallel gRPC tunnel streams to open (grpc-tunnel mode only). + TunnelCount int + // GrpcInsecure disables TLS on the gRPC tunnel connection (separate from HttpDisableTLS). + GrpcInsecure bool + // GrpcTunnelServer is the address of the gRPC tunnel server (host:port). + GrpcTunnelServer string } func (ac AgentConfig) HttpBaseUrl() string { return fmt.Sprintf("http://localhost:%d", ac.HttpServerPort) } +// IsGRPCTunnel returns true if the relay mode is grpc-tunnel. +func (ac AgentConfig) IsGRPCTunnel() bool { + return ac.RelayMode == "grpc-tunnel" +} + func (ac AgentConfig) Print() { fmt.Println("Agent Configuration:") if ac.GrpcPort != DefaultGrpcPort { @@ -302,6 +323,26 @@ func NewAgentEnvConfig() AgentConfig { cfg.RelayIdleTimeout = rit } + cfg.RelayMode = "snyk-broker" + if relayMode := os.Getenv("RELAY_MODE"); relayMode != "" { + cfg.RelayMode = relayMode + } + + cfg.TunnelCount = 2 + if tunnelCount := os.Getenv("TUNNEL_COUNT"); tunnelCount != "" { + tc, err := strconv.Atoi(tunnelCount) + if err != nil { + panic(err) + } + cfg.TunnelCount = tc + } + + if grpcInsecure := os.Getenv("GRPC_INSECURE"); grpcInsecure == "true" { + cfg.GrpcInsecure = true + } + + cfg.GrpcTunnelServer = os.Getenv("GRPC_TUNNEL_SERVER") + return cfg } diff --git a/agent/proto/tunnel.proto b/agent/proto/tunnel.proto new file mode 100644 index 0000000..b218111 --- /dev/null +++ b/agent/proto/tunnel.proto @@ -0,0 +1,111 @@ +syntax = "proto3"; + +package cortex.axon.tunnel; + +option go_package = "github.com/cortexapps/axon/tunnelpb"; + +// TunnelService provides a bidirectional streaming tunnel between +// Axon agents (clients) and the gRPC relay server. The server dispatches +// HTTP requests from the Cortex backend through the tunnel to agents +// running inside customer networks. +service TunnelService { + // Tunnel establishes a persistent bidirectional stream. + // Client sends ClientHello first, then heartbeats and HTTP responses. + // Server sends ServerHello, then heartbeats and HTTP requests. + rpc Tunnel(stream TunnelClientMessage) returns (stream TunnelServerMessage); +} + +// Client → Server envelope +message TunnelClientMessage { + oneof message { + ClientHello hello = 1; + Heartbeat heartbeat = 2; + HttpResponse http_response = 3; + } +} + +// Server → Client envelope +message TunnelServerMessage { + oneof message { + ServerHello hello = 1; + Heartbeat heartbeat = 2; + HttpRequest http_request = 3; + } +} + +// ClientHello is the first message sent by the client after opening a stream. +// It carries identity information and the Cortex-API-issued broker token. +message ClientHello { + // Cortex-API-issued token, used for BROKER_SERVER dispatch routing. + string broker_token = 1; + // Client software version. + string client_version = 2; + // Tenant identifier from Cortex API registration. + string tenant_id = 3; + // Integration type, e.g. "github", "jira". + string integration = 4; + // Integration alias name. + string alias = 5; + // Unique agent instance ID (from config.InstanceId). + string instance_id = 6; + // Cortex API token for optional server-side JWT validation. + string cortex_api_token = 7; + // Arbitrary metadata. + map metadata = 8; +} + +// ServerHello is the response to ClientHello. +message ServerHello { + // Server hostname (HOSTNAME env var, or UUID fallback). + string server_id = 1; + // Interval at which the server sends heartbeats. + int32 heartbeat_interval_ms = 2; + // Server-generated UUID for this specific stream. + string stream_id = 3; +} + +// Heartbeat is sent by both sides to verify tunnel liveness. +message Heartbeat { + // Unix timestamp in milliseconds. + int64 timestamp_ms = 1; +} + +// HttpRequest is an HTTP request dispatched from the server to the client. +// Large bodies are chunked: the first chunk carries method/path/headers, +// subsequent chunks carry only body/chunk_index/is_final. +message HttpRequest { + // Unique request identifier for correlating responses. + string request_id = 1; + // HTTP method (GET, POST, PUT, DELETE, etc.). + string method = 2; + // Request path (e.g. "/api/v1/repos"). + string path = 3; + // HTTP headers. Only set on first chunk (chunk_index=0). + map headers = 4; + // Request body chunk. + bytes body = 5; + // Zero-based chunk index. + int32 chunk_index = 6; + // True if this is the last chunk for this request. + bool is_final = 7; + // Maximum time in milliseconds the server will wait for a response. + int32 timeout_ms = 8; +} + +// HttpResponse is an HTTP response sent from the client back to the server. +// Large bodies are chunked: the first chunk carries status_code/headers, +// subsequent chunks carry only body/chunk_index/is_final. +message HttpResponse { + // Must match the request_id of the corresponding HttpRequest. + string request_id = 1; + // HTTP status code. Only set on first chunk (chunk_index=0). + int32 status_code = 2; + // HTTP response headers. Only set on first chunk (chunk_index=0). + map headers = 3; + // Response body chunk. + bytes body = 4; + // Zero-based chunk index. + int32 chunk_index = 5; + // True if this is the last chunk for this response. + bool is_final = 6; +} diff --git a/agent/server/grpctunnel/chunking_test.go b/agent/server/grpctunnel/chunking_test.go new file mode 100644 index 0000000..a4c04e3 --- /dev/null +++ b/agent/server/grpctunnel/chunking_test.go @@ -0,0 +1,327 @@ +package grpctunnel + +import ( + "testing" + + pb "github.com/cortexapps/axon/.generated/proto/github.com/cortexapps/axon/tunnelpb" +) + +func TestRequestAssembler_SingleChunk(t *testing.T) { + ra := newRequestAssembler() + + // A single-chunk request (chunk_index=0, is_final=true) should pass through directly. + req := &pb.HttpRequest{ + RequestId: "req-1", + Method: "GET", + Path: "/api/v1/repos", + Headers: map[string]string{"Authorization": "Bearer tok"}, + Body: []byte("hello"), + ChunkIndex: 0, + IsFinal: true, + TimeoutMs: 5000, + } + + result := ra.handleChunk(req) + if result == nil { + t.Fatal("expected non-nil result for single-chunk request") + } + // Single-chunk should return the exact same pointer (fast path). + if result != req { + t.Fatal("expected single-chunk to return the original request pointer") + } + if result.RequestId != "req-1" { + t.Errorf("expected RequestId=req-1, got %s", result.RequestId) + } + if result.Method != "GET" { + t.Errorf("expected Method=GET, got %s", result.Method) + } + if string(result.Body) != "hello" { + t.Errorf("expected body=hello, got %s", string(result.Body)) + } + + // No pending requests should remain. + ra.mu.Lock() + pendingCount := len(ra.pending) + ra.mu.Unlock() + if pendingCount != 0 { + t.Errorf("expected 0 pending requests, got %d", pendingCount) + } +} + +func TestRequestAssembler_MultiChunk(t *testing.T) { + ra := newRequestAssembler() + + // Chunk 0: first chunk with metadata. + chunk0 := &pb.HttpRequest{ + RequestId: "req-2", + Method: "POST", + Path: "/api/v1/upload", + Headers: map[string]string{"Content-Type": "application/octet-stream"}, + Body: []byte("chunk0-"), + ChunkIndex: 0, + IsFinal: false, + TimeoutMs: 10000, + } + result := ra.handleChunk(chunk0) + if result != nil { + t.Fatal("expected nil result for non-final first chunk") + } + + // Chunk 1: continuation chunk (only has body). + chunk1 := &pb.HttpRequest{ + RequestId: "req-2", + Body: []byte("chunk1-"), + ChunkIndex: 1, + IsFinal: false, + } + result = ra.handleChunk(chunk1) + if result != nil { + t.Fatal("expected nil result for non-final continuation chunk") + } + + // Chunk 2: final chunk. + chunk2 := &pb.HttpRequest{ + RequestId: "req-2", + Body: []byte("chunk2"), + ChunkIndex: 2, + IsFinal: true, + } + result = ra.handleChunk(chunk2) + if result == nil { + t.Fatal("expected non-nil result for final chunk") + } + + // Verify assembled request carries metadata from chunk 0. + if result.RequestId != "req-2" { + t.Errorf("expected RequestId=req-2, got %s", result.RequestId) + } + if result.Method != "POST" { + t.Errorf("expected Method=POST, got %s", result.Method) + } + if result.Path != "/api/v1/upload" { + t.Errorf("expected Path=/api/v1/upload, got %s", result.Path) + } + if result.Headers["Content-Type"] != "application/octet-stream" { + t.Errorf("expected Content-Type header, got %v", result.Headers) + } + if result.TimeoutMs != 10000 { + t.Errorf("expected TimeoutMs=10000, got %d", result.TimeoutMs) + } + if !result.IsFinal { + t.Error("expected IsFinal=true on assembled request") + } + + // Verify body is concatenated in order. + expectedBody := "chunk0-chunk1-chunk2" + if string(result.Body) != expectedBody { + t.Errorf("expected body=%q, got %q", expectedBody, string(result.Body)) + } + + // No pending requests should remain after assembly. + ra.mu.Lock() + pendingCount := len(ra.pending) + ra.mu.Unlock() + if pendingCount != 0 { + t.Errorf("expected 0 pending requests after assembly, got %d", pendingCount) + } +} + +func TestRequestAssembler_FirstChunkStoresMetadata(t *testing.T) { + ra := newRequestAssembler() + + // First chunk carries all metadata. + chunk0 := &pb.HttpRequest{ + RequestId: "req-3", + Method: "PUT", + Path: "/api/v1/data", + Headers: map[string]string{"X-Custom": "value1", "Accept": "application/json"}, + Body: []byte("part1"), + ChunkIndex: 0, + IsFinal: false, + TimeoutMs: 30000, + } + ra.handleChunk(chunk0) + + // Subsequent chunk has different method/path/headers on the proto message, + // but the assembler should use only the first chunk's metadata. + chunk1 := &pb.HttpRequest{ + RequestId: "req-3", + Method: "DELETE", // should be ignored + Path: "/wrong/path", // should be ignored + Headers: map[string]string{"X-Wrong": "ignored"}, // should be ignored + Body: []byte("part2"), + ChunkIndex: 1, + IsFinal: true, + TimeoutMs: 99999, // should be ignored + } + result := ra.handleChunk(chunk1) + if result == nil { + t.Fatal("expected non-nil result for final chunk") + } + + // Verify metadata comes from chunk 0, not chunk 1. + if result.Method != "PUT" { + t.Errorf("expected Method=PUT from first chunk, got %s", result.Method) + } + if result.Path != "/api/v1/data" { + t.Errorf("expected Path=/api/v1/data from first chunk, got %s", result.Path) + } + if result.TimeoutMs != 30000 { + t.Errorf("expected TimeoutMs=30000 from first chunk, got %d", result.TimeoutMs) + } + if result.Headers["X-Custom"] != "value1" { + t.Errorf("expected X-Custom=value1 from first chunk, got %v", result.Headers) + } + if _, ok := result.Headers["X-Wrong"]; ok { + t.Error("expected X-Wrong header from continuation chunk to be ignored") + } + + // Body should be concatenated. + if string(result.Body) != "part1part2" { + t.Errorf("expected body=part1part2, got %q", string(result.Body)) + } +} + +func TestRequestAssembler_IncompleteRequestDiscarded(t *testing.T) { + ra := newRequestAssembler() + + // Start a multi-chunk request but never send the final chunk. + chunk0 := &pb.HttpRequest{ + RequestId: "req-4", + Method: "POST", + Path: "/api/v1/big", + Headers: map[string]string{"Content-Type": "text/plain"}, + Body: []byte("partial-"), + ChunkIndex: 0, + IsFinal: false, + TimeoutMs: 5000, + } + ra.handleChunk(chunk0) + + chunk1 := &pb.HttpRequest{ + RequestId: "req-4", + Body: []byte("data"), + ChunkIndex: 1, + IsFinal: false, + } + ra.handleChunk(chunk1) + + // Verify there is a pending request. + ra.mu.Lock() + pendingCount := len(ra.pending) + ra.mu.Unlock() + if pendingCount != 1 { + t.Fatalf("expected 1 pending request, got %d", pendingCount) + } + + // Simulate stream close: discardAll should remove incomplete requests. + ra.discardAll() + + ra.mu.Lock() + pendingCount = len(ra.pending) + ra.mu.Unlock() + if pendingCount != 0 { + t.Errorf("expected 0 pending requests after discardAll, got %d", pendingCount) + } +} + +func TestRequestAssembler_OrphanChunkIgnored(t *testing.T) { + ra := newRequestAssembler() + + // Send a continuation chunk without a preceding first chunk. + orphan := &pb.HttpRequest{ + RequestId: "req-orphan", + Body: []byte("orphan-data"), + ChunkIndex: 2, + IsFinal: true, + } + result := ra.handleChunk(orphan) + if result != nil { + t.Error("expected nil result for orphan chunk with no matching first chunk") + } + + // No pending requests should exist. + ra.mu.Lock() + pendingCount := len(ra.pending) + ra.mu.Unlock() + if pendingCount != 0 { + t.Errorf("expected 0 pending requests, got %d", pendingCount) + } +} + +func TestRequestAssembler_MultipleConcurrentRequests(t *testing.T) { + ra := newRequestAssembler() + + // Start two multi-chunk requests interleaved. + ra.handleChunk(&pb.HttpRequest{ + RequestId: "req-a", Method: "GET", Path: "/a", + Headers: map[string]string{"X-Req": "a"}, + Body: []byte("a0-"), ChunkIndex: 0, IsFinal: false, TimeoutMs: 1000, + }) + ra.handleChunk(&pb.HttpRequest{ + RequestId: "req-b", Method: "POST", Path: "/b", + Headers: map[string]string{"X-Req": "b"}, + Body: []byte("b0-"), ChunkIndex: 0, IsFinal: false, TimeoutMs: 2000, + }) + + // Continue both. + ra.handleChunk(&pb.HttpRequest{ + RequestId: "req-a", Body: []byte("a1"), ChunkIndex: 1, IsFinal: true, + }) + resultA := ra.handleChunk(&pb.HttpRequest{ + RequestId: "req-a", Body: []byte("a1"), ChunkIndex: 1, IsFinal: true, + }) + + ra.handleChunk(&pb.HttpRequest{ + RequestId: "req-b", Body: []byte("b1"), ChunkIndex: 1, IsFinal: true, + }) + + // req-a was already completed, so a second final chunk for it is an orphan. + // Let's re-test properly: we need to restart. + ra2 := newRequestAssembler() + ra2.handleChunk(&pb.HttpRequest{ + RequestId: "req-a", Method: "GET", Path: "/a", + Headers: map[string]string{"X-Req": "a"}, + Body: []byte("a0-"), ChunkIndex: 0, IsFinal: false, TimeoutMs: 1000, + }) + ra2.handleChunk(&pb.HttpRequest{ + RequestId: "req-b", Method: "POST", Path: "/b", + Headers: map[string]string{"X-Req": "b"}, + Body: []byte("b0-"), ChunkIndex: 0, IsFinal: false, TimeoutMs: 2000, + }) + + // Finalize req-b first. + resultB := ra2.handleChunk(&pb.HttpRequest{ + RequestId: "req-b", Body: []byte("b1"), ChunkIndex: 1, IsFinal: true, + }) + if resultB == nil { + t.Fatal("expected non-nil result for req-b") + } + if resultB.Method != "POST" || resultB.Path != "/b" { + t.Errorf("req-b metadata wrong: method=%s path=%s", resultB.Method, resultB.Path) + } + if string(resultB.Body) != "b0-b1" { + t.Errorf("req-b body wrong: got %q", string(resultB.Body)) + } + + // req-a should still be pending. + ra2.mu.Lock() + if _, ok := ra2.pending["req-a"]; !ok { + t.Error("expected req-a to still be pending") + } + ra2.mu.Unlock() + + // Finalize req-a. + resultA = ra2.handleChunk(&pb.HttpRequest{ + RequestId: "req-a", Body: []byte("a1"), ChunkIndex: 1, IsFinal: true, + }) + if resultA == nil { + t.Fatal("expected non-nil result for req-a") + } + if resultA.Method != "GET" || resultA.Path != "/a" { + t.Errorf("req-a metadata wrong: method=%s path=%s", resultA.Method, resultA.Path) + } + if string(resultA.Body) != "a0-a1" { + t.Errorf("req-a body wrong: got %q", string(resultA.Body)) + } +} diff --git a/agent/server/grpctunnel/module.go b/agent/server/grpctunnel/module.go new file mode 100644 index 0000000..e215339 --- /dev/null +++ b/agent/server/grpctunnel/module.go @@ -0,0 +1,11 @@ +package grpctunnel + +import ( + "github.com/cortexapps/axon/server/snykbroker" + "go.uber.org/fx" +) + +var Module = fx.Module("grpctunnel", + fx.Provide(snykbroker.NewRegistration), + fx.Invoke(NewTunnelClient), +) diff --git a/agent/server/grpctunnel/tunnel_client.go b/agent/server/grpctunnel/tunnel_client.go new file mode 100644 index 0000000..0f59555 --- /dev/null +++ b/agent/server/grpctunnel/tunnel_client.go @@ -0,0 +1,897 @@ +package grpctunnel + +import ( + "bufio" + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "fmt" + "io" + "math/rand/v2" + "net" + "net/http" + "net/url" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/cortexapps/axon/common" + "github.com/cortexapps/axon/config" + cortexHttp "github.com/cortexapps/axon/server/http" + "github.com/cortexapps/axon/server/requestexecutor" + "github.com/cortexapps/axon/server/snykbroker" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" + pb "github.com/cortexapps/axon/.generated/proto/github.com/cortexapps/axon/tunnelpb" + "github.com/gorilla/mux" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/fx" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +const maxChunkSize = 1024 * 1024 // 1MB + +// tunnelClient implements the snykbroker.RelayInstanceManager interface +// using gRPC bidirectional streaming instead of snyk-broker. +type tunnelClient struct { + config config.AgentConfig + logger *zap.Logger + integrationInfo common.IntegrationInfo + registration snykbroker.Registration + executor requestexecutor.RequestExecutor + httpClient *http.Client + + running atomic.Bool + mu sync.Mutex + conn *grpc.ClientConn + streams []*tunnelStream + parentCtx context.Context + cancelAll context.CancelFunc + + // Metrics + connectionsActive *prometheus.GaugeVec + requestsTotal *prometheus.CounterVec + reconnectsTotal *prometheus.CounterVec + requestDuration *prometheus.HistogramVec +} + +type tunnelStream struct { + streamID string + serverID string + cancel context.CancelFunc + done chan struct{} +} + +// sendFunc is a mutex-protected function for sending messages on a gRPC stream. +// Multiple goroutines (heartbeat responses, HTTP response handlers) call Send() +// concurrently; wrapping it in a mutex prevents data races on the underlying stream. +type sendFunc func(msg *pb.TunnelClientMessage) error + +// requestAssembler reassembles chunked HTTP requests from the server. +// The server chunks requests larger than 1MB, sending method/path/headers/timeoutMs +// only on the first chunk (chunk_index=0) and body data on all chunks. +type requestAssembler struct { + mu sync.Mutex + pending map[string]*pendingRequest +} + +type pendingRequest struct { + method string + path string + headers map[string]string + body []byte + timeoutMs int32 +} + +func newRequestAssembler() *requestAssembler { + return &requestAssembler{ + pending: make(map[string]*pendingRequest), + } +} + +// handleChunk processes an incoming HttpRequest chunk. It returns a fully +// assembled *pb.HttpRequest when the final chunk arrives, or nil if more +// chunks are still expected. +func (ra *requestAssembler) handleChunk(chunk *pb.HttpRequest) *pb.HttpRequest { + // Fast path: single-chunk request (most common case). + if chunk.ChunkIndex == 0 && chunk.IsFinal { + return chunk + } + + ra.mu.Lock() + defer ra.mu.Unlock() + + if chunk.ChunkIndex == 0 { + // First chunk of a multi-chunk request: store metadata + body. + ra.pending[chunk.RequestId] = &pendingRequest{ + method: chunk.Method, + path: chunk.Path, + headers: chunk.Headers, + body: append([]byte(nil), chunk.Body...), + timeoutMs: chunk.TimeoutMs, + } + return nil + } + + // Continuation chunk. + pr, ok := ra.pending[chunk.RequestId] + if !ok { + // Orphan chunk — no first chunk was received; discard. + return nil + } + + pr.body = append(pr.body, chunk.Body...) + + if !chunk.IsFinal { + return nil + } + + // Final chunk: assemble and remove from pending. + delete(ra.pending, chunk.RequestId) + return &pb.HttpRequest{ + RequestId: chunk.RequestId, + Method: pr.method, + Path: pr.path, + Headers: pr.headers, + Body: pr.body, + TimeoutMs: pr.timeoutMs, + IsFinal: true, + } +} + +// discardAll removes all incomplete pending requests (called on stream close). +func (ra *requestAssembler) discardAll() { + ra.mu.Lock() + defer ra.mu.Unlock() + ra.pending = make(map[string]*pendingRequest) +} + +type TunnelClientParams struct { + fx.In + Lifecycle fx.Lifecycle `optional:"true"` + Config config.AgentConfig + Logger *zap.Logger + IntegrationInfo common.IntegrationInfo + HttpServer cortexHttp.Server + Registration snykbroker.Registration + HttpClient *http.Client `optional:"true"` + Registry *prometheus.Registry `optional:"true"` +} + +func NewTunnelClient(p TunnelClientParams) snykbroker.RelayInstanceManager { + httpClient := p.HttpClient + if httpClient == nil { + httpClient = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + } + + tc := &tunnelClient{ + config: p.Config, + logger: p.Logger.Named("grpc-tunnel"), + integrationInfo: p.IntegrationInfo, + registration: p.Registration, + httpClient: httpClient, + connectionsActive: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "grpc_tunnel_connections_active", + Help: "Number of active gRPC tunnel streams", + }, + []string{"server_id"}, + ), + requestsTotal: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "grpc_tunnel_requests_total", + Help: "Total requests dispatched through gRPC tunnel", + }, + []string{"method", "status"}, + ), + reconnectsTotal: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "grpc_tunnel_reconnects_total", + Help: "Total tunnel reconnection attempts", + }, + []string{"server_id"}, + ), + requestDuration: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "grpc_tunnel_request_duration_ms", + Help: "Request execution latency in milliseconds", + Buckets: prometheus.ExponentialBuckets(10, 2, 12), + }, + []string{"method"}, + ), + } + + p.HttpServer.RegisterHandler(tc) + + if p.Registry != nil { + p.Registry.MustRegister( + tc.connectionsActive, + tc.requestsTotal, + tc.reconnectsTotal, + tc.requestDuration, + ) + } + + if p.Lifecycle != nil { + p.Lifecycle.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return tc.Start() + }, + OnStop: func(ctx context.Context) error { + return tc.Close() + }, + }) + } + + return tc +} + +func (tc *tunnelClient) RegisterRoutes(mux *mux.Router) error { + subRouter := mux.PathPrefix(fmt.Sprintf("%s/broker", cortexHttp.AxonPathRoot)).Subrouter() + subRouter.HandleFunc("/restart", tc.handleRestart) + subRouter.HandleFunc("/systemcheck", tc.handleSystemCheck) + return nil +} + +func (tc *tunnelClient) ServeHTTP(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusNotFound) +} + +func (tc *tunnelClient) handleRestart(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if err := tc.Restart(); err != nil { + tc.logger.Error("Restart failed", zap.Error(err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} + +func (tc *tunnelClient) handleSystemCheck(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + tc.mu.Lock() + streamCount := len(tc.streams) + tc.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"status":"ok","relay_mode":"grpc-tunnel","streams":%d}`, streamCount) +} + +func (tc *tunnelClient) Start() error { + if !tc.running.CompareAndSwap(false, true) { + return fmt.Errorf("already started") + } + + go tc.startAsync() + return nil +} + +func (tc *tunnelClient) startAsync() { + // Check for direct connection config (skip registration if both are set). + serverAddr, token := tc.getConnectionConfig() + + // Render accept file and create RequestExecutor. + if err := tc.setupExecutor(); err != nil { + tc.logger.Error("Failed to set up request executor", zap.Error(err)) + return + } + + // Strip http(s):// if present — gRPC expects host:port. + serverAddr = stripScheme(serverAddr) + + // Establish gRPC connection. + dialOpts, dialAddr := tc.buildDialOptions(serverAddr) + conn, err := grpc.NewClient(dialAddr, dialOpts...) + if err != nil { + tc.logger.Error("Failed to connect to gRPC server", zap.String("addr", serverAddr), zap.Error(err)) + return + } + + tc.mu.Lock() + tc.conn = conn + ctx, cancel := context.WithCancel(context.Background()) + tc.parentCtx = ctx + tc.cancelAll = cancel + tc.mu.Unlock() + + // Open N tunnel streams. + client := pb.NewTunnelServiceClient(conn) + seenServers := make(map[string]bool) + + for i := 0; i < tc.config.TunnelCount; i++ { + ts := tc.openStream(ctx, client, token, i, seenServers) + if ts != nil { + tc.mu.Lock() + tc.streams = append(tc.streams, ts) + tc.mu.Unlock() + } + } + + tc.logger.Info("gRPC tunnel started", + zap.Int("streams", len(tc.streams)), + zap.String("serverAddr", serverAddr), + ) +} + +// getConnectionConfig returns the server address and token for the gRPC tunnel. +// If both BROKER_SERVER_URL and BROKER_TOKEN are set, uses those directly (skips registration). +// Otherwise, registers with the Cortex API to get the server URI and token. +func (tc *tunnelClient) getConnectionConfig() (serverAddr, token string) { + envServerURL := os.Getenv("BROKER_SERVER_URL") + envToken := os.Getenv("BROKER_TOKEN") + + // If both are provided, skip registration and use direct connection. + if envServerURL != "" && envToken != "" { + tc.logger.Info("Using direct connection config (skipping registration)", + zap.String("serverUrl", envServerURL), + ) + return envServerURL, envToken + } + + // Register with Cortex API to get server URI + token. + var regInfo *snykbroker.RegistrationInfoResponse + backoff := tc.config.FailWaitTime + for tc.running.Load() { + var err error + regInfo, err = tc.registration.Register(tc.integrationInfo.Integration, tc.integrationInfo.Alias) + if err != nil { + tc.logger.Error("Registration failed, retrying", zap.Error(err), zap.Duration("backoff", backoff)) + time.Sleep(backoff) + backoff = min(backoff*2, 30*time.Second) + continue + } + break + } + + if regInfo == nil || !tc.running.Load() { + return "", "" + } + + tc.logger.Info("Registered with Cortex API", + zap.String("serverUri", regInfo.ServerUri), + ) + + // Use BROKER_SERVER_URL override if set, otherwise use registered server URI. + serverAddr = envServerURL + if serverAddr == "" { + serverAddr = regInfo.ServerUri + } + + return serverAddr, regInfo.Token +} + +func (tc *tunnelClient) setupExecutor() error { + af, err := tc.integrationInfo.ToAcceptFile(tc.config, tc.logger) + if err != nil { + return fmt.Errorf("error creating accept file: %w", err) + } + + rendered, err := af.Render(tc.logger) + if err != nil { + return fmt.Errorf("error rendering accept file: %w", err) + } + + // Parse rendered rules. + af2, err := acceptfile.NewAcceptFile(rendered, tc.config, tc.logger) + if err != nil { + return fmt.Errorf("error parsing rendered accept file: %w", err) + } + + rules := af2.Wrapper().PrivateRules() + tc.executor = requestexecutor.NewRequestExecutor(rules, tc.httpClient, tc.logger) + return nil +} + +const handshakeTimeout = 30 * time.Second + +func (tc *tunnelClient) openStream( + ctx context.Context, + client pb.TunnelServiceClient, + token string, + index int, + seenServers map[string]bool, +) *tunnelStream { + streamCtx, cancel := context.WithCancel(ctx) + + stream, err := client.Tunnel(streamCtx) + if err != nil { + tc.logger.Error("Failed to open tunnel stream", zap.Int("index", index), zap.Error(err)) + cancel() + return nil + } + + // Cancel the stream if handshake (Send+Recv) takes too long. + handshakeTimer := time.AfterFunc(handshakeTimeout, func() { + tc.logger.Warn("Handshake timeout, cancelling stream", zap.Int("index", index)) + cancel() + }) + defer handshakeTimer.Stop() + + // Send ClientHello. + hello := &pb.TunnelClientMessage{ + Message: &pb.TunnelClientMessage_Hello{ + Hello: &pb.ClientHello{ + BrokerToken: token, + ClientVersion: common.ClientVersion, + TenantId: os.Getenv("CORTEX_TENANT_ID"), + Integration: tc.integrationInfo.Integration.String(), + Alias: tc.integrationInfo.Alias, + InstanceId: tc.config.InstanceId, + CortexApiToken: tc.config.CortexApiToken, + }, + }, + } + + if err := stream.Send(hello); err != nil { + tc.logger.Error("Failed to send ClientHello", zap.Error(err)) + cancel() + return nil + } + + // Receive ServerHello. + msg, err := stream.Recv() + if err != nil { + tc.logger.Error("Failed to receive ServerHello", zap.Error(err)) + cancel() + return nil + } + + serverHello := msg.GetHello() + if serverHello == nil { + tc.logger.Error("Expected ServerHello, got something else") + cancel() + return nil + } + + // Dedup: if we're already connected to this server, skip. + if seenServers[serverHello.ServerId] { + tc.logger.Info("Already connected to this server, skipping duplicate stream", + zap.String("serverId", serverHello.ServerId), + zap.Int("index", index), + ) + cancel() + return nil + } + seenServers[serverHello.ServerId] = true + + tc.logger.Info("Tunnel stream established", + zap.String("streamId", serverHello.StreamId), + zap.String("serverId", serverHello.ServerId), + zap.Int32("heartbeatIntervalMs", serverHello.HeartbeatIntervalMs), + zap.Int("index", index), + ) + + tc.connectionsActive.WithLabelValues(serverHello.ServerId).Inc() + + ts := &tunnelStream{ + streamID: serverHello.StreamId, + serverID: serverHello.ServerId, + cancel: cancel, + done: make(chan struct{}), + } + + // Create a mutex-protected send function to prevent concurrent Send() calls. + // Multiple goroutines (heartbeat responses, HTTP response handlers) may send + // on this stream concurrently. + sendMu := &sync.Mutex{} + sendFn := sendFunc(func(msg *pb.TunnelClientMessage) error { + sendMu.Lock() + defer sendMu.Unlock() + return stream.Send(msg) + }) + + // Start request handler goroutine. + go tc.streamLoop(streamCtx, stream, sendFn, ts, token) + + return ts +} + +func (tc *tunnelClient) streamLoop(ctx context.Context, stream pb.TunnelService_TunnelClient, sendFn sendFunc, ts *tunnelStream, token string) { + assembler := newRequestAssembler() + defer func() { + assembler.discardAll() + tc.connectionsActive.WithLabelValues(ts.serverID).Dec() + close(ts.done) + }() + + for { + select { + case <-ctx.Done(): + return + default: + } + + msg, err := stream.Recv() + if err != nil { + if err != io.EOF && ctx.Err() == nil { + tc.logger.Warn("Stream recv error, will reconnect", + zap.String("streamId", ts.streamID), + zap.Error(err), + ) + tc.reconnectsTotal.WithLabelValues(ts.serverID).Inc() + go tc.reconnectStream(tc.parentCtx, ts, token) + } + return + } + + switch m := msg.Message.(type) { + case *pb.TunnelServerMessage_Heartbeat: + // Respond with heartbeat. + if err := sendFn(&pb.TunnelClientMessage{ + Message: &pb.TunnelClientMessage_Heartbeat{ + Heartbeat: &pb.Heartbeat{ + TimestampMs: time.Now().UnixMilli(), + }, + }, + }); err != nil { + tc.logger.Warn("Failed to send heartbeat response", zap.Error(err)) + } + + case *pb.TunnelServerMessage_HttpRequest: + assembled := assembler.handleChunk(m.HttpRequest) + if assembled != nil { + go tc.handleRequest(sendFn, assembled) + } + + case *pb.TunnelServerMessage_Hello: + tc.logger.Warn("Received unexpected ServerHello after handshake") + } + } +} + +func (tc *tunnelClient) handleRequest(sendFn sendFunc, req *pb.HttpRequest) { + if tc.executor == nil { + tc.sendErrorResponse(sendFn, req.RequestId, 503, "executor not ready") + return + } + + // Convert headers. + headers := make(map[string]string, len(req.Headers)) + for k, v := range req.Headers { + headers[k] = v + } + + start := time.Now() + + // Use server-provided timeout as context deadline. + ctx := context.Background() + if req.TimeoutMs > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutMs)*time.Millisecond) + defer cancel() + } + + resp, err := tc.executor.Execute(ctx, req.Method, req.Path, headers, req.Body) + if err != nil { + tc.logger.Error("Request execution failed", + zap.String("requestId", req.RequestId), + zap.String("method", req.Method), + zap.String("path", req.Path), + zap.Error(err), + ) + statusCode := 502 + if err == requestexecutor.ErrNoMatchingRule { + statusCode = 404 + } + tc.requestsTotal.WithLabelValues(req.Method, fmt.Sprintf("%d", statusCode)).Inc() + tc.sendErrorResponse(sendFn, req.RequestId, int32(statusCode), err.Error()) + return + } + + duration := time.Since(start) + tc.requestsTotal.WithLabelValues(req.Method, fmt.Sprintf("%d", resp.StatusCode)).Inc() + tc.requestDuration.WithLabelValues(req.Method).Observe(float64(duration.Milliseconds())) + + // Send response back through tunnel (chunked if needed). + tc.sendResponse(sendFn, req.RequestId, resp) +} + +func (tc *tunnelClient) sendResponse(sendFn sendFunc, requestID string, resp *requestexecutor.ExecutorResponse) { + if len(resp.Body) <= maxChunkSize { + if err := sendFn(&pb.TunnelClientMessage{ + Message: &pb.TunnelClientMessage_HttpResponse{ + HttpResponse: &pb.HttpResponse{ + RequestId: requestID, + StatusCode: int32(resp.StatusCode), + Headers: resp.Headers, + Body: resp.Body, + ChunkIndex: 0, + IsFinal: true, + }, + }, + }); err != nil { + tc.logger.Warn("Failed to send response", + zap.String("requestId", requestID), + zap.Error(err), + ) + } + return + } + + // Chunked response. + for i := 0; i < len(resp.Body); i += maxChunkSize { + end := i + maxChunkSize + if end > len(resp.Body) { + end = len(resp.Body) + } + chunkIndex := int32(i / maxChunkSize) + isFinal := end == len(resp.Body) + + httpResp := &pb.HttpResponse{ + RequestId: requestID, + Body: resp.Body[i:end], + ChunkIndex: chunkIndex, + IsFinal: isFinal, + } + + // First chunk includes status code and headers. + if chunkIndex == 0 { + httpResp.StatusCode = int32(resp.StatusCode) + httpResp.Headers = resp.Headers + } + + if err := sendFn(&pb.TunnelClientMessage{ + Message: &pb.TunnelClientMessage_HttpResponse{ + HttpResponse: httpResp, + }, + }); err != nil { + tc.logger.Warn("Failed to send response chunk, aborting remaining chunks", + zap.String("requestId", requestID), + zap.Int32("chunkIndex", chunkIndex), + zap.Error(err), + ) + return + } + } +} + +func (tc *tunnelClient) sendErrorResponse(sendFn sendFunc, requestID string, statusCode int32, message string) { + if err := sendFn(&pb.TunnelClientMessage{ + Message: &pb.TunnelClientMessage_HttpResponse{ + HttpResponse: &pb.HttpResponse{ + RequestId: requestID, + StatusCode: statusCode, + Headers: map[string]string{"Content-Type": "text/plain"}, + Body: []byte(message), + ChunkIndex: 0, + IsFinal: true, + }, + }, + }); err != nil { + tc.logger.Warn("Failed to send error response", + zap.String("requestId", requestID), + zap.Int32("statusCode", statusCode), + zap.Error(err), + ) + } +} + +func (tc *tunnelClient) reconnectStream(parentCtx context.Context, ts *tunnelStream, token string) { + // Add jitter to prevent thundering herd. + jitter := time.Duration(rand.IntN(5000)) * time.Millisecond + time.Sleep(jitter) + + backoff := time.Second + maxBackoff := 30 * time.Second + + for attempt := 0; tc.running.Load(); attempt++ { + // Stop if the parent context was cancelled (e.g. Close() called). + if parentCtx.Err() != nil { + return + } + + tc.logger.Info("Reconnecting tunnel stream", + zap.String("streamId", ts.streamID), + zap.Int("attempt", attempt), + ) + + tc.mu.Lock() + if tc.conn == nil { + tc.mu.Unlock() + return + } + client := pb.NewTunnelServiceClient(tc.conn) + tc.mu.Unlock() + + seenServers := make(map[string]bool) + newStream := tc.openStream(parentCtx, client, token, 0, seenServers) + if newStream != nil { + // Replace the old stream entry. + tc.mu.Lock() + for i, s := range tc.streams { + if s == ts { + tc.streams[i] = newStream + break + } + } + tc.mu.Unlock() + return + } + + // Wait with backoff, but bail if context is cancelled. + select { + case <-time.After(backoff): + case <-parentCtx.Done(): + return + } + backoff = min(backoff*2, maxBackoff) + } +} + +func (tc *tunnelClient) Restart() error { + tc.logger.Info("Restarting gRPC tunnel") + if err := tc.Close(); err != nil { + tc.logger.Error("Error closing tunnel on restart", zap.Error(err)) + } + return tc.Start() +} + +func (tc *tunnelClient) Close() error { + if !tc.running.CompareAndSwap(true, false) { + return nil + } + + tc.mu.Lock() + defer tc.mu.Unlock() + + // Cancel all stream contexts. + if tc.cancelAll != nil { + tc.cancelAll() + tc.cancelAll = nil + } + + // Wait for all streams to finish. + for _, s := range tc.streams { + <-s.done + } + tc.streams = nil + + // Close gRPC connection. + if tc.conn != nil { + tc.conn.Close() + tc.conn = nil + } + + tc.logger.Info("gRPC tunnel closed") + return nil +} + +func (tc *tunnelClient) buildDialOptions(targetAddr string) ([]grpc.DialOption, string) { + opts := []grpc.DialOption{} + dialAddr := targetAddr + + // Add transport credentials. + creds := tc.buildTransportCredentials() + opts = append(opts, grpc.WithTransportCredentials(creds)) + + // Check for HTTP proxy configuration. + proxyURL := tc.getProxyURL(targetAddr) + if proxyURL != nil { + tc.logger.Info("Using HTTP proxy for gRPC connection", + zap.String("proxy", proxyURL.Host), + zap.String("target", targetAddr), + ) + dialer := tc.buildProxyDialer(proxyURL) + opts = append(opts, grpc.WithContextDialer(dialer)) + + // Use passthrough scheme to skip local DNS resolution. + // This ensures the address is passed directly to our custom dialer, + // which will connect through the proxy and let the proxy resolve the hostname. + dialAddr = "passthrough:///" + targetAddr + } + + return opts, dialAddr +} + +// getProxyURL returns the proxy URL to use for the target address, or nil if no proxy. +func (tc *tunnelClient) getProxyURL(targetAddr string) *url.URL { + // Build a fake request to use http.ProxyFromEnvironment + // This respects HTTP_PROXY, HTTPS_PROXY, and NO_PROXY. + scheme := "https" + if tc.config.GrpcInsecure { + scheme = "http" + } + fakeReq, _ := http.NewRequest("GET", fmt.Sprintf("%s://%s/", scheme, targetAddr), nil) + proxyURL, err := http.ProxyFromEnvironment(fakeReq) + if err != nil || proxyURL == nil { + return nil + } + return proxyURL +} + +// buildProxyDialer returns a context dialer that connects through an HTTP CONNECT proxy. +func (tc *tunnelClient) buildProxyDialer(proxyURL *url.URL) func(ctx context.Context, addr string) (net.Conn, error) { + return func(ctx context.Context, addr string) (net.Conn, error) { + // Connect to the proxy. + proxyAddr := proxyURL.Host + if proxyURL.Port() == "" { + proxyAddr = net.JoinHostPort(proxyURL.Hostname(), "8080") + } + + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", proxyAddr) + if err != nil { + return nil, fmt.Errorf("failed to connect to proxy %s: %w", proxyAddr, err) + } + + // Send HTTP CONNECT request. + connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", addr, addr) + + // Add proxy authentication if present. + if proxyURL.User != nil { + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + connectReq += fmt.Sprintf("Proxy-Authorization: Basic %s\r\n", auth) + } + + connectReq += "\r\n" + + if _, err := conn.Write([]byte(connectReq)); err != nil { + conn.Close() + return nil, fmt.Errorf("failed to send CONNECT request: %w", err) + } + + // Read the response. + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, nil) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to read CONNECT response: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + conn.Close() + return nil, fmt.Errorf("proxy CONNECT failed with status %d", resp.StatusCode) + } + + tc.logger.Debug("HTTP CONNECT tunnel established", + zap.String("proxy", proxyURL.Host), + zap.String("target", addr), + ) + + return conn, nil + } +} + +func (tc *tunnelClient) buildTransportCredentials() credentials.TransportCredentials { + if tc.config.GrpcInsecure { + return insecure.NewCredentials() + } + + tlsConfig := &tls.Config{} + + if tc.config.HttpCaCertFilePath != "" { + caCert, err := os.ReadFile(tc.config.HttpCaCertFilePath) + if err != nil { + tc.logger.Error("Failed to read CA cert", zap.Error(err)) + return insecure.NewCredentials() + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.RootCAs = caCertPool + } + + return credentials.NewTLS(tlsConfig) +} + +func stripScheme(addr string) string { + addr = strings.TrimPrefix(addr, "https://") + addr = strings.TrimPrefix(addr, "http://") + return addr +} diff --git a/agent/server/requestexecutor/executor.go b/agent/server/requestexecutor/executor.go new file mode 100644 index 0000000..752ba43 --- /dev/null +++ b/agent/server/requestexecutor/executor.go @@ -0,0 +1,166 @@ +package requestexecutor + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + + "github.com/cortexapps/axon/server/snykbroker/acceptfile" + "go.uber.org/zap" +) + +// ExecutorResponse is the result of executing an HTTP request through a matched accept file rule. +type ExecutorResponse struct { + StatusCode int + Headers map[string]string + Body []byte +} + +// RequestExecutor applies accept file rules to execute HTTP requests. +// It matches incoming requests against rules, rewrites URLs, injects headers/auth, +// and executes the request against the target origin. +type RequestExecutor interface { + Execute(ctx context.Context, method, path string, headers map[string]string, body []byte) (*ExecutorResponse, error) +} + +// ErrNoMatchingRule is returned when no accept file rule matches the request. +var ErrNoMatchingRule = fmt.Errorf("no matching accept file rule") + +type requestExecutor struct { + rules []acceptfile.AcceptFileRuleWrapper + logger *zap.Logger + httpClient *http.Client + pools *PoolManager +} + +// NewRequestExecutor creates a new RequestExecutor from rendered accept file rules. +// The httpClient parameter should be the shared *http.Client from DI, which already +// handles proxy (http.ProxyFromEnvironment), CA certs (including directories), and TLS config. +func NewRequestExecutor(rules []acceptfile.AcceptFileRuleWrapper, httpClient *http.Client, logger *zap.Logger) RequestExecutor { + return &requestExecutor{ + rules: rules, + logger: logger.Named("request-executor"), + httpClient: httpClient, + pools: NewPoolManager(), + } +} + +func (e *requestExecutor) Execute(ctx context.Context, method, path string, headers map[string]string, body []byte) (*ExecutorResponse, error) { + rule := MatchRule(e.rules, method, path, headers) + if rule == nil { + return nil, ErrNoMatchingRule + } + + origin := e.resolveOrigin(rule.Origin()) + targetURL, err := buildTargetURL(origin, path) + if err != nil { + return nil, fmt.Errorf("failed to build target URL: %w", err) + } + + var bodyReader io.Reader + if len(body) > 0 { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequestWithContext(ctx, method, targetURL, bodyReader) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Copy incoming headers first. + for k, v := range headers { + req.Header.Set(k, v) + } + + // Inject rule headers (overrides incoming). + ruleHeaders := rule.Headers() + if ruleHeaders != nil { + resolved := ruleHeaders.ToStringMap() + for k, v := range resolved { + req.Header.Set(k, v) + } + } + + // Inject auth. + e.applyAuth(req, rule.Auth()) + + // Set Host header to target. + parsedOrigin, _ := url.Parse(origin) + if parsedOrigin != nil { + req.Host = parsedOrigin.Host + } + + e.logger.Debug("Executing request", + zap.String("method", method), + zap.String("path", path), + zap.String("targetURL", targetURL), + ) + + resp, err := e.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request execution failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + respHeaders := make(map[string]string, len(resp.Header)) + for k, v := range resp.Header { + respHeaders[k] = strings.Join(v, ", ") + } + + return &ExecutorResponse{ + StatusCode: resp.StatusCode, + Headers: respHeaders, + Body: respBody, + }, nil +} + +// resolveOrigin resolves _POOL variables in the origin URL via environment expansion and pool rotation. +func (e *requestExecutor) resolveOrigin(origin string) string { + // The origin may contain env vars like ${GITHUB_API} or pool vars like ${GITHUB_API_POOL}. + // After preprocessing, env vars are already expanded in the origin string. + // We need to check if the resolved value is a comma-separated pool. + return e.pools.ResolvePoolVars(origin) +} + +func (e *requestExecutor) applyAuth(req *http.Request, auth *acceptfile.AcceptFileRuleAuth) { + if auth == nil { + return + } + switch strings.ToLower(auth.Scheme) { + case "bearer", "token": + token := os.ExpandEnv(auth.Token) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + case "basic": + username := os.ExpandEnv(auth.Username) + password := os.ExpandEnv(auth.Password) + req.SetBasicAuth(username, password) + default: + // Custom scheme: set as Authorization header. + token := os.ExpandEnv(auth.Token) + req.Header.Set("Authorization", fmt.Sprintf("%s %s", auth.Scheme, token)) + } +} + +func buildTargetURL(origin, requestPath string) (string, error) { + parsed, err := url.Parse(origin) + if err != nil { + return "", err + } + // Append the request path to the origin's path. + if parsed.Path == "" || parsed.Path == "/" { + parsed.Path = requestPath + } else { + parsed.Path = strings.TrimRight(parsed.Path, "/") + "/" + strings.TrimLeft(requestPath, "/") + } + return parsed.String(), nil +} diff --git a/agent/server/requestexecutor/executor_test.go b/agent/server/requestexecutor/executor_test.go new file mode 100644 index 0000000..4ce92d0 --- /dev/null +++ b/agent/server/requestexecutor/executor_test.go @@ -0,0 +1,470 @@ +package requestexecutor + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cortexapps/axon/config" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func makeRules(t *testing.T, rules string, cfg config.AgentConfig) []acceptfile.AcceptFileRuleWrapper { + t.Helper() + af, err := acceptfile.NewAcceptFile([]byte(rules), cfg, zap.NewNop()) + require.NoError(t, err) + rendered, err := af.Render(zap.NewNop()) + require.NoError(t, err) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(rendered, &parsed)) + + // Re-parse to get wrappers via a new accept file. + af2, err := acceptfile.NewAcceptFile(rendered, cfg, zap.NewNop()) + require.NoError(t, err) + rendered2, err := af2.Render(zap.NewNop()) + require.NoError(t, err) + + af3, err := acceptfile.NewAcceptFile(rendered2, cfg, zap.NewNop()) + require.NoError(t, err) + _ = af3 + + // Get private rules from the wrapper. + af4, err := acceptfile.NewAcceptFile(rendered, cfg, zap.NewNop()) + require.NoError(t, err) + wrapper := af4.Wrapper() + return wrapper.PrivateRules() +} + +func TestMatchRule_MethodAndPath(t *testing.T) { + tests := []struct { + name string + ruleMethod string + rulePath string + reqMethod string + reqPath string + shouldMatch bool + }{ + {"exact GET match", "GET", "/api/v1/repos", "GET", "/api/v1/repos", true}, + {"method mismatch", "POST", "/api/v1/repos", "GET", "/api/v1/repos", false}, + {"any method match", "any", "/api/v1/repos", "DELETE", "/api/v1/repos", true}, + {"wildcard path", "GET", "/api/*", "GET", "/api/repos", true}, + {"wildcard path no match", "GET", "/api/*", "GET", "/other/repos", false}, + {"path mismatch", "GET", "/api/v1/repos", "GET", "/api/v2/repos", false}, + {"case insensitive method", "get", "/api/v1/repos", "GET", "/api/v1/repos", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matched := matchesMethod(tt.ruleMethod, tt.reqMethod) && matchesPath(tt.rulePath, tt.reqPath) + assert.Equal(t, tt.shouldMatch, matched) + }) + } +} + +func TestMatchRule_WildcardSubpath(t *testing.T) { + assert.True(t, matchesPath("/api/*", "/api/repos")) + assert.True(t, matchesPath("/api/*", "/api/anything")) + assert.True(t, matchesPath("/__axon/*", "/__axon/health")) + assert.False(t, matchesPath("/api/*", "/other/repos")) +} + +func TestExecutor_BasicRequest(t *testing.T) { + // Set up a test HTTP server. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + assert.Equal(t, "/api/v1/repos", r.URL.Path) + w.Header().Set("X-Test", "response-header") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"repos": []}`)) + })) + defer server.Close() + + rulesJSON := fmt.Sprintf(`{ + "private": [ + { + "method": "GET", + "path": "/api/v1/repos", + "origin": "%s" + } + ] + }`, server.URL) + + cfg := config.AgentConfig{ + HttpServerPort: 8080, + PluginDirs: []string{}, + } + + rules := makeRules(t, rulesJSON, cfg) + // Filter out the axon route added by render. + var filteredRules []acceptfile.AcceptFileRuleWrapper + for _, r := range rules { + if r.Path() != "/__axon/*" { + filteredRules = append(filteredRules, r) + } + } + + executor := NewRequestExecutor(filteredRules, &http.Client{}, zap.NewNop()) + + resp, err := executor.Execute(context.Background(), "GET", "/api/v1/repos", nil, nil) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, `{"repos": []}`, string(resp.Body)) + assert.Equal(t, "response-header", resp.Headers["X-Test"]) +} + +func TestExecutor_NoMatchingRule(t *testing.T) { + rulesJSON := `{ + "private": [ + { + "method": "GET", + "path": "/api/v1/repos", + "origin": "https://example.com" + } + ] + }` + + cfg := config.AgentConfig{ + HttpServerPort: 8080, + PluginDirs: []string{}, + } + + rules := makeRules(t, rulesJSON, cfg) + executor := NewRequestExecutor(rules, &http.Client{}, zap.NewNop()) + + _, err := executor.Execute(context.Background(), "GET", "/unknown/path", nil, nil) + assert.ErrorIs(t, err, ErrNoMatchingRule) +} + +func TestExecutor_BearerAuth(t *testing.T) { + t.Setenv("MY_TOKEN", "secret-token-123") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer secret-token-123", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rulesJSON := fmt.Sprintf(`{ + "private": [ + { + "method": "GET", + "path": "/api/*", + "origin": "%s", + "auth": { + "scheme": "bearer", + "token": "${MY_TOKEN}" + } + } + ] + }`, server.URL) + + cfg := config.AgentConfig{ + HttpServerPort: 8080, + PluginDirs: []string{}, + } + + rules := makeRules(t, rulesJSON, cfg) + var filteredRules []acceptfile.AcceptFileRuleWrapper + for _, r := range rules { + if r.Path() != "/__axon/*" { + filteredRules = append(filteredRules, r) + } + } + + executor := NewRequestExecutor(filteredRules, &http.Client{}, zap.NewNop()) + + resp, err := executor.Execute(context.Background(), "GET", "/api/repos", nil, nil) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestExecutor_BasicAuth(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + assert.True(t, ok) + assert.Equal(t, "myuser", user) + assert.Equal(t, "mypass", pass) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rulesJSON := fmt.Sprintf(`{ + "private": [ + { + "method": "POST", + "path": "/api/*", + "origin": "%s", + "auth": { + "scheme": "basic", + "username": "myuser", + "password": "mypass" + } + } + ] + }`, server.URL) + + cfg := config.AgentConfig{ + HttpServerPort: 8080, + PluginDirs: []string{}, + } + + rules := makeRules(t, rulesJSON, cfg) + var filteredRules []acceptfile.AcceptFileRuleWrapper + for _, r := range rules { + if r.Path() != "/__axon/*" { + filteredRules = append(filteredRules, r) + } + } + + executor := NewRequestExecutor(filteredRules, &http.Client{}, zap.NewNop()) + + resp, err := executor.Execute(context.Background(), "POST", "/api/data", nil, []byte(`{"key":"value"}`)) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestPool_RoundRobin(t *testing.T) { + t.Setenv("TEST_API_POOL", "https://api1.example.com,https://api2.example.com,https://api3.example.com") + + pm := NewPoolManager() + + results := make([]string, 6) + for i := 0; i < 6; i++ { + results[i] = pm.ResolvePoolVars("${TEST_API}") + } + + assert.Equal(t, "https://api1.example.com", results[0]) + assert.Equal(t, "https://api2.example.com", results[1]) + assert.Equal(t, "https://api3.example.com", results[2]) + assert.Equal(t, "https://api1.example.com", results[3]) + assert.Equal(t, "https://api2.example.com", results[4]) + assert.Equal(t, "https://api3.example.com", results[5]) +} + +func TestPool_FallbackToEnvVar(t *testing.T) { + t.Setenv("SINGLE_API", "https://api.example.com") + + pm := NewPoolManager() + result := pm.ResolvePoolVars("${SINGLE_API}") + assert.Equal(t, "https://api.example.com", result) +} + +func TestPool_NoMatch(t *testing.T) { + pm := NewPoolManager() + result := pm.ResolvePoolVars("https://static.example.com") + assert.Equal(t, "https://static.example.com", result) +} + +func TestBuildTargetURL(t *testing.T) { + tests := []struct { + origin string + path string + want string + wantErr bool + }{ + {"https://api.github.com", "/repos/foo", "https://api.github.com/repos/foo", false}, + {"https://api.github.com/v3", "/repos/foo", "https://api.github.com/v3/repos/foo", false}, + {"https://api.github.com/", "/repos/foo", "https://api.github.com/repos/foo", false}, + } + + for _, tt := range tests { + t.Run(tt.origin+tt.path, func(t *testing.T) { + got, err := buildTargetURL(tt.origin, tt.path) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestMatchRule_ValidHeaderRequirement(t *testing.T) { + tests := []struct { + name string + requirements []acceptfile.ValidHeaderRequirement + headers map[string]string + shouldMatch bool + }{ + { + name: "no requirements - always matches", + requirements: nil, + headers: nil, + shouldMatch: true, + }, + { + name: "header present with matching value", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder"}}, + }, + headers: map[string]string{"x-cortex-service": "scaffolder"}, + shouldMatch: true, + }, + { + name: "header present but wrong value", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder"}}, + }, + headers: map[string]string{"x-cortex-service": "other"}, + shouldMatch: false, + }, + { + name: "header missing", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder"}}, + }, + headers: map[string]string{"x-other": "value"}, + shouldMatch: false, + }, + { + name: "header missing - nil headers", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder"}}, + }, + headers: nil, + shouldMatch: false, + }, + { + name: "case insensitive header name", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "X-Cortex-Service", Values: []string{"scaffolder"}}, + }, + headers: map[string]string{"x-cortex-service": "scaffolder"}, + shouldMatch: true, + }, + { + name: "case insensitive header value", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"Scaffolder"}}, + }, + headers: map[string]string{"x-cortex-service": "scaffolder"}, + shouldMatch: true, + }, + { + name: "multiple allowed values - first matches", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder", "catalog", "other"}}, + }, + headers: map[string]string{"x-cortex-service": "scaffolder"}, + shouldMatch: true, + }, + { + name: "multiple allowed values - second matches", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder", "catalog", "other"}}, + }, + headers: map[string]string{"x-cortex-service": "catalog"}, + shouldMatch: true, + }, + { + name: "multiple allowed values - none match", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder", "catalog"}}, + }, + headers: map[string]string{"x-cortex-service": "unknown"}, + shouldMatch: false, + }, + { + name: "multiple requirements - all must match", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder"}}, + {Header: "x-cortex-tenant", Values: []string{"acme"}}, + }, + headers: map[string]string{ + "x-cortex-service": "scaffolder", + "x-cortex-tenant": "acme", + }, + shouldMatch: true, + }, + { + name: "multiple requirements - one missing", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{"scaffolder"}}, + {Header: "x-cortex-tenant", Values: []string{"acme"}}, + }, + headers: map[string]string{"x-cortex-service": "scaffolder"}, + shouldMatch: false, + }, + { + name: "empty values array - just check header exists", + requirements: []acceptfile.ValidHeaderRequirement{ + {Header: "x-cortex-service", Values: []string{}}, + }, + headers: map[string]string{"x-cortex-service": "anything"}, + shouldMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchesValid(tt.requirements, tt.headers) + assert.Equal(t, tt.shouldMatch, result) + }) + } +} + +func TestMatchRule_WithValidHeaders(t *testing.T) { + // Test that MatchRule correctly uses valid header requirements to select the right rule. + rulesJSON := `{ + "private": [ + { + "method": "any", + "path": "/*", + "origin": "https://github.com", + "valid": [ + { + "header": "x-cortex-service", + "values": ["scaffolder"] + } + ] + }, + { + "method": "any", + "path": "/*", + "origin": "https://api.github.com" + } + ] + }` + + cfg := config.AgentConfig{ + HttpServerPort: 8080, + PluginDirs: []string{}, + } + + rules := makeRules(t, rulesJSON, cfg) + // Filter out the axon route added by render. + var filteredRules []acceptfile.AcceptFileRuleWrapper + for _, r := range rules { + if r.Path() != "/__axon/*" { + filteredRules = append(filteredRules, r) + } + } + + t.Run("with scaffolder header - matches first rule", func(t *testing.T) { + headers := map[string]string{"x-cortex-service": "scaffolder"} + rule := MatchRule(filteredRules, "GET", "/repos/foo", headers) + require.NotNil(t, rule) + assert.Equal(t, "https://github.com", rule.Origin()) + }) + + t.Run("without scaffolder header - skips first rule, matches second", func(t *testing.T) { + headers := map[string]string{"x-other": "value"} + rule := MatchRule(filteredRules, "GET", "/repos/foo", headers) + require.NotNil(t, rule) + assert.Equal(t, "https://api.github.com", rule.Origin()) + }) + + t.Run("no headers - skips first rule, matches second", func(t *testing.T) { + rule := MatchRule(filteredRules, "GET", "/repos/foo") + require.NotNil(t, rule) + assert.Equal(t, "https://api.github.com", rule.Origin()) + }) +} diff --git a/agent/server/requestexecutor/pool.go b/agent/server/requestexecutor/pool.go new file mode 100644 index 0000000..3ae263e --- /dev/null +++ b/agent/server/requestexecutor/pool.go @@ -0,0 +1,98 @@ +package requestexecutor + +import ( + "os" + "regexp" + "strings" + "sync" + "sync/atomic" +) + +// PoolManager handles _POOL variable resolution with round-robin rotation. +// When an environment variable like GITHUB_API_POOL is set to a comma-separated +// list of values (e.g., "https://api1.github.com,https://api2.github.com"), +// the pool manager rotates through them on each resolution. +type PoolManager struct { + mu sync.RWMutex + pools map[string]*poolEntry +} + +type poolEntry struct { + values []string + counter atomic.Uint64 +} + +func NewPoolManager() *PoolManager { + return &PoolManager{ + pools: make(map[string]*poolEntry), + } +} + +// getPool returns the pool entry for the given variable name, creating it if needed. +func (pm *PoolManager) getPool(varName string) *poolEntry { + pm.mu.RLock() + entry, exists := pm.pools[varName] + pm.mu.RUnlock() + if exists { + return entry + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // Double-check after acquiring write lock. + if entry, exists = pm.pools[varName]; exists { + return entry + } + + poolValue := os.Getenv(varName + "_POOL") + if poolValue == "" { + return nil + } + + values := strings.Split(poolValue, ",") + trimmed := make([]string, 0, len(values)) + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + trimmed = append(trimmed, v) + } + } + if len(trimmed) == 0 { + return nil + } + + entry = &poolEntry{values: trimmed} + pm.pools[varName] = entry + return entry +} + +// Next returns the next value from the pool using round-robin. +func (pe *poolEntry) Next() string { + idx := pe.counter.Add(1) - 1 + return pe.values[idx%uint64(len(pe.values))] +} + +// reEnvVar matches ${VAR_NAME} patterns in strings. +var reEnvVar = regexp.MustCompile(`\$\{([^}]+)\}`) + +// ResolvePoolVars resolves any ${VAR} references in the string, checking for +// _POOL variants first (round-robin), then falling back to regular env vars. +func (pm *PoolManager) ResolvePoolVars(s string) string { + return reEnvVar.ReplaceAllStringFunc(s, func(match string) string { + varName := match[2 : len(match)-1] // strip ${ and } + + // Check pool first. + if entry := pm.getPool(varName); entry != nil { + return entry.Next() + } + + // Fall back to regular env var. + if val := os.Getenv(varName); val != "" { + return val + } + + // Check if the value itself (already expanded) is a comma-separated pool. + return match + }) +} diff --git a/agent/server/requestexecutor/rule_matcher.go b/agent/server/requestexecutor/rule_matcher.go new file mode 100644 index 0000000..e751e7d --- /dev/null +++ b/agent/server/requestexecutor/rule_matcher.go @@ -0,0 +1,115 @@ +package requestexecutor + +import ( + "path" + "strings" + + "github.com/cortexapps/axon/server/snykbroker/acceptfile" +) + +// MatchRule finds the first accept file rule that matches the given HTTP method, path, and headers. +// Returns nil if no rule matches. +func MatchRule(rules []acceptfile.AcceptFileRuleWrapper, method, requestPath string, headers ...map[string]string) *acceptfile.AcceptFileRuleWrapper { + var reqHeaders map[string]string + if len(headers) > 0 { + reqHeaders = headers[0] + } + + for i := range rules { + rule := &rules[i] + if matchesMethod(rule.Method(), method) && matchesPath(rule.Path(), requestPath) && matchesValid(rule.Valid(), reqHeaders) { + return rule + } + } + return nil +} + +// matchesValid checks if the request headers satisfy the rule's "valid" requirements. +// If no requirements are specified, returns true. +func matchesValid(requirements []acceptfile.ValidHeaderRequirement, headers map[string]string) bool { + if len(requirements) == 0 { + return true + } + + for _, req := range requirements { + headerValue, exists := getHeaderCaseInsensitive(headers, req.Header) + if !exists { + return false + } + + // Check if the header value matches one of the allowed values. + if len(req.Values) > 0 { + matched := false + for _, allowedValue := range req.Values { + if strings.EqualFold(headerValue, allowedValue) { + matched = true + break + } + } + if !matched { + return false + } + } + } + + return true +} + +// getHeaderCaseInsensitive retrieves a header value with case-insensitive key matching. +func getHeaderCaseInsensitive(headers map[string]string, key string) (string, bool) { + if headers == nil { + return "", false + } + + // Try exact match first. + if v, ok := headers[key]; ok { + return v, true + } + + // Case-insensitive search. + keyLower := strings.ToLower(key) + for k, v := range headers { + if strings.ToLower(k) == keyLower { + return v, true + } + } + + return "", false +} + +// matchesMethod checks if the rule method matches the request method. +// "any" matches all methods. +func matchesMethod(ruleMethod, requestMethod string) bool { + if strings.EqualFold(ruleMethod, "any") { + return true + } + return strings.EqualFold(ruleMethod, requestMethod) +} + +// matchesPath checks if the request path matches the rule's path pattern. +// Supports glob-style wildcards: * matches a single path segment, ** matches any number. +func matchesPath(pattern, requestPath string) bool { + if pattern == "" { + return false + } + + // Normalize paths. + pattern = "/" + strings.TrimLeft(pattern, "/") + requestPath = "/" + strings.TrimLeft(requestPath, "/") + + // Use path.Match for simple glob patterns. + // Handle trailing /* as "match anything under this prefix". + if strings.HasSuffix(pattern, "/*") { + prefix := strings.TrimSuffix(pattern, "/*") + if strings.HasPrefix(requestPath, prefix+"/") || requestPath == prefix { + return true + } + } + + // Try exact path.Match. + matched, err := path.Match(pattern, requestPath) + if err != nil { + return false + } + return matched +} diff --git a/agent/server/snykbroker/acceptfile/accept_file.go b/agent/server/snykbroker/acceptfile/accept_file.go index fe1aaf9..0778d1f 100644 --- a/agent/server/snykbroker/acceptfile/accept_file.go +++ b/agent/server/snykbroker/acceptfile/accept_file.go @@ -49,6 +49,11 @@ func NewAcceptFile(content []byte, cfg config.AgentConfig, logger *zap.Logger) ( return af, nil } +// Wrapper returns the typed wrapper for accessing accept file rules. +func (a *AcceptFile) Wrapper() acceptFileWrapper { + return a.wrapper +} + type RenderContext struct { AcceptFile acceptFileWrapper Logger *zap.Logger @@ -145,28 +150,28 @@ func newAcceptFileWrapper(content []byte, af *AcceptFile) acceptFileWrapper { return acceptFileWrapper{dict: dict, acceptFile: af} } -func (w acceptFileWrapper) PrivateRules() []acceptFileRuleWrapper { +func (w acceptFileWrapper) PrivateRules() []AcceptFileRuleWrapper { return w.rules(RULES_PRIVATE) } -func (w acceptFileWrapper) PublicRules() []acceptFileRuleWrapper { +func (w acceptFileWrapper) PublicRules() []AcceptFileRuleWrapper { return w.rules(RULES_PUBLIC) } -func (w acceptFileWrapper) rules(routeType string) []acceptFileRuleWrapper { +func (w acceptFileWrapper) rules(routeType string) []AcceptFileRuleWrapper { routesEntry, ok := w.dict[routeType].([]interface{}) if !ok { routesEntry = []any{} w.dict[routeType] = routesEntry } - routes := make([]acceptFileRuleWrapper, len(routesEntry)) + routes := make([]AcceptFileRuleWrapper, len(routesEntry)) for i, route := range routesEntry { routeDict, ok := route.(map[string]any) if !ok { return nil } - routes[i] = acceptFileRuleWrapper{ + routes[i] = AcceptFileRuleWrapper{ dict: routeDict, acceptFile: w.acceptFile, } @@ -175,10 +180,10 @@ func (w acceptFileWrapper) rules(routeType string) []acceptFileRuleWrapper { } // AddRule adds a new route to the accept file for the specified route type. -func (w acceptFileWrapper) AddRule(routeType string, entry acceptFileRule) acceptFileRuleWrapper { +func (w acceptFileWrapper) AddRule(routeType string, entry acceptFileRule) AcceptFileRuleWrapper { // with a little extra work here we could probably just directly use - // the entry structure above, but the acceptFileRuleWrapper takes a dict so we need + // the entry structure above, but the AcceptFileRuleWrapper takes a dict so we need // to convert it to a map[string]any first, so we round trip it through JSON. routeAsJson, err := json.Marshal(entry) @@ -192,7 +197,7 @@ func (w acceptFileWrapper) AddRule(routeType string, entry acceptFileRule) accep } existingRoutes := w.dict[routeType].([]any) w.dict[routeType] = append([]any{routeDict}, existingRoutes...) - return acceptFileRuleWrapper{dict: routeDict} + return AcceptFileRuleWrapper{dict: routeDict} } func (w acceptFileWrapper) toJSON() ([]byte, error) { @@ -203,12 +208,14 @@ func (w acceptFileWrapper) toJSON() ([]byte, error) { return jsonData, nil } -type acceptFileRuleWrapper struct { +// AcceptFileRuleWrapper provides strongly typed access to a single accept file rule +// parsed from the raw JSON dictionary. +type AcceptFileRuleWrapper struct { dict map[string]any acceptFile *AcceptFile } -func (r acceptFileRuleWrapper) Origin() string { +func (r AcceptFileRuleWrapper) Origin() string { rawOrigin, ok := r.dict["origin"].(string) if !ok { return "" @@ -227,7 +234,7 @@ func (r acceptFileRuleWrapper) Origin() string { } -func (r acceptFileRuleWrapper) Path() string { +func (r AcceptFileRuleWrapper) Path() string { path, ok := r.dict["path"].(string) if !ok { return "" @@ -235,11 +242,40 @@ func (r acceptFileRuleWrapper) Path() string { return path } -func (r acceptFileRuleWrapper) SetOrigin(origin string) { +func (r AcceptFileRuleWrapper) SetOrigin(origin string) { r.dict["origin"] = origin } -func (r acceptFileRuleWrapper) Headers() ResolverMap { +func (r AcceptFileRuleWrapper) Method() string { + method, ok := r.dict["method"].(string) + if !ok { + return "" + } + return method +} + +func (r AcceptFileRuleWrapper) Auth() *AcceptFileRuleAuth { + authDict, ok := r.dict["auth"].(map[string]any) + if !ok { + return nil + } + auth := &AcceptFileRuleAuth{} + if scheme, ok := authDict["scheme"].(string); ok { + auth.Scheme = scheme + } + if username, ok := authDict["username"].(string); ok { + auth.Username = username + } + if password, ok := authDict["password"].(string); ok { + auth.Password = password + } + if token, ok := authDict["token"].(string); ok { + auth.Token = token + } + return auth +} + +func (r AcceptFileRuleWrapper) Headers() ResolverMap { headers, ok := r.dict["headers"].(map[string]any) if !ok { return nil @@ -254,6 +290,50 @@ func (r acceptFileRuleWrapper) Headers() ResolverMap { return result } +// ValidHeaderRequirement represents a header validation rule from the "valid" field. +type ValidHeaderRequirement struct { + Header string + Values []string +} + +// Valid returns the header validation requirements for this rule. +// If the rule has a "valid" field, incoming requests must have headers matching these requirements. +func (r AcceptFileRuleWrapper) Valid() []ValidHeaderRequirement { + validArr, ok := r.dict["valid"].([]any) + if !ok { + return nil + } + + var requirements []ValidHeaderRequirement + for _, item := range validArr { + itemDict, ok := item.(map[string]any) + if !ok { + continue + } + + header, _ := itemDict["header"].(string) + if header == "" { + continue + } + + var values []string + if valuesArr, ok := itemDict["values"].([]any); ok { + for _, v := range valuesArr { + if str, ok := v.(string); ok { + values = append(values, str) + } + } + } + + requirements = append(requirements, ValidHeaderRequirement{ + Header: header, + Values: values, + }) + } + + return requirements +} + // Here are our JSON structed types that represent the accept file rules. // that we can use for things that we are generating such that we don't need to worry // about additional fields that might be in the accept file that we don't know about. @@ -262,13 +342,39 @@ type acceptFileRule struct { Method string `json:"method"` Path string `json:"path"` Origin string `json:"origin"` - Auth *acceptFileRuleAuth `json:"auth,omitempty"` + Auth *AcceptFileRuleAuth `json:"auth,omitempty"` Headers map[string]string `json:"headers,omitempty"` } -type acceptFileRuleAuth struct { +type AcceptFileRuleAuth struct { Scheme string `json:"scheme"` Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` Token string `json:"token,omitempty"` } + +// Rule represents a routing rule from the accept file. +type Rule struct { + Method string + Path string + Origin string + Headers ResolverMap +} + +// GetPrivateRules returns the private routing rules from the accept file. +func (a *AcceptFile) GetPrivateRules() []Rule { + wrapper := newAcceptFileWrapper(a.content, a) + wrappedRules := wrapper.PrivateRules() + + rules := make([]Rule, len(wrappedRules)) + for i, r := range wrappedRules { + rules[i] = Rule{ + Method: r.Method(), + Path: r.Path(), + Origin: r.Origin(), + Headers: r.Headers(), + } + } + return rules +} + diff --git a/agent/test/relay/Dockerfile.grpc-tunnel-server b/agent/test/relay/Dockerfile.grpc-tunnel-server new file mode 100644 index 0000000..53af51f --- /dev/null +++ b/agent/test/relay/Dockerfile.grpc-tunnel-server @@ -0,0 +1,29 @@ +# Build the gRPC tunnel server for E2E testing +FROM golang:1.26.1-alpine AS builder + +RUN apk add --no-cache protobuf git make + +WORKDIR /build + +# Copy agent module +COPY agent/go.mod agent/go.sum ./ +RUN go mod download + +COPY agent/. . + +# Generate proto files and build server +RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6 && \ + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 +RUN make proto +RUN go build -o /grpc-tunnel-server ./cmd/grpc-tunnel-server + +# Runtime image +FROM alpine:latest + +RUN apk add --no-cache curl + +COPY --from=builder /grpc-tunnel-server /grpc-tunnel-server + +EXPOSE 8080 50051 + +CMD ["/grpc-tunnel-server"] diff --git a/agent/test/relay/docker-compose.grpc.noproxy.yml b/agent/test/relay/docker-compose.grpc.noproxy.yml new file mode 100644 index 0000000..24e9df2 --- /dev/null +++ b/agent/test/relay/docker-compose.grpc.noproxy.yml @@ -0,0 +1,12 @@ +# Override for PROXY=0 testing - adds external network for direct access +# +# Usage: docker compose -f docker-compose.grpc.yml -f docker-compose.grpc.noproxy.yml up +# +# Without this override, axon-relay can only reach grpc-tunnel-server through mitmproxy. +# This adds the external network so axon-relay can connect directly. + +services: + axon-relay: + networks: + - internal + - external diff --git a/agent/test/relay/docker-compose.grpc.yml b/agent/test/relay/docker-compose.grpc.yml new file mode 100644 index 0000000..6023458 --- /dev/null +++ b/agent/test/relay/docker-compose.grpc.yml @@ -0,0 +1,110 @@ +services: + grpc-tunnel-server: + build: + context: ../../.. + dockerfile: server/docker/Dockerfile + image: cortex-axon-tunnel-server:local + ports: + - "${GRPC_PORT}:50052" + - "${HTTP_PORT}:8080" + environment: + GRPC_PORT: 50052 + HTTP_PORT: 8080 + SERVER_ID: test-server-1 + DISPATCH_TIMEOUT: 60s + HEARTBEAT_INTERVAL: 10s + ENV: development + TOKEN: ${TOKEN:-0e481b34-76ac-481a-a92f-c94a6cf6f6c1} + networks: + - external # grpc-tunnel-server is on external network only + + mitmproxy: + image: mitmproxy/mitmproxy:12 + ports: + - "9980:8080" + - "9981:8081" + volumes: + - ./.mitmproxy:/home/mitmproxy/.mitmproxy + - ./mitmproxy_addon_header.py:/home/mitmproxy/addon.py + command: mitmdump -s /home/mitmproxy/addon.py + networks: + - internal # mitmproxy bridges both networks + - external + + axon-relay: + build: + context: ../../.. + dockerfile: docker/Dockerfile + image: cortex-axon-agent:local + volumes: + - .:/src + - ./.mitmproxy:/certs + - ../../server/snykbroker/acceptfile:/agent/plugins + environment: + CORTEX_API_TOKEN: fake-token + CORTEX_API_BASE_URL: http://cortex-fake:8081 + GITHUB_RAW_API: https://raw.githubusercontent.com + PORT: 7433 + PLUGIN_DIRS: /agent/plugins + RELAY_MODE: grpc-tunnel + TUNNEL_COUNT: 2 + BROKER_SERVER_URL: grpc-tunnel-server:50052 + GRPC_INSECURE: "true" + CORTEX_TENANT_ID: test-tenant + TOKEN: ${TOKEN:-0e481b34-76ac-481a-a92f-c94a6cf6f6c1} + env_file: ${ENVFILE:-noproxy.env} + command: relay -f /src/accept-client.json -i github -a axon-test + depends_on: + mitmproxy: + condition: service_started + grpc-tunnel-server: + condition: service_started + cortex-fake: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:7433/healthcheck"] + interval: 5s + timeout: 5s + retries: 6 + networks: + - internal + # For PROXY=1: this is the only network, forcing traffic through mitmproxy + # For PROXY=0: docker-compose.grpc.noproxy.yml adds external network + + python-server: + image: python:3.12-alpine + volumes: + - /tmp:/tmp + command: sh -c "apk add curl && python3 -m http.server 80 -d /tmp" + healthcheck: + test: ["CMD", "sh", "-c", "curl -f http://localhost:80/ || true"] # just check server is up + interval: 5s + timeout: 5s + retries: 6 + stop_grace_period: 1s # SIGKILL after 1s + networks: + - internal # python-server is a local customer service + + cortex-fake: + image: golang:1.24-alpine + volumes: + - .:/src + environment: + PORT: 8081 + BROKER_SERVER_URL: http://grpc-tunnel-server:8080 + TOKEN: ${TOKEN:-0e481b34-76ac-481a-a92f-c94a6cf6f6c1} + command: sh -c "apk add curl && cd /src && go run cortex-registration-fake.go" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8081/healthcheck"] + interval: 1s + timeout: 1s + retries: 30 + networks: + - internal # cortex-fake is accessed by axon-relay for registration + +networks: + internal: + # Internal network - axon-relay, python-server, cortex-fake + # axon-relay can only reach external network via mitmproxy + external: + # External network - grpc-tunnel-server (mimics Cortex side) diff --git a/agent/test/relay/relay_test.grpc.sh b/agent/test/relay/relay_test.grpc.sh new file mode 100755 index 0000000..498f08e --- /dev/null +++ b/agent/test/relay/relay_test.grpc.sh @@ -0,0 +1,306 @@ +#!/bin/bash +set -e + +# End-to-end test for the gRPC tunnel relay stack. +# This mirrors relay_test.sh but uses the gRPC tunnel server instead of snyk-broker. +# +# Components: +# - [server-side] grpc-tunnel-server: gRPC tunnel server with HTTP dispatch endpoint +# - [server-side] cortex-fake: mimics the Cortex registration API +# - [client-side] axon-relay: agent in gRPC tunnel mode (RELAY_MODE=grpc-tunnel) +# - [client-side] python-server: mimics an API that Cortex is calling out to +# - [optional] mitmproxy: HTTP proxy for proxy-mode testing + +export TOKEN=0e481b34-76ac-481a-a92f-c94a6cf6f6c1 +export GRPC_PORT=50152 +export HTTP_PORT=58180 + +if [ "$PROXY" == "1" ]; then + echo "TESTING WITH PROXY" + export ENVFILE=proxy.env + # Base docker-compose.grpc.yml has axon-relay on internal network only + # This enforces that gRPC connections MUST go through mitmproxy + export COMPOSE_FILES="-f docker-compose.grpc.yml" +else + echo "TESTING WITHOUT PROXY" + export ENVFILE=noproxy.env + # Add external network so axon-relay can connect directly to grpc-tunnel-server + export COMPOSE_FILES="-f docker-compose.grpc.yml -f docker-compose.grpc.noproxy.yml" + + # Also set the HTTP_PORT to a different value to ensure we respect that port + export HTTP_PORT=58280 +fi + +function cleanup { + echo "Cleanup: Stopping docker-compose" + docker compose $COMPOSE_FILES down + rm -f /tmp/token-* /tmp/axon-test-token /tmp/binary-test-*.bin /tmp/binary-test-*.downloaded +} +trap cleanup EXIT + +echo "Starting docker compose (gRPC tunnel)..." +docker compose $COMPOSE_FILES up -d +sleep 5 + +function get_container_status { + result_status=$(docker inspect -f '{{.State.Status}}' $1) + echo "Status $1 = $result_status" >&2 + echo $result_status +} + +if [ -n "$DEBUG" ]; then + echo "Debug mode enabled, sleeping indefinitely" + while true; do + sleep 5 + done +fi + +COUNTER=30 +SERVER_STATUS=$(get_container_status relay-grpc-tunnel-server-1) +AXON_STATUS=$(get_container_status relay-axon-relay-1) + +while [ "$SERVER_STATUS" != "running" ] || [ "$AXON_STATUS" != "running" ]; do + if [ $COUNTER -eq 0 ]; then + echo "Containers did not start in time" + docker compose $COMPOSE_FILES logs + exit 1 + fi + + echo "Waiting for containers to start" + sleep 1 + SERVER_STATUS=$(get_container_status relay-grpc-tunnel-server-1) + AXON_STATUS=$(get_container_status relay-axon-relay-1) + COUNTER=$((COUNTER-1)) +done + +# Wait for grpc-tunnel-server healthz (exposed to host via HTTP_PORT). +echo "Waiting for grpc-tunnel-server healthz..." +COUNTER=30 +while ! curl -sf http://localhost:$HTTP_PORT/healthz > /dev/null 2>&1; do + if [ $COUNTER -eq 0 ]; then + echo "grpc-tunnel-server healthz did not pass in time" + docker compose $COMPOSE_FILES logs grpc-tunnel-server + exit 1 + fi + sleep 1 + COUNTER=$((COUNTER-1)) +done +echo "grpc-tunnel-server is healthy" + +# Wait for at least one tunnel stream to register (agent connects via gRPC). +echo "Waiting for tunnel stream registration..." +COUNTER=60 +while true; do + HEALTH=$(curl -sf http://localhost:$HTTP_PORT/healthz 2>/dev/null || echo '{}') + STREAMS=$(echo "$HEALTH" | grep -o '"streams":[0-9]*' | grep -o '[0-9]*' || echo "0") + if [ "$STREAMS" -gt 0 ]; then + echo "Tunnel has $STREAMS active stream(s)" + break + fi + if [ $COUNTER -eq 0 ]; then + echo "No tunnel streams registered in time" + echo "Server health: $HEALTH" + docker compose $COMPOSE_FILES logs + exit 1 + fi + sleep 1 + COUNTER=$((COUNTER-1)) +done + +real_curl=$(which curl) + +function curlw { + [ -n "$DEBUG" ] && echo "Executing: $real_curl $@" >&2 + if ! curl_result=$($real_curl -s "$@" 2>&1); then + echo "Curl command failed: $@ ==> $curl_result" + exit 1 + else + [ -n "$DEBUG" ] && echo "curl $@ ==> $curl_result" >&2 + fi + echo "$curl_result" +} + +# Dispatch URL: grpc-tunnel-server HTTP port at /broker/{token}/{path} +DISPATCH_URL="http://localhost:$HTTP_PORT/broker/$TOKEN" + +echo "Checking relay broker passthrough..." +# Test relay of a text file through the gRPC tunnel. +# python-server serves files from /tmp. +FILENAME="token-$(date +%s)" +echo "$TOKEN" > /tmp/$FILENAME +echo "$TOKEN" > /tmp/axon-test-token +result=$(curlw $DISPATCH_URL/$FILENAME) + +if [ "$result" != "$TOKEN" ]; then + echo "FAIL: Expected $TOKEN, got $result" + docker compose $COMPOSE_FILES logs + exit 1 +fi +echo "Success: Text file relay through gRPC tunnel" + +echo "Checking binary file relay passthrough..." +BINARY_FILENAME="binary-test-$(date +%s).bin" +dd if=/dev/urandom of="/tmp/$BINARY_FILENAME" bs=1024 count=1536 2>/dev/null +ORIGINAL_CHECKSUM=$(sha256sum "/tmp/$BINARY_FILENAME" | awk '{print $1}') + +BINARY_DOWNLOAD="/tmp/${BINARY_FILENAME}.downloaded" +curl -s -f -o "$BINARY_DOWNLOAD" "$DISPATCH_URL/$BINARY_FILENAME" +DOWNLOAD_STATUS=$? +if [ $DOWNLOAD_STATUS -ne 0 ]; then + echo "FAIL: curl failed to download binary file (exit code $DOWNLOAD_STATUS)" + docker compose $COMPOSE_FILES logs + exit 1 +fi + +DOWNLOADED_CHECKSUM=$(sha256sum "$BINARY_DOWNLOAD" | awk '{print $1}') +if [ "$ORIGINAL_CHECKSUM" != "$DOWNLOADED_CHECKSUM" ]; then + echo "FAIL: Binary checksum mismatch" + echo " Original: $ORIGINAL_CHECKSUM ($(wc -c < /tmp/$BINARY_FILENAME) bytes)" + echo " Downloaded: $DOWNLOADED_CHECKSUM ($(wc -c < $BINARY_DOWNLOAD) bytes)" + exit 1 +else + echo "Success: Binary file (1.5MB) checksum verified ($ORIGINAL_CHECKSUM)" +fi + +# Validate HTTPS relay by fetching the Axon README from GitHub. +echo "Checking HTTPS relay (GitHub README)..." +if ! proxy_result=$(curlw -f -v $DISPATCH_URL/cortexapps/axon/refs/heads/main/README.md 2>&1); then + echo "FAIL: Expected to be able to read the axon readme from GitHub, but got error" + echo "$proxy_result" + docker compose $COMPOSE_FILES logs + exit 1 +fi +echo "Success: HTTPS relay through gRPC tunnel" + +if [ "$PROXY" == "1" ]; then + echo "Checking relay HTTP_PROXY config..." + if ! echo "$proxy_result" | grep -i "x-proxy-mitmproxy"; then + echo "FAIL: Expected 'x-proxy-mitmproxy' header, got nothing" + exit 1 + else + echo "Success: Found 'x-proxy-mitmproxy' header" + fi + + echo "Checking echo endpoint with injected headers..." + if ! proxy_result=$(curlw -f -v $DISPATCH_URL/echo/foobar 2>&1); then + echo "FAIL: Expected to echo 'foobar' via the proxy, but got error" + echo "$proxy_result" + exit 1 + fi + + if ! echo "$proxy_result" | grep -q "added-fake-server"; then + echo "FAIL: Expected injected header value but not found" + echo "$proxy_result" + exit 1 + else + echo "Success: Found expected injected header value in result" + fi + + if ! echo "$proxy_result" | grep -q "HOME=/root"; then + echo "FAIL: Expected injected plugin header value but not found" + echo "$proxy_result" + exit 1 + else + echo "Success: Found expected injected plugin header value in result" + fi + + # Verify gRPC tunnel streams are active (replaces WebSocket tunnel check). + echo "Checking gRPC tunnel streams..." + axon_logs=$(docker compose $COMPOSE_FILES logs axon-relay 2>&1) + if ! echo "$axon_logs" | grep -q "Tunnel stream established"; then + echo "FAIL: Expected 'Tunnel stream established' in agent logs but not found" + echo "=== Axon Relay Logs (last 50) ===" + echo "$axon_logs" | tail -50 + exit 1 + else + echo "Success: gRPC tunnel stream established" + fi +else + echo "Checking relay non-proxy config..." + if echo "$proxy_result" | grep -i "x-proxy-mitmproxy"; then + echo "FAIL: Expected no 'x-proxy-mitmproxy' header, got one" + exit 1 + else + echo "Success: Did not find 'x-proxy-mitmproxy' header (as expected)" + fi +fi + +echo "=== gRPC tunnel reconnection after SIGKILL ===" + +# Force-kill the grpc-tunnel-server container to simulate a non-graceful disconnect. +# This tears down the TCP connection without sending a gRPC GoAway frame, which +# is what happens when the tunnel server crashes or the network drops. +echo "Force-killing grpc-tunnel-server container..." +docker kill --signal=KILL relay-grpc-tunnel-server-1 + +# Wait for the container to be fully dead +sleep 2 +SERVER_STATUS=$(get_container_status relay-grpc-tunnel-server-1) +if [ "$SERVER_STATUS" == "running" ]; then + echo "FAIL: grpc-tunnel-server should be dead after SIGKILL" + exit 1 +fi +echo "grpc-tunnel-server is stopped (status=$SERVER_STATUS)" + +# Restart the tunnel server +echo "Restarting grpc-tunnel-server container..." +docker compose $COMPOSE_FILES up -d grpc-tunnel-server + +# Wait for the tunnel server to be healthy again +COUNTER=30 +while [ $COUNTER -gt 0 ]; do + SERVER_STATUS=$(get_container_status relay-grpc-tunnel-server-1) + if [ "$SERVER_STATUS" == "running" ]; then + if curl -s -f http://localhost:$HTTP_PORT/healthz > /dev/null 2>&1; then + break + fi + fi + echo "Waiting for grpc-tunnel-server to be healthy ($COUNTER)..." + sleep 1 + COUNTER=$((COUNTER-1)) +done + +if [ $COUNTER -eq 0 ]; then + echo "FAIL: grpc-tunnel-server did not become healthy in time" + docker compose $COMPOSE_FILES logs grpc-tunnel-server + exit 1 +fi +echo "grpc-tunnel-server is back up" + +# Give the axon relay time to detect the disconnect and reconnect. +# The gRPC tunnel client uses exponential backoff so 15s should be enough +# for the first reconnect attempt. +echo "Waiting for axon relay to reconnect..." +COUNTER=60 +while [ $COUNTER -gt 0 ]; do + HEALTH=$(curl -sf http://localhost:$HTTP_PORT/healthz 2>/dev/null || echo '{}') + STREAMS=$(echo "$HEALTH" | grep -o '"streams":[0-9]*' | grep -o '[0-9]*' || echo "0") + if [ "$STREAMS" -gt 0 ]; then + break + fi + echo "Waiting for tunnel stream re-establishment ($COUNTER)..." + sleep 1 + COUNTER=$((COUNTER-1)) +done + +if [ $COUNTER -eq 0 ]; then + echo "FAIL: Tunnel stream did not re-establish in time" + docker compose $COMPOSE_FILES logs --tail=80 axon-relay + exit 1 +fi +echo "Tunnel re-established with $STREAMS active stream(s)" + +# Now verify the relay is working again by sending a request through the tunnel +FILENAME="token-reconnect-$(date +%s)" +echo "$TOKEN" > /tmp/$FILENAME +result=$(curlw $DISPATCH_URL/$FILENAME) + +if [ "$result" != "$TOKEN" ]; then + echo "FAIL: Expected $TOKEN after reconnect, got '$result'" + echo "=== Axon Relay Logs (last 80) ===" + docker compose $COMPOSE_FILES logs --tail=80 axon-relay + exit 1 +fi +echo "Success: gRPC tunnel passthrough works after SIGKILL + restart" + +echo "Success! gRPC tunnel e2e test passed!" diff --git a/docker/Dockerfile b/docker/Dockerfile index 9b9c9c2..5b3ff23 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,6 +14,7 @@ ENV GOPROXY="https://proxy.golang.org,direct" ENV PATH="${GOROOT}/bin:${GOPATH}/bin:${PATH}" COPY agent/. /build/. +COPY proto/ /proto/ RUN make -C /build setup proto RUN cd /build && go build -o /agent/cortex-axon-agent diff --git a/proto/tunnel/tunnel.proto b/proto/tunnel/tunnel.proto new file mode 100644 index 0000000..9230408 --- /dev/null +++ b/proto/tunnel/tunnel.proto @@ -0,0 +1,117 @@ +syntax = "proto3"; +package cortex.axon.tunnel; + +option go_package = "github.com/cortexapps/axon-server/tunnelpb"; + +// TunnelService provides a bidirectional streaming tunnel between +// Axon agents (clients) and the gRPC relay server. The server dispatches +// HTTP requests from the Cortex backend through the tunnel to agents +// running inside customer networks. +service TunnelService { + // Tunnel establishes a persistent bidirectional stream. + // Client sends ClientHello first, then heartbeats and HTTP responses. + // Server sends ServerHello, then heartbeats and HTTP requests. + rpc Tunnel(stream TunnelClientMessage) returns (stream TunnelServerMessage); +} + +// Client → Server envelope +message TunnelClientMessage { + oneof message { + ClientHello hello = 1; + Heartbeat heartbeat = 2; + HttpResponse http_response = 3; + } +} + +// Server → Client envelope +message TunnelServerMessage { + oneof message { + ServerHello hello = 1; + Heartbeat heartbeat = 2; + HttpRequest http_request = 3; + } +} + +// ClientHello is the first message sent by the client after opening a stream. +// It carries identity information and the Cortex-API-issued broker token. +message ClientHello { + // Cortex-API-issued token, used for BROKER_SERVER dispatch routing. + string broker_token = 1; + // Client software version. + string client_version = 2; + // Tenant identifier from Cortex API registration. + string tenant_id = 3; + // Integration type, e.g. "github", "jira". + string integration = 4; + // Integration alias name. + string alias = 5; + // Unique agent instance ID (from config.InstanceId). + string instance_id = 6; + // Cortex API token for optional server-side JWT validation. + string cortex_api_token = 7; + // Arbitrary metadata. + map metadata = 8; +} + +// ServerHello is the response to ClientHello. +message ServerHello { + // Server hostname (HOSTNAME env var, or UUID fallback). + // Client uses this for metrics tagging and to detect duplicate connections + // to the same server instance (for multi-tunnel dedup). + string server_id = 1; + // Interval at which the server sends heartbeats. + // Client should respond within 2x this interval. + int32 heartbeat_interval_ms = 2; + // Server-generated UUID for this specific stream. + // Used in BROKER_SERVER client-connected/deleted notifications + // to distinguish multiple streams for the same token. + string stream_id = 3; +} + +// Heartbeat is sent by both sides to verify tunnel liveness. +message Heartbeat { + // Unix timestamp in milliseconds. + int64 timestamp_ms = 1; +} + +// HttpRequest is an HTTP request dispatched from the server to the client. +// Large bodies are chunked: the first chunk carries method/path/headers, +// subsequent chunks carry only body/chunk_index/is_final. +message HttpRequest { + // Unique request identifier for correlating responses. + string request_id = 1; + // HTTP method (GET, POST, PUT, DELETE, etc.). + string method = 2; + // Request path (e.g. "/api/v1/repos"). + string path = 3; + // HTTP headers. Only set on first chunk (chunk_index=0). + map headers = 4; + // Request body chunk. + bytes body = 5; + // Zero-based chunk index. + int32 chunk_index = 6; + // True if this is the last chunk for this request. + bool is_final = 7; + // Maximum time in milliseconds the server will wait for a response. + // Client should use this as a deadline for request execution. + // Only set on first chunk (chunk_index=0). 0 means no explicit timeout. + int32 timeout_ms = 8; +} + +// HttpResponse is an HTTP response sent from the client back to the server. +// Large bodies are chunked: the first chunk carries status_code/headers, +// subsequent chunks carry only body/chunk_index/is_final. +message HttpResponse { + // Must match the request_id of the corresponding HttpRequest. + string request_id = 1; + // HTTP status code. Only set on first chunk (chunk_index=0). + int32 status_code = 2; + // HTTP response headers. Only set on first chunk (chunk_index=0). + map headers = 3; + // Response body chunk. + bytes body = 4; + // Zero-based chunk index. + int32 chunk_index = 5; + // True if this is the last chunk for this response. + bool is_final = 6; +} diff --git a/server/.generated/proto/tunnelpb/tunnel.pb.go b/server/.generated/proto/tunnelpb/tunnel.pb.go new file mode 100644 index 0000000..ba76f2b --- /dev/null +++ b/server/.generated/proto/tunnelpb/tunnel.pb.go @@ -0,0 +1,794 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: tunnel.proto + +package tunnelpb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Client → Server envelope +type TunnelClientMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Message: + // + // *TunnelClientMessage_Hello + // *TunnelClientMessage_Heartbeat + // *TunnelClientMessage_HttpResponse + Message isTunnelClientMessage_Message `protobuf_oneof:"message"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TunnelClientMessage) Reset() { + *x = TunnelClientMessage{} + mi := &file_tunnel_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TunnelClientMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TunnelClientMessage) ProtoMessage() {} + +func (x *TunnelClientMessage) ProtoReflect() protoreflect.Message { + mi := &file_tunnel_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TunnelClientMessage.ProtoReflect.Descriptor instead. +func (*TunnelClientMessage) Descriptor() ([]byte, []int) { + return file_tunnel_proto_rawDescGZIP(), []int{0} +} + +func (x *TunnelClientMessage) GetMessage() isTunnelClientMessage_Message { + if x != nil { + return x.Message + } + return nil +} + +func (x *TunnelClientMessage) GetHello() *ClientHello { + if x != nil { + if x, ok := x.Message.(*TunnelClientMessage_Hello); ok { + return x.Hello + } + } + return nil +} + +func (x *TunnelClientMessage) GetHeartbeat() *Heartbeat { + if x != nil { + if x, ok := x.Message.(*TunnelClientMessage_Heartbeat); ok { + return x.Heartbeat + } + } + return nil +} + +func (x *TunnelClientMessage) GetHttpResponse() *HttpResponse { + if x != nil { + if x, ok := x.Message.(*TunnelClientMessage_HttpResponse); ok { + return x.HttpResponse + } + } + return nil +} + +type isTunnelClientMessage_Message interface { + isTunnelClientMessage_Message() +} + +type TunnelClientMessage_Hello struct { + Hello *ClientHello `protobuf:"bytes,1,opt,name=hello,proto3,oneof"` +} + +type TunnelClientMessage_Heartbeat struct { + Heartbeat *Heartbeat `protobuf:"bytes,2,opt,name=heartbeat,proto3,oneof"` +} + +type TunnelClientMessage_HttpResponse struct { + HttpResponse *HttpResponse `protobuf:"bytes,3,opt,name=http_response,json=httpResponse,proto3,oneof"` +} + +func (*TunnelClientMessage_Hello) isTunnelClientMessage_Message() {} + +func (*TunnelClientMessage_Heartbeat) isTunnelClientMessage_Message() {} + +func (*TunnelClientMessage_HttpResponse) isTunnelClientMessage_Message() {} + +// Server → Client envelope +type TunnelServerMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Message: + // + // *TunnelServerMessage_Hello + // *TunnelServerMessage_Heartbeat + // *TunnelServerMessage_HttpRequest + Message isTunnelServerMessage_Message `protobuf_oneof:"message"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TunnelServerMessage) Reset() { + *x = TunnelServerMessage{} + mi := &file_tunnel_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TunnelServerMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TunnelServerMessage) ProtoMessage() {} + +func (x *TunnelServerMessage) ProtoReflect() protoreflect.Message { + mi := &file_tunnel_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TunnelServerMessage.ProtoReflect.Descriptor instead. +func (*TunnelServerMessage) Descriptor() ([]byte, []int) { + return file_tunnel_proto_rawDescGZIP(), []int{1} +} + +func (x *TunnelServerMessage) GetMessage() isTunnelServerMessage_Message { + if x != nil { + return x.Message + } + return nil +} + +func (x *TunnelServerMessage) GetHello() *ServerHello { + if x != nil { + if x, ok := x.Message.(*TunnelServerMessage_Hello); ok { + return x.Hello + } + } + return nil +} + +func (x *TunnelServerMessage) GetHeartbeat() *Heartbeat { + if x != nil { + if x, ok := x.Message.(*TunnelServerMessage_Heartbeat); ok { + return x.Heartbeat + } + } + return nil +} + +func (x *TunnelServerMessage) GetHttpRequest() *HttpRequest { + if x != nil { + if x, ok := x.Message.(*TunnelServerMessage_HttpRequest); ok { + return x.HttpRequest + } + } + return nil +} + +type isTunnelServerMessage_Message interface { + isTunnelServerMessage_Message() +} + +type TunnelServerMessage_Hello struct { + Hello *ServerHello `protobuf:"bytes,1,opt,name=hello,proto3,oneof"` +} + +type TunnelServerMessage_Heartbeat struct { + Heartbeat *Heartbeat `protobuf:"bytes,2,opt,name=heartbeat,proto3,oneof"` +} + +type TunnelServerMessage_HttpRequest struct { + HttpRequest *HttpRequest `protobuf:"bytes,3,opt,name=http_request,json=httpRequest,proto3,oneof"` +} + +func (*TunnelServerMessage_Hello) isTunnelServerMessage_Message() {} + +func (*TunnelServerMessage_Heartbeat) isTunnelServerMessage_Message() {} + +func (*TunnelServerMessage_HttpRequest) isTunnelServerMessage_Message() {} + +// ClientHello is the first message sent by the client after opening a stream. +// It carries identity information and the Cortex-API-issued broker token. +type ClientHello struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Cortex-API-issued token, used for BROKER_SERVER dispatch routing. + BrokerToken string `protobuf:"bytes,1,opt,name=broker_token,json=brokerToken,proto3" json:"broker_token,omitempty"` + // Client software version. + ClientVersion string `protobuf:"bytes,2,opt,name=client_version,json=clientVersion,proto3" json:"client_version,omitempty"` + // Tenant identifier from Cortex API registration. + TenantId string `protobuf:"bytes,3,opt,name=tenant_id,json=tenantId,proto3" json:"tenant_id,omitempty"` + // Integration type, e.g. "github", "jira". + Integration string `protobuf:"bytes,4,opt,name=integration,proto3" json:"integration,omitempty"` + // Integration alias name. + Alias string `protobuf:"bytes,5,opt,name=alias,proto3" json:"alias,omitempty"` + // Unique agent instance ID (from config.InstanceId). + InstanceId string `protobuf:"bytes,6,opt,name=instance_id,json=instanceId,proto3" json:"instance_id,omitempty"` + // Cortex API token for optional server-side JWT validation. + CortexApiToken string `protobuf:"bytes,7,opt,name=cortex_api_token,json=cortexApiToken,proto3" json:"cortex_api_token,omitempty"` + // Arbitrary metadata. + Metadata map[string]string `protobuf:"bytes,8,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ClientHello) Reset() { + *x = ClientHello{} + mi := &file_tunnel_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ClientHello) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientHello) ProtoMessage() {} + +func (x *ClientHello) ProtoReflect() protoreflect.Message { + mi := &file_tunnel_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientHello.ProtoReflect.Descriptor instead. +func (*ClientHello) Descriptor() ([]byte, []int) { + return file_tunnel_proto_rawDescGZIP(), []int{2} +} + +func (x *ClientHello) GetBrokerToken() string { + if x != nil { + return x.BrokerToken + } + return "" +} + +func (x *ClientHello) GetClientVersion() string { + if x != nil { + return x.ClientVersion + } + return "" +} + +func (x *ClientHello) GetTenantId() string { + if x != nil { + return x.TenantId + } + return "" +} + +func (x *ClientHello) GetIntegration() string { + if x != nil { + return x.Integration + } + return "" +} + +func (x *ClientHello) GetAlias() string { + if x != nil { + return x.Alias + } + return "" +} + +func (x *ClientHello) GetInstanceId() string { + if x != nil { + return x.InstanceId + } + return "" +} + +func (x *ClientHello) GetCortexApiToken() string { + if x != nil { + return x.CortexApiToken + } + return "" +} + +func (x *ClientHello) GetMetadata() map[string]string { + if x != nil { + return x.Metadata + } + return nil +} + +// ServerHello is the response to ClientHello. +type ServerHello struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Server hostname (HOSTNAME env var, or UUID fallback). + // Client uses this for metrics tagging and to detect duplicate connections + // to the same server instance (for multi-tunnel dedup). + ServerId string `protobuf:"bytes,1,opt,name=server_id,json=serverId,proto3" json:"server_id,omitempty"` + // Interval at which the server sends heartbeats. + // Client should respond within 2x this interval. + HeartbeatIntervalMs int32 `protobuf:"varint,2,opt,name=heartbeat_interval_ms,json=heartbeatIntervalMs,proto3" json:"heartbeat_interval_ms,omitempty"` + // Server-generated UUID for this specific stream. + // Used in BROKER_SERVER client-connected/deleted notifications + // to distinguish multiple streams for the same token. + StreamId string `protobuf:"bytes,3,opt,name=stream_id,json=streamId,proto3" json:"stream_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ServerHello) Reset() { + *x = ServerHello{} + mi := &file_tunnel_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ServerHello) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ServerHello) ProtoMessage() {} + +func (x *ServerHello) ProtoReflect() protoreflect.Message { + mi := &file_tunnel_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ServerHello.ProtoReflect.Descriptor instead. +func (*ServerHello) Descriptor() ([]byte, []int) { + return file_tunnel_proto_rawDescGZIP(), []int{3} +} + +func (x *ServerHello) GetServerId() string { + if x != nil { + return x.ServerId + } + return "" +} + +func (x *ServerHello) GetHeartbeatIntervalMs() int32 { + if x != nil { + return x.HeartbeatIntervalMs + } + return 0 +} + +func (x *ServerHello) GetStreamId() string { + if x != nil { + return x.StreamId + } + return "" +} + +// Heartbeat is sent by both sides to verify tunnel liveness. +type Heartbeat struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Unix timestamp in milliseconds. + TimestampMs int64 `protobuf:"varint,1,opt,name=timestamp_ms,json=timestampMs,proto3" json:"timestamp_ms,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Heartbeat) Reset() { + *x = Heartbeat{} + mi := &file_tunnel_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Heartbeat) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Heartbeat) ProtoMessage() {} + +func (x *Heartbeat) ProtoReflect() protoreflect.Message { + mi := &file_tunnel_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Heartbeat.ProtoReflect.Descriptor instead. +func (*Heartbeat) Descriptor() ([]byte, []int) { + return file_tunnel_proto_rawDescGZIP(), []int{4} +} + +func (x *Heartbeat) GetTimestampMs() int64 { + if x != nil { + return x.TimestampMs + } + return 0 +} + +// HttpRequest is an HTTP request dispatched from the server to the client. +// Large bodies are chunked: the first chunk carries method/path/headers, +// subsequent chunks carry only body/chunk_index/is_final. +type HttpRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Unique request identifier for correlating responses. + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // HTTP method (GET, POST, PUT, DELETE, etc.). + Method string `protobuf:"bytes,2,opt,name=method,proto3" json:"method,omitempty"` + // Request path (e.g. "/api/v1/repos"). + Path string `protobuf:"bytes,3,opt,name=path,proto3" json:"path,omitempty"` + // HTTP headers. Only set on first chunk (chunk_index=0). + Headers map[string]string `protobuf:"bytes,4,rep,name=headers,proto3" json:"headers,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // Request body chunk. + Body []byte `protobuf:"bytes,5,opt,name=body,proto3" json:"body,omitempty"` + // Zero-based chunk index. + ChunkIndex int32 `protobuf:"varint,6,opt,name=chunk_index,json=chunkIndex,proto3" json:"chunk_index,omitempty"` + // True if this is the last chunk for this request. + IsFinal bool `protobuf:"varint,7,opt,name=is_final,json=isFinal,proto3" json:"is_final,omitempty"` + // Maximum time in milliseconds the server will wait for a response. + // Client should use this as a deadline for request execution. + // Only set on first chunk (chunk_index=0). 0 means no explicit timeout. + TimeoutMs int32 `protobuf:"varint,8,opt,name=timeout_ms,json=timeoutMs,proto3" json:"timeout_ms,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HttpRequest) Reset() { + *x = HttpRequest{} + mi := &file_tunnel_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HttpRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HttpRequest) ProtoMessage() {} + +func (x *HttpRequest) ProtoReflect() protoreflect.Message { + mi := &file_tunnel_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HttpRequest.ProtoReflect.Descriptor instead. +func (*HttpRequest) Descriptor() ([]byte, []int) { + return file_tunnel_proto_rawDescGZIP(), []int{5} +} + +func (x *HttpRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *HttpRequest) GetMethod() string { + if x != nil { + return x.Method + } + return "" +} + +func (x *HttpRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *HttpRequest) GetHeaders() map[string]string { + if x != nil { + return x.Headers + } + return nil +} + +func (x *HttpRequest) GetBody() []byte { + if x != nil { + return x.Body + } + return nil +} + +func (x *HttpRequest) GetChunkIndex() int32 { + if x != nil { + return x.ChunkIndex + } + return 0 +} + +func (x *HttpRequest) GetIsFinal() bool { + if x != nil { + return x.IsFinal + } + return false +} + +func (x *HttpRequest) GetTimeoutMs() int32 { + if x != nil { + return x.TimeoutMs + } + return 0 +} + +// HttpResponse is an HTTP response sent from the client back to the server. +// Large bodies are chunked: the first chunk carries status_code/headers, +// subsequent chunks carry only body/chunk_index/is_final. +type HttpResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Must match the request_id of the corresponding HttpRequest. + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // HTTP status code. Only set on first chunk (chunk_index=0). + StatusCode int32 `protobuf:"varint,2,opt,name=status_code,json=statusCode,proto3" json:"status_code,omitempty"` + // HTTP response headers. Only set on first chunk (chunk_index=0). + Headers map[string]string `protobuf:"bytes,3,rep,name=headers,proto3" json:"headers,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // Response body chunk. + Body []byte `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"` + // Zero-based chunk index. + ChunkIndex int32 `protobuf:"varint,5,opt,name=chunk_index,json=chunkIndex,proto3" json:"chunk_index,omitempty"` + // True if this is the last chunk for this response. + IsFinal bool `protobuf:"varint,6,opt,name=is_final,json=isFinal,proto3" json:"is_final,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HttpResponse) Reset() { + *x = HttpResponse{} + mi := &file_tunnel_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HttpResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HttpResponse) ProtoMessage() {} + +func (x *HttpResponse) ProtoReflect() protoreflect.Message { + mi := &file_tunnel_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HttpResponse.ProtoReflect.Descriptor instead. +func (*HttpResponse) Descriptor() ([]byte, []int) { + return file_tunnel_proto_rawDescGZIP(), []int{6} +} + +func (x *HttpResponse) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *HttpResponse) GetStatusCode() int32 { + if x != nil { + return x.StatusCode + } + return 0 +} + +func (x *HttpResponse) GetHeaders() map[string]string { + if x != nil { + return x.Headers + } + return nil +} + +func (x *HttpResponse) GetBody() []byte { + if x != nil { + return x.Body + } + return nil +} + +func (x *HttpResponse) GetChunkIndex() int32 { + if x != nil { + return x.ChunkIndex + } + return 0 +} + +func (x *HttpResponse) GetIsFinal() bool { + if x != nil { + return x.IsFinal + } + return false +} + +var File_tunnel_proto protoreflect.FileDescriptor + +const file_tunnel_proto_rawDesc = "" + + "\n" + + "\ftunnel.proto\x12\x12cortex.axon.tunnel\"\xe1\x01\n" + + "\x13TunnelClientMessage\x127\n" + + "\x05hello\x18\x01 \x01(\v2\x1f.cortex.axon.tunnel.ClientHelloH\x00R\x05hello\x12=\n" + + "\theartbeat\x18\x02 \x01(\v2\x1d.cortex.axon.tunnel.HeartbeatH\x00R\theartbeat\x12G\n" + + "\rhttp_response\x18\x03 \x01(\v2 .cortex.axon.tunnel.HttpResponseH\x00R\fhttpResponseB\t\n" + + "\amessage\"\xde\x01\n" + + "\x13TunnelServerMessage\x127\n" + + "\x05hello\x18\x01 \x01(\v2\x1f.cortex.axon.tunnel.ServerHelloH\x00R\x05hello\x12=\n" + + "\theartbeat\x18\x02 \x01(\v2\x1d.cortex.axon.tunnel.HeartbeatH\x00R\theartbeat\x12D\n" + + "\fhttp_request\x18\x03 \x01(\v2\x1f.cortex.axon.tunnel.HttpRequestH\x00R\vhttpRequestB\t\n" + + "\amessage\"\xff\x02\n" + + "\vClientHello\x12!\n" + + "\fbroker_token\x18\x01 \x01(\tR\vbrokerToken\x12%\n" + + "\x0eclient_version\x18\x02 \x01(\tR\rclientVersion\x12\x1b\n" + + "\ttenant_id\x18\x03 \x01(\tR\btenantId\x12 \n" + + "\vintegration\x18\x04 \x01(\tR\vintegration\x12\x14\n" + + "\x05alias\x18\x05 \x01(\tR\x05alias\x12\x1f\n" + + "\vinstance_id\x18\x06 \x01(\tR\n" + + "instanceId\x12(\n" + + "\x10cortex_api_token\x18\a \x01(\tR\x0ecortexApiToken\x12I\n" + + "\bmetadata\x18\b \x03(\v2-.cortex.axon.tunnel.ClientHello.MetadataEntryR\bmetadata\x1a;\n" + + "\rMetadataEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"{\n" + + "\vServerHello\x12\x1b\n" + + "\tserver_id\x18\x01 \x01(\tR\bserverId\x122\n" + + "\x15heartbeat_interval_ms\x18\x02 \x01(\x05R\x13heartbeatIntervalMs\x12\x1b\n" + + "\tstream_id\x18\x03 \x01(\tR\bstreamId\".\n" + + "\tHeartbeat\x12!\n" + + "\ftimestamp_ms\x18\x01 \x01(\x03R\vtimestampMs\"\xcb\x02\n" + + "\vHttpRequest\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x16\n" + + "\x06method\x18\x02 \x01(\tR\x06method\x12\x12\n" + + "\x04path\x18\x03 \x01(\tR\x04path\x12F\n" + + "\aheaders\x18\x04 \x03(\v2,.cortex.axon.tunnel.HttpRequest.HeadersEntryR\aheaders\x12\x12\n" + + "\x04body\x18\x05 \x01(\fR\x04body\x12\x1f\n" + + "\vchunk_index\x18\x06 \x01(\x05R\n" + + "chunkIndex\x12\x19\n" + + "\bis_final\x18\a \x01(\bR\aisFinal\x12\x1d\n" + + "\n" + + "timeout_ms\x18\b \x01(\x05R\ttimeoutMs\x1a:\n" + + "\fHeadersEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\xa3\x02\n" + + "\fHttpResponse\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x1f\n" + + "\vstatus_code\x18\x02 \x01(\x05R\n" + + "statusCode\x12G\n" + + "\aheaders\x18\x03 \x03(\v2-.cortex.axon.tunnel.HttpResponse.HeadersEntryR\aheaders\x12\x12\n" + + "\x04body\x18\x04 \x01(\fR\x04body\x12\x1f\n" + + "\vchunk_index\x18\x05 \x01(\x05R\n" + + "chunkIndex\x12\x19\n" + + "\bis_final\x18\x06 \x01(\bR\aisFinal\x1a:\n" + + "\fHeadersEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x012o\n" + + "\rTunnelService\x12^\n" + + "\x06Tunnel\x12'.cortex.axon.tunnel.TunnelClientMessage\x1a'.cortex.axon.tunnel.TunnelServerMessage(\x010\x01B,Z*github.com/cortexapps/axon-server/tunnelpbb\x06proto3" + +var ( + file_tunnel_proto_rawDescOnce sync.Once + file_tunnel_proto_rawDescData []byte +) + +func file_tunnel_proto_rawDescGZIP() []byte { + file_tunnel_proto_rawDescOnce.Do(func() { + file_tunnel_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_tunnel_proto_rawDesc), len(file_tunnel_proto_rawDesc))) + }) + return file_tunnel_proto_rawDescData +} + +var file_tunnel_proto_msgTypes = make([]protoimpl.MessageInfo, 10) +var file_tunnel_proto_goTypes = []any{ + (*TunnelClientMessage)(nil), // 0: cortex.axon.tunnel.TunnelClientMessage + (*TunnelServerMessage)(nil), // 1: cortex.axon.tunnel.TunnelServerMessage + (*ClientHello)(nil), // 2: cortex.axon.tunnel.ClientHello + (*ServerHello)(nil), // 3: cortex.axon.tunnel.ServerHello + (*Heartbeat)(nil), // 4: cortex.axon.tunnel.Heartbeat + (*HttpRequest)(nil), // 5: cortex.axon.tunnel.HttpRequest + (*HttpResponse)(nil), // 6: cortex.axon.tunnel.HttpResponse + nil, // 7: cortex.axon.tunnel.ClientHello.MetadataEntry + nil, // 8: cortex.axon.tunnel.HttpRequest.HeadersEntry + nil, // 9: cortex.axon.tunnel.HttpResponse.HeadersEntry +} +var file_tunnel_proto_depIdxs = []int32{ + 2, // 0: cortex.axon.tunnel.TunnelClientMessage.hello:type_name -> cortex.axon.tunnel.ClientHello + 4, // 1: cortex.axon.tunnel.TunnelClientMessage.heartbeat:type_name -> cortex.axon.tunnel.Heartbeat + 6, // 2: cortex.axon.tunnel.TunnelClientMessage.http_response:type_name -> cortex.axon.tunnel.HttpResponse + 3, // 3: cortex.axon.tunnel.TunnelServerMessage.hello:type_name -> cortex.axon.tunnel.ServerHello + 4, // 4: cortex.axon.tunnel.TunnelServerMessage.heartbeat:type_name -> cortex.axon.tunnel.Heartbeat + 5, // 5: cortex.axon.tunnel.TunnelServerMessage.http_request:type_name -> cortex.axon.tunnel.HttpRequest + 7, // 6: cortex.axon.tunnel.ClientHello.metadata:type_name -> cortex.axon.tunnel.ClientHello.MetadataEntry + 8, // 7: cortex.axon.tunnel.HttpRequest.headers:type_name -> cortex.axon.tunnel.HttpRequest.HeadersEntry + 9, // 8: cortex.axon.tunnel.HttpResponse.headers:type_name -> cortex.axon.tunnel.HttpResponse.HeadersEntry + 0, // 9: cortex.axon.tunnel.TunnelService.Tunnel:input_type -> cortex.axon.tunnel.TunnelClientMessage + 1, // 10: cortex.axon.tunnel.TunnelService.Tunnel:output_type -> cortex.axon.tunnel.TunnelServerMessage + 10, // [10:11] is the sub-list for method output_type + 9, // [9:10] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name +} + +func init() { file_tunnel_proto_init() } +func file_tunnel_proto_init() { + if File_tunnel_proto != nil { + return + } + file_tunnel_proto_msgTypes[0].OneofWrappers = []any{ + (*TunnelClientMessage_Hello)(nil), + (*TunnelClientMessage_Heartbeat)(nil), + (*TunnelClientMessage_HttpResponse)(nil), + } + file_tunnel_proto_msgTypes[1].OneofWrappers = []any{ + (*TunnelServerMessage_Hello)(nil), + (*TunnelServerMessage_Heartbeat)(nil), + (*TunnelServerMessage_HttpRequest)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_tunnel_proto_rawDesc), len(file_tunnel_proto_rawDesc)), + NumEnums: 0, + NumMessages: 10, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_tunnel_proto_goTypes, + DependencyIndexes: file_tunnel_proto_depIdxs, + MessageInfos: file_tunnel_proto_msgTypes, + }.Build() + File_tunnel_proto = out.File + file_tunnel_proto_goTypes = nil + file_tunnel_proto_depIdxs = nil +} diff --git a/server/.generated/proto/tunnelpb/tunnel_grpc.pb.go b/server/.generated/proto/tunnelpb/tunnel_grpc.pb.go new file mode 100644 index 0000000..85a6be3 --- /dev/null +++ b/server/.generated/proto/tunnelpb/tunnel_grpc.pb.go @@ -0,0 +1,131 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 +// source: tunnel.proto + +package tunnelpb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + TunnelService_Tunnel_FullMethodName = "/cortex.axon.tunnel.TunnelService/Tunnel" +) + +// TunnelServiceClient is the client API for TunnelService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// TunnelService provides a bidirectional streaming tunnel between +// Axon agents (clients) and the gRPC relay server. The server dispatches +// HTTP requests from the Cortex backend through the tunnel to agents +// running inside customer networks. +type TunnelServiceClient interface { + // Tunnel establishes a persistent bidirectional stream. + // Client sends ClientHello first, then heartbeats and HTTP responses. + // Server sends ServerHello, then heartbeats and HTTP requests. + Tunnel(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[TunnelClientMessage, TunnelServerMessage], error) +} + +type tunnelServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewTunnelServiceClient(cc grpc.ClientConnInterface) TunnelServiceClient { + return &tunnelServiceClient{cc} +} + +func (c *tunnelServiceClient) Tunnel(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[TunnelClientMessage, TunnelServerMessage], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &TunnelService_ServiceDesc.Streams[0], TunnelService_Tunnel_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[TunnelClientMessage, TunnelServerMessage]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type TunnelService_TunnelClient = grpc.BidiStreamingClient[TunnelClientMessage, TunnelServerMessage] + +// TunnelServiceServer is the server API for TunnelService service. +// All implementations must embed UnimplementedTunnelServiceServer +// for forward compatibility. +// +// TunnelService provides a bidirectional streaming tunnel between +// Axon agents (clients) and the gRPC relay server. The server dispatches +// HTTP requests from the Cortex backend through the tunnel to agents +// running inside customer networks. +type TunnelServiceServer interface { + // Tunnel establishes a persistent bidirectional stream. + // Client sends ClientHello first, then heartbeats and HTTP responses. + // Server sends ServerHello, then heartbeats and HTTP requests. + Tunnel(grpc.BidiStreamingServer[TunnelClientMessage, TunnelServerMessage]) error + mustEmbedUnimplementedTunnelServiceServer() +} + +// UnimplementedTunnelServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedTunnelServiceServer struct{} + +func (UnimplementedTunnelServiceServer) Tunnel(grpc.BidiStreamingServer[TunnelClientMessage, TunnelServerMessage]) error { + return status.Error(codes.Unimplemented, "method Tunnel not implemented") +} +func (UnimplementedTunnelServiceServer) mustEmbedUnimplementedTunnelServiceServer() {} +func (UnimplementedTunnelServiceServer) testEmbeddedByValue() {} + +// UnsafeTunnelServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to TunnelServiceServer will +// result in compilation errors. +type UnsafeTunnelServiceServer interface { + mustEmbedUnimplementedTunnelServiceServer() +} + +func RegisterTunnelServiceServer(s grpc.ServiceRegistrar, srv TunnelServiceServer) { + // If the following call panics, it indicates UnimplementedTunnelServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&TunnelService_ServiceDesc, srv) +} + +func _TunnelService_Tunnel_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TunnelServiceServer).Tunnel(&grpc.GenericServerStream[TunnelClientMessage, TunnelServerMessage]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type TunnelService_TunnelServer = grpc.BidiStreamingServer[TunnelClientMessage, TunnelServerMessage] + +// TunnelService_ServiceDesc is the grpc.ServiceDesc for TunnelService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var TunnelService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "cortex.axon.tunnel.TunnelService", + HandlerType: (*TunnelServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Tunnel", + Handler: _TunnelService_Tunnel_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "tunnel.proto", +} diff --git a/server/Makefile b/server/Makefile new file mode 100644 index 0000000..84e9b98 --- /dev/null +++ b/server/Makefile @@ -0,0 +1,36 @@ +PROTO_DIR = ../proto/tunnel +GENERATED_DIR = .generated/proto +GENERATED_PATH = $(GENERATED_DIR)/tunnel + +GOPATH ?= $(HOME)/go +GOBIN ?= $(GOPATH)/bin + +PROTO_FILES := $(wildcard $(PROTO_DIR)/*.proto) +GO_FILES := $(patsubst $(PROTO_DIR)/%.proto,$(GENERATED_PATH)/%.pb.go,$(PROTO_FILES)) + +all: setup proto build + +proto: setup $(GO_FILES) + +$(GENERATED_PATH)/%.pb.go: $(PROTO_DIR)/%.proto + @echo "Generating protobuf: $< ==> $@" + @mkdir -p $(GENERATED_DIR) + protoc -I=$(PROTO_DIR) \ + --go_out=$(GENERATED_DIR) --go_opt=module=github.com/cortexapps/axon-server \ + --go-grpc_out=$(GENERATED_DIR) --go-grpc_opt=module=github.com/cortexapps/axon-server \ + $< + +setup: + @if ! command -v $(GOBIN)/protoc-gen-go >/dev/null; then echo "Installing protoc-gen-go..."; go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6; fi + @if ! command -v $(GOBIN)/protoc-gen-go-grpc >/dev/null; then echo "Installing protoc-gen-go-grpc..."; go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1; fi + +build: proto + go build -o axon-server ./cmd/... + +test: + go test -v ./... + +clean: + rm -rf $(GENERATED_DIR) axon-server + +.PHONY: all setup proto build test clean diff --git a/server/broker/broker_server_client.go b/server/broker/broker_server_client.go new file mode 100644 index 0000000..f568762 --- /dev/null +++ b/server/broker/broker_server_client.go @@ -0,0 +1,239 @@ +package broker + +import ( + "bytes" + "crypto/sha256" + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" + + "go.uber.org/zap" +) + +const dispatcherAPIVersion = "2022-12-02~experimental" + +// Client wraps all BROKER_SERVER (dispatcher) HTTP API interactions. +// Paths match the snyk-broker dispatcher API: +// +// POST /internal/brokerservers/{serverId} — server starting +// DELETE /internal/brokerservers/{serverId} — server stopping +// POST /internal/brokerservers/{serverId}/connections/{hashedToken} — client connected +// DELETE /internal/brokerservers/{serverId}/connections/{hashedToken} — client disconnected +type Client struct { + baseURL string + serverID string + httpClient *http.Client + logger *zap.Logger +} + +// NewClient creates a new BROKER_SERVER client. +// If baseURL is empty, all operations are no-ops (for testing/dev). +func NewClient(baseURL string, serverID string, logger *zap.Logger) *Client { + return &Client{ + baseURL: baseURL, + serverID: serverID, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + logger: logger, + } +} + +// IsConfigured returns true if a BROKER_SERVER URL is set. +func (c *Client) IsConfigured() bool { + return c.baseURL != "" +} + +// Token encapsulates a raw broker token and its SHA-256 hash. +// Create via NewToken (from raw) or TokenFromHash (from pre-hashed). +type Token struct { + raw string + hashed string +} + +// NewToken creates a Token from a raw broker token, computing the SHA-256 hash. +func NewToken(raw string) Token { + h := sha256.Sum256([]byte(raw)) + return Token{ + raw: raw, + hashed: fmt.Sprintf("%x", h[:]), + } +} + +// TokenFromHash creates a Token from an already-hashed value (no raw token available). +func TokenFromHash(hashed string) Token { + return Token{hashed: hashed} +} + +// Raw returns the original unhashed token. May be empty if created via TokenFromHash. +func (t Token) Raw() string { return t.raw } + +// Hashed returns the SHA-256 hex hash of the token. +func (t Token) Hashed() string { return t.hashed } + +// String returns a safe representation of the token for logging, +// showing only the first 12 characters of the hash to prevent accidental +// raw token exposure via %v or %s. +func (t Token) String() string { + h := t.hashed + if len(h) > 12 { + h = h[:12] + } + return fmt.Sprintf("Token{hash=%s...}", h) +} + +// jsonAPIBody wraps request bodies in the JSONAPI envelope expected by the dispatcher. +type jsonAPIBody struct { + Data jsonAPIData `json:"data"` +} + +type jsonAPIData struct { + Attributes map[string]string `json:"attributes"` +} + +// ClientConnected notifies the BROKER_SERVER that a client has connected. +// POST /internal/brokerservers/{serverId}/connections/{hashedToken}?broker_client_id=...&request_type=client-connected&version=... +func (c *Client) ClientConnected(token Token, clientID string, metadata map[string]string) error { + if !c.IsConfigured() { + return nil + } + + path := fmt.Sprintf("/internal/brokerservers/%s/connections/%s", c.serverID, token.Hashed()) + + params := url.Values{} + if clientID != "" { + params.Set("broker_client_id", clientID) + } + params.Set("request_type", "client-connected") + + body := jsonAPIBody{ + Data: jsonAPIData{ + Attributes: map[string]string{ + "health_check_link": fmt.Sprintf("http://%s/healthcheck", c.serverID), + }, + }, + } + + // Merge any additional metadata into attributes. + if metadata != nil { + for k, v := range metadata { + body.Data.Attributes[k] = v + } + } + + return c.doRequest(http.MethodPost, path, params, body) +} + +// ClientDisconnected notifies the BROKER_SERVER that a client has disconnected. +// DELETE /internal/brokerservers/{serverId}/connections/{hashedToken}?broker_client_id=...&version=... +func (c *Client) ClientDisconnected(token Token, clientID string) error { + if !c.IsConfigured() { + return nil + } + + path := fmt.Sprintf("/internal/brokerservers/%s/connections/%s", c.serverID, token.Hashed()) + + params := url.Values{} + if clientID != "" { + params.Set("broker_client_id", clientID) + } + + return c.doRequest(http.MethodDelete, path, params, nil) +} + +// ServerStarting notifies the BROKER_SERVER that this server instance has started. +// POST /internal/brokerservers/{serverId}?version=... +func (c *Client) ServerStarting(hostname string) error { + if !c.IsConfigured() { + return nil + } + + path := fmt.Sprintf("/internal/brokerservers/%s", c.serverID) + + body := jsonAPIBody{ + Data: jsonAPIData{ + Attributes: map[string]string{ + "health_check_link": fmt.Sprintf("http://%s/healthcheck", hostname), + }, + }, + } + + return c.doRequest(http.MethodPost, path, nil, body) +} + +// ServerStopping notifies the BROKER_SERVER that this server instance is shutting down. +// DELETE /internal/brokerservers/{serverId}?version=... +func (c *Client) ServerStopping() error { + if !c.IsConfigured() { + return nil + } + + path := fmt.Sprintf("/internal/brokerservers/%s", c.serverID) + return c.doRequest(http.MethodDelete, path, nil, nil) +} + +// doRequest sends a request to the dispatcher API with the required version param and content type. +func (c *Client) doRequest(method, path string, params url.Values, body any) error { + u, err := url.Parse(c.baseURL + path) + if err != nil { + return fmt.Errorf("parse URL: %w", err) + } + + // Merge params and always add version. + q := u.Query() + if params != nil { + for k, vs := range params { + for _, v := range vs { + q.Set(k, v) + } + } + } + q.Set("version", dispatcherAPIVersion) + u.RawQuery = q.Encode() + + var reqBody *bytes.Reader + if body != nil { + jsonBody, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal request body: %w", err) + } + reqBody = bytes.NewReader(jsonBody) + } + + var httpReq *http.Request + if reqBody != nil { + httpReq, err = http.NewRequest(method, u.String(), reqBody) + } else { + httpReq, err = http.NewRequest(method, u.String(), nil) + } + if err != nil { + return fmt.Errorf("create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/vnd.api+json") + httpReq.Header.Set("Connection", "Keep-Alive") + httpReq.Header.Set("Keep-Alive", "timeout=60, max=10") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + c.logger.Warn("BROKER_SERVER request failed", + zap.String("method", method), + zap.String("path", path), + zap.Error(err), + ) + return fmt.Errorf("broker-server %s %s: %w", method, path, err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + c.logger.Warn("BROKER_SERVER returned non-success status", + zap.String("method", method), + zap.String("path", path), + zap.Int("status", resp.StatusCode), + ) + return fmt.Errorf("broker-server %s %s: status %d", method, path, resp.StatusCode) + } + + return nil +} diff --git a/server/broker/broker_server_client_test.go b/server/broker/broker_server_client_test.go new file mode 100644 index 0000000..d927c7d --- /dev/null +++ b/server/broker/broker_server_client_test.go @@ -0,0 +1,147 @@ +package broker + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestNewToken(t *testing.T) { + token := NewToken("my-secret-token") + assert.Equal(t, "my-secret-token", token.Raw()) + assert.Len(t, token.Hashed(), 64) // SHA-256 hex = 64 chars + + // Same input produces same hash. + assert.Equal(t, token.Hashed(), NewToken("my-secret-token").Hashed()) + + // Different input produces different hash. + assert.NotEqual(t, token.Hashed(), NewToken("different-token").Hashed()) +} + +func TestTokenFromHash(t *testing.T) { + token := TokenFromHash("abc123") + assert.Equal(t, "", token.Raw()) + assert.Equal(t, "abc123", token.Hashed()) +} + +func TestClientConnected(t *testing.T) { + var mu sync.Mutex + var reqMethod, reqPath, reqContentType, reqQuery string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + reqMethod = r.Method + reqPath = r.URL.Path + reqContentType = r.Header.Get("Content-Type") + reqQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + logger := zaptest.NewLogger(t) + client := NewClient(server.URL, "server-42", logger) + + token := NewToken("raw-token") + err := client.ClientConnected(token, "client-123", map[string]string{"broker_client_version": "1.0"}) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, http.MethodPost, reqMethod) + assert.Equal(t, "/internal/brokerservers/server-42/connections/"+token.Hashed(), reqPath) + assert.Equal(t, "application/vnd.api+json", reqContentType) + assert.Contains(t, reqQuery, "broker_client_id=client-123") + assert.Contains(t, reqQuery, "request_type=client-connected") + assert.Contains(t, reqQuery, "version="+dispatcherAPIVersion) +} + +func TestClientDisconnected(t *testing.T) { + var reqMethod, reqPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqMethod = r.Method + reqPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + logger := zaptest.NewLogger(t) + client := NewClient(server.URL, "server-42", logger) + + token := NewToken("raw-token") + err := client.ClientDisconnected(token, "client-123") + require.NoError(t, err) + assert.Equal(t, http.MethodDelete, reqMethod) + assert.Equal(t, "/internal/brokerservers/server-42/connections/"+token.Hashed(), reqPath) +} + +func TestServerStarting(t *testing.T) { + var reqMethod, reqPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqMethod = r.Method + reqPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + logger := zaptest.NewLogger(t) + client := NewClient(server.URL, "server-42", logger) + + err := client.ServerStarting("my-hostname") + require.NoError(t, err) + assert.Equal(t, http.MethodPost, reqMethod) + assert.Equal(t, "/internal/brokerservers/server-42", reqPath) +} + +func TestServerStopping(t *testing.T) { + var reqMethod, reqPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqMethod = r.Method + reqPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + logger := zaptest.NewLogger(t) + client := NewClient(server.URL, "server-42", logger) + + err := client.ServerStopping() + require.NoError(t, err) + assert.Equal(t, http.MethodDelete, reqMethod) + assert.Equal(t, "/internal/brokerservers/server-42", reqPath) +} + +func TestNotConfigured(t *testing.T) { + logger := zaptest.NewLogger(t) + client := NewClient("", "server-42", logger) + + assert.False(t, client.IsConfigured()) + + // All operations should be no-ops. + assert.NoError(t, client.ClientConnected(NewToken("t"), "c", nil)) + assert.NoError(t, client.ClientDisconnected(NewToken("t"), "c")) + assert.NoError(t, client.ServerStarting("host")) + assert.NoError(t, client.ServerStopping()) +} + +func TestServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + logger := zaptest.NewLogger(t) + client := NewClient(server.URL, "server-42", logger) + + err := client.ClientConnected(NewToken("token"), "client", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "status 500") +} diff --git a/server/cmd/main.go b/server/cmd/main.go new file mode 100644 index 0000000..33685ac --- /dev/null +++ b/server/cmd/main.go @@ -0,0 +1,176 @@ +package main + +import ( + "context" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/cortexapps/axon-server/broker" + "github.com/cortexapps/axon-server/config" + "github.com/cortexapps/axon-server/dispatch" + "github.com/cortexapps/axon-server/metrics" + "github.com/cortexapps/axon-server/tunnel" + pb "github.com/cortexapps/axon-server/.generated/proto/tunnelpb" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +func main() { + cfg := config.NewConfigFromEnv() + + // Set up structured JSON logging. + zapCfg := zap.NewProductionConfig() + zapCfg.EncoderConfig.TimeKey = "time" + zapCfg.EncoderConfig.EncodeTime = func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.UTC().Format("2006-01-02T15:04:05.000Z")) + } + if os.Getenv("ENV") != "production" { + zapCfg = zap.NewDevelopmentConfig() + } + logger, err := zapCfg.Build() + if err != nil { + panic(err) + } + logger = logger.Named("axon-tunnel-server") + defer logger.Sync() + + cfg.Print() + + // Initialize metrics. + m := metrics.New(cfg.ServerID) + defer m.Closer() + + // Initialize BROKER_SERVER client. + brokerClient := broker.NewClient(cfg.BrokerServerURL, cfg.ServerID, logger) + + // Initialize client registry and tunnel service. + registry := tunnel.NewClientRegistry(logger) + tunnelService := tunnel.NewService(cfg, logger, registry, brokerClient, m) + + // Create gRPC server with keepalive. + grpcServer := grpc.NewServer( + grpc.KeepaliveParams(keepalive.ServerParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + }), + grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + MinTime: 15 * time.Second, + PermitWithoutStream: true, + }), + ) + pb.RegisterTunnelServiceServer(grpcServer, tunnelService) + + // Start gRPC listener. + grpcLis, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.GrpcPort)) + if err != nil { + logger.Fatal("Failed to listen for gRPC", zap.Error(err)) + } + + // Initialize dispatch handler and wire response delivery. + dispatchHandler := dispatch.NewHandler(cfg, registry, m, logger) + tunnelService.SetResponseHandler(dispatchHandler.HandleResponse) + tunnelService.SetStreamCloseHandler(dispatchHandler.HandleStreamClose) + + // Start HTTP server for metrics, health, and dispatch. + httpMux := http.NewServeMux() + httpMux.Handle("/metrics", m.Handler()) + httpMux.Handle("/broker/", dispatchHandler) + httpMux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"status":"ok","server_id":%q,"clients":%d,"streams":%d,"inflight":%d,"broker_server_configured":%t}`, + cfg.ServerID, registry.Count(), registry.StreamCount(), dispatchHandler.PendingCount(), brokerClient.IsConfigured()) + }) + + httpServer := &http.Server{ + Addr: fmt.Sprintf(":%d", cfg.HttpPort), + Handler: httpMux, + } + + // Notify BROKER_SERVER that this server instance has started. + if brokerClient.IsConfigured() { + go func() { + backoff := time.Second + for { + if err := brokerClient.ServerStarting(cfg.ServerID); err != nil { + logger.Warn("BROKER_SERVER server-starting failed, retrying", + zap.Error(err), zap.Duration("backoff", backoff)) + time.Sleep(backoff) + backoff = min(backoff*2, 30*time.Second) + continue + } + logger.Info("BROKER_SERVER server-starting succeeded") + break + } + }() + } + + // Start periodic re-registration of all active clients. + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + if brokerClient.IsConfigured() { + go func() { + ticker := time.NewTicker(cfg.ReRegistrationInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + registry.ForEach(func(token broker.Token, identity tunnel.ClientIdentity) { + if err := brokerClient.ClientConnected(token, identity.InstanceID, nil); err != nil { + logger.Warn("Periodic re-registration failed", + zap.String("tenantId", identity.TenantID), + zap.Error(err)) + } + }) + } + } + }() + } + + // Start servers. + go func() { + logger.Info("Starting HTTP server", zap.Int("port", cfg.HttpPort)) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Fatal("HTTP server failed", zap.Error(err)) + } + }() + + go func() { + logger.Info("Starting gRPC server", zap.Int("port", cfg.GrpcPort)) + if err := grpcServer.Serve(grpcLis); err != nil { + logger.Fatal("gRPC server failed", zap.Error(err)) + } + }() + + // Wait for shutdown signal. + <-ctx.Done() + logger.Info("Shutting down...") + + // Graceful shutdown. + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + + // Notify BROKER_SERVER of shutdown before draining (best-effort). + if brokerClient.IsConfigured() { + if err := brokerClient.ServerStopping(); err != nil { + logger.Warn("BROKER_SERVER server-stopping failed", zap.Error(err)) + } else { + logger.Info("BROKER_SERVER server-stopping succeeded") + } + } + + // Stop accepting new connections and drain. + grpcServer.GracefulStop() + httpServer.Shutdown(shutdownCtx) + + logger.Info("Server stopped") +} diff --git a/server/config/config.go b/server/config/config.go new file mode 100644 index 0000000..ab19436 --- /dev/null +++ b/server/config/config.go @@ -0,0 +1,119 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "time" + + "github.com/google/uuid" +) + +const ( + DefaultGrpcPort = 50052 + DefaultHttpPort = 8080 + DefaultHeartbeatInterval = 30 * time.Second +) + +type Config struct { + // GrpcPort is the port the gRPC tunnel server listens on. + GrpcPort int + // HttpPort is the port the HTTP dispatch server listens on. + HttpPort int + // BrokerServerURL is the base URL of the BROKER_SERVER HTTP API + // for client-connected/deleted and server-connected/deleted notifications. + BrokerServerURL string + // JWTPublicKeyPath is the path to a PEM-encoded public key for + // validating JWT tokens in ClientHello. Empty disables JWT validation. + JWTPublicKeyPath string + // HeartbeatInterval is how often the server sends heartbeat messages. + // Clients must respond within 2x this interval. + HeartbeatInterval time.Duration + // DispatchTimeout is the maximum time to wait for a client response + // to a dispatched HTTP request. + DispatchTimeout time.Duration + // ServerID identifies this server instance. Used in metrics and + // returned to clients in ServerHello for dedup. + ServerID string + // ReRegistrationInterval is how often the server re-sends + // client-connected notifications to BROKER_SERVER as a TTL refresh. + ReRegistrationInterval time.Duration +} + +func (c Config) Print() { + fmt.Println("Server Configuration:") + fmt.Printf("\tgRPC Port: %d\n", c.GrpcPort) + fmt.Printf("\tHTTP Port: %d\n", c.HttpPort) + fmt.Printf("\tBroker Server URL: %s\n", c.BrokerServerURL) + fmt.Printf("\tServer ID: %s\n", c.ServerID) + fmt.Printf("\tHeartbeat Interval: %v\n", c.HeartbeatInterval) + fmt.Printf("\tDispatch Timeout: %v\n", c.DispatchTimeout) + if c.JWTPublicKeyPath != "" { + fmt.Printf("\tJWT Public Key: %s\n", c.JWTPublicKeyPath) + } else { + fmt.Println("\tJWT Validation: Disabled") + } +} + +func NewConfigFromEnv() Config { + cfg := Config{ + GrpcPort: DefaultGrpcPort, + HttpPort: DefaultHttpPort, + HeartbeatInterval: DefaultHeartbeatInterval, + DispatchTimeout: 60 * time.Second, + ServerID: getServerID(), + ReRegistrationInterval: 5 * time.Minute, + } + + if v := os.Getenv("GRPC_PORT"); v != "" { + p, err := strconv.Atoi(v) + if err != nil { + panic(fmt.Errorf("invalid GRPC_PORT: %w", err)) + } + cfg.GrpcPort = p + } + + if v := os.Getenv("HTTP_PORT"); v != "" { + p, err := strconv.Atoi(v) + if err != nil { + panic(fmt.Errorf("invalid HTTP_PORT: %w", err)) + } + cfg.HttpPort = p + } + + cfg.BrokerServerURL = os.Getenv("BROKER_SERVER_URL") + cfg.JWTPublicKeyPath = os.Getenv("JWT_PUBLIC_KEY_PATH") + + if v := os.Getenv("HEARTBEAT_INTERVAL"); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + panic(fmt.Errorf("invalid HEARTBEAT_INTERVAL: %w", err)) + } + cfg.HeartbeatInterval = d + } + + if v := os.Getenv("DISPATCH_TIMEOUT"); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + panic(fmt.Errorf("invalid DISPATCH_TIMEOUT: %w", err)) + } + cfg.DispatchTimeout = d + } + + if v := os.Getenv("RE_REGISTRATION_INTERVAL"); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + panic(fmt.Errorf("invalid RE_REGISTRATION_INTERVAL: %w", err)) + } + cfg.ReRegistrationInterval = d + } + + return cfg +} + +func getServerID() string { + if h := os.Getenv("HOSTNAME"); h != "" && h != "localhost" { + return h + } + return uuid.New().String() +} diff --git a/server/dispatch/handler.go b/server/dispatch/handler.go new file mode 100644 index 0000000..a4ea53c --- /dev/null +++ b/server/dispatch/handler.go @@ -0,0 +1,236 @@ +package dispatch + +import ( + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/cortexapps/axon-server/broker" + "github.com/cortexapps/axon-server/config" + "github.com/cortexapps/axon-server/metrics" + "github.com/cortexapps/axon-server/tunnel" + pb "github.com/cortexapps/axon-server/.generated/proto/tunnelpb" + "github.com/google/uuid" + "go.uber.org/zap" +) + +const maxChunkSize = 1024 * 1024 // 1MB +const maxRequestBodySize = 100 * 1024 * 1024 // 100MB + +// Handler is the HTTP handler that dispatches requests through tunnel streams. +// It mounts at /broker/:token/* and routes HTTP requests to connected agents. +type Handler struct { + registry *tunnel.ClientRegistry + pending *PendingRequests + metrics *metrics.Metrics + logger *zap.Logger + dispatchTimeout time.Duration +} + +// NewHandler creates a new dispatch handler. +func NewHandler( + cfg config.Config, + registry *tunnel.ClientRegistry, + m *metrics.Metrics, + logger *zap.Logger, +) *Handler { + return &Handler{ + registry: registry, + pending: NewPendingRequests(cfg.DispatchTimeout), + metrics: m, + logger: logger.Named("dispatch"), + dispatchTimeout: cfg.DispatchTimeout, + } +} + +// HandleResponse processes an incoming HttpResponse from the tunnel service. +// This is the ResponseHandler callback set on the tunnel service. +func (h *Handler) HandleResponse(response *pb.HttpResponse) { + if err := h.pending.Deliver(response); err != nil { + h.logger.Debug("Response delivery failed", zap.String("requestId", response.RequestId), zap.Error(err)) + } +} + +// HandleStreamClose fails all pending dispatch requests for a closed stream. +func (h *Handler) HandleStreamClose(streamID string) { + h.pending.FailStream(streamID) +} + +// PendingCount returns the number of inflight dispatch requests. +func (h *Handler) PendingCount() int { + return h.pending.Count() +} + +// ServeHTTP handles HTTP requests at /broker//. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Extract token and path from URL: /broker// + trimmed := strings.TrimPrefix(r.URL.Path, "/broker/") + slashIdx := strings.Index(trimmed, "/") + if slashIdx == -1 { + http.Error(w, "invalid path: missing token", http.StatusBadRequest) + return + } + + tokenOrHash := trimmed[:slashIdx] + dispatchPath := trimmed[slashIdx:] + + // Try as raw token first, then as already-hashed. + token := broker.NewToken(tokenOrHash) + identity := h.registry.GetIdentity(token) + if identity == nil { + token = broker.TokenFromHash(tokenOrHash) + identity = h.registry.GetIdentity(token) + } + + // Pick a stream for dispatch (round-robin). + stream := h.registry.PickStream(token) + if stream == nil { + h.logger.Warn("No tunnel available for token", + zap.String("hashedToken", token.Hashed()), + ) + h.metrics.DispatchErrors.Inc(1) + http.Error(w, "no tunnel available", http.StatusBadGateway) + return + } + + // Read request body with size limit to prevent OOM. + var body []byte + if r.Body != nil { + limitedReader := io.LimitReader(r.Body, maxRequestBodySize+1) + var err error + body, err = io.ReadAll(limitedReader) + if err != nil { + h.logger.Error("Failed to read request body", zap.Error(err)) + http.Error(w, "failed to read request body", http.StatusInternalServerError) + return + } + if len(body) > maxRequestBodySize { + http.Error(w, "request body too large", http.StatusRequestEntityTooLarge) + return + } + } + + // Extract request headers. + headers := make(map[string]string, len(r.Header)) + for k, v := range r.Header { + headers[k] = strings.Join(v, ", ") + } + + requestID := uuid.New().String() + + h.logger.Debug("Dispatching request", + zap.String("requestId", requestID), + zap.String("method", r.Method), + zap.String("path", dispatchPath), + zap.String("streamId", stream.StreamID), + ) + + start := time.Now() + h.metrics.DispatchInflight.Update(float64(h.pending.Count() + 1)) + + // Register pending response before sending. + respCh := h.pending.Add(requestID, stream.StreamID) + + // Send request through the tunnel (chunked if needed). + timeoutMs := int32(h.dispatchTimeout.Milliseconds()) + if err := h.sendRequest(stream, requestID, r.Method, dispatchPath, headers, body, timeoutMs); err != nil { + h.pending.Timeout(requestID) + h.logger.Error("Failed to send request through tunnel", zap.Error(err)) + h.metrics.DispatchErrors.Inc(1) + http.Error(w, "tunnel send failed", http.StatusBadGateway) + return + } + + h.metrics.DispatchBytesSent.Inc(int64(len(body))) + + // Wait for response. + resp, ok := <-respCh + duration := time.Since(start) + h.metrics.DispatchInflight.Update(float64(h.pending.Count())) + + if !ok || resp == nil { + h.logger.Warn("Dispatch timeout or stream closed", + zap.String("requestId", requestID), + zap.Duration("duration", duration), + ) + h.metrics.DispatchErrors.Inc(1) + http.Error(w, "gateway timeout", http.StatusGatewayTimeout) + return + } + + // Record tagged metrics if we have identity info. + tenantID, integration, alias := "", "", "" + if identity != nil { + tenantID = identity.TenantID + integration = identity.Integration + alias = identity.Alias + } + h.metrics.DispatchCount(tenantID, integration, alias, r.Method, resp.StatusCode) + h.metrics.DispatchDuration(tenantID, integration, alias, float64(duration.Milliseconds())) + h.metrics.DispatchBytesRecv.Inc(int64(len(resp.Body))) + + // Write response. + for k, v := range resp.Headers { + w.Header().Set(k, v) + } + w.WriteHeader(resp.StatusCode) + if len(resp.Body) > 0 { + w.Write(resp.Body) + } +} + +// sendRequest sends an HTTP request through a tunnel stream, chunking large bodies. +func (h *Handler) sendRequest(stream *tunnel.StreamHandle, requestID, method, path string, headers map[string]string, body []byte, timeoutMs int32) error { + if len(body) <= maxChunkSize { + return stream.Send(&pb.TunnelServerMessage{ + Message: &pb.TunnelServerMessage_HttpRequest{ + HttpRequest: &pb.HttpRequest{ + RequestId: requestID, + Method: method, + Path: path, + Headers: headers, + Body: body, + ChunkIndex: 0, + IsFinal: true, + TimeoutMs: timeoutMs, + }, + }, + }) + } + + // Chunked send for large bodies. + for i := 0; i < len(body); i += maxChunkSize { + end := i + maxChunkSize + if end > len(body) { + end = len(body) + } + chunkIndex := int32(i / maxChunkSize) + isFinal := end == len(body) + + req := &pb.HttpRequest{ + RequestId: requestID, + ChunkIndex: chunkIndex, + IsFinal: isFinal, + Body: body[i:end], + } + + // First chunk includes method/path/headers and timeout. + if chunkIndex == 0 { + req.Method = method + req.Path = path + req.Headers = headers + req.TimeoutMs = timeoutMs + } + + if err := stream.Send(&pb.TunnelServerMessage{ + Message: &pb.TunnelServerMessage_HttpRequest{ + HttpRequest: req, + }, + }); err != nil { + return fmt.Errorf("send chunk %d: %w", chunkIndex, err) + } + } + return nil +} diff --git a/server/dispatch/pending_requests.go b/server/dispatch/pending_requests.go new file mode 100644 index 0000000..870e361 --- /dev/null +++ b/server/dispatch/pending_requests.go @@ -0,0 +1,156 @@ +package dispatch + +import ( + "fmt" + "sync" + "time" + + pb "github.com/cortexapps/axon-server/.generated/proto/tunnelpb" +) + +// PendingRequests tracks inflight HTTP requests dispatched through tunnels, +// correlating request IDs to response channels. It supports chunked responses. +type PendingRequests struct { + mu sync.RWMutex + pending map[string]*pendingEntry + timeout time.Duration +} + +type pendingEntry struct { + ch chan *assembledResponse + streamID string + chunks []*pb.HttpResponse + timer *time.Timer +} + +type assembledResponse struct { + StatusCode int + Headers map[string]string + Body []byte +} + +// NewPendingRequests creates a new pending requests tracker. +func NewPendingRequests(timeout time.Duration) *PendingRequests { + return &PendingRequests{ + pending: make(map[string]*pendingEntry), + timeout: timeout, + } +} + +// Add registers a new pending request and returns a channel to await the response. +func (pr *PendingRequests) Add(requestID, streamID string) <-chan *assembledResponse { + ch := make(chan *assembledResponse, 1) + + pr.mu.Lock() + defer pr.mu.Unlock() + + timer := time.AfterFunc(pr.timeout, func() { + pr.Timeout(requestID) + }) + + pr.pending[requestID] = &pendingEntry{ + ch: ch, + streamID: streamID, + timer: timer, + } + return ch +} + +// Deliver processes an incoming HttpResponse chunk. When the final chunk is +// received (is_final=true), the assembled response is sent to the waiting channel. +func (pr *PendingRequests) Deliver(response *pb.HttpResponse) error { + pr.mu.Lock() + entry, ok := pr.pending[response.RequestId] + if !ok { + pr.mu.Unlock() + return fmt.Errorf("no pending request for ID %s", response.RequestId) + } + + entry.chunks = append(entry.chunks, response) + + if !response.IsFinal { + pr.mu.Unlock() + return nil + } + + // Final chunk received — assemble and deliver. + entry.timer.Stop() + delete(pr.pending, response.RequestId) + pr.mu.Unlock() + + assembled := assembleChunks(entry.chunks) + entry.ch <- assembled + return nil +} + +// Timeout fails a pending request with a timeout error. +func (pr *PendingRequests) Timeout(requestID string) { + pr.mu.Lock() + entry, ok := pr.pending[requestID] + if !ok { + pr.mu.Unlock() + return + } + delete(pr.pending, requestID) + pr.mu.Unlock() + + // Send nil to indicate timeout. + close(entry.ch) +} + +// FailStream fails all pending requests that were dispatched on a given stream. +func (pr *PendingRequests) FailStream(streamID string) { + pr.mu.Lock() + var toFail []*pendingEntry + var toDelete []string + + for reqID, entry := range pr.pending { + if entry.streamID == streamID { + toFail = append(toFail, entry) + toDelete = append(toDelete, reqID) + } + } + for _, reqID := range toDelete { + delete(pr.pending, reqID) + } + pr.mu.Unlock() + + for _, entry := range toFail { + entry.timer.Stop() + close(entry.ch) + } +} + +// Count returns the number of inflight requests. +func (pr *PendingRequests) Count() int { + pr.mu.RLock() + defer pr.mu.RUnlock() + return len(pr.pending) +} + +// assembleChunks reconstructs a full response from ordered chunks. +func assembleChunks(chunks []*pb.HttpResponse) *assembledResponse { + if len(chunks) == 0 { + return &assembledResponse{} + } + + // First chunk has status code and headers. + first := chunks[0] + resp := &assembledResponse{ + StatusCode: int(first.StatusCode), + Headers: first.Headers, + } + + // Concatenate all body chunks. + var totalLen int + for _, c := range chunks { + totalLen += len(c.Body) + } + body := make([]byte, 0, totalLen) + for _, c := range chunks { + body = append(body, c.Body...) + } + resp.Body = body + + return resp +} diff --git a/server/dispatch/pending_requests_test.go b/server/dispatch/pending_requests_test.go new file mode 100644 index 0000000..62aac45 --- /dev/null +++ b/server/dispatch/pending_requests_test.go @@ -0,0 +1,118 @@ +package dispatch + +import ( + "testing" + "time" + + pb "github.com/cortexapps/axon-server/.generated/proto/tunnelpb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPendingRequests_BasicFlow(t *testing.T) { + pr := NewPendingRequests(5 * time.Second) + + ch := pr.Add("req-1", "stream-1") + assert.Equal(t, 1, pr.Count()) + + err := pr.Deliver(&pb.HttpResponse{ + RequestId: "req-1", + StatusCode: 200, + Headers: map[string]string{"Content-Type": "application/json"}, + Body: []byte(`{"ok":true}`), + ChunkIndex: 0, + IsFinal: true, + }) + require.NoError(t, err) + + resp := <-ch + require.NotNil(t, resp) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, `{"ok":true}`, string(resp.Body)) + assert.Equal(t, "application/json", resp.Headers["Content-Type"]) + assert.Equal(t, 0, pr.Count()) +} + +func TestPendingRequests_ChunkedResponse(t *testing.T) { + pr := NewPendingRequests(5 * time.Second) + + ch := pr.Add("req-1", "stream-1") + + // Chunk 0 - headers + partial body. + err := pr.Deliver(&pb.HttpResponse{ + RequestId: "req-1", + StatusCode: 200, + Headers: map[string]string{"Content-Type": "text/plain"}, + Body: []byte("Hello "), + ChunkIndex: 0, + IsFinal: false, + }) + require.NoError(t, err) + + // Chunk 1 - final body. + err = pr.Deliver(&pb.HttpResponse{ + RequestId: "req-1", + Body: []byte("World!"), + ChunkIndex: 1, + IsFinal: true, + }) + require.NoError(t, err) + + resp := <-ch + require.NotNil(t, resp) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, "Hello World!", string(resp.Body)) +} + +func TestPendingRequests_Timeout(t *testing.T) { + pr := NewPendingRequests(100 * time.Millisecond) + + ch := pr.Add("req-1", "stream-1") + + // Wait for timeout. + resp, ok := <-ch + assert.False(t, ok) + assert.Nil(t, resp) + assert.Equal(t, 0, pr.Count()) +} + +func TestPendingRequests_FailStream(t *testing.T) { + pr := NewPendingRequests(5 * time.Second) + + ch1 := pr.Add("req-1", "stream-1") + ch2 := pr.Add("req-2", "stream-1") + ch3 := pr.Add("req-3", "stream-2") + + pr.FailStream("stream-1") + + // Requests on stream-1 should fail. + _, ok := <-ch1 + assert.False(t, ok) + _, ok = <-ch2 + assert.False(t, ok) + + // Request on stream-2 should still be pending. + assert.Equal(t, 1, pr.Count()) + + // Deliver to stream-2 request normally. + pr.Deliver(&pb.HttpResponse{ + RequestId: "req-3", + StatusCode: 200, + Body: []byte("ok"), + IsFinal: true, + }) + + resp := <-ch3 + require.NotNil(t, resp) + assert.Equal(t, 200, resp.StatusCode) +} + +func TestPendingRequests_UnknownRequestID(t *testing.T) { + pr := NewPendingRequests(5 * time.Second) + + err := pr.Deliver(&pb.HttpResponse{ + RequestId: "unknown", + IsFinal: true, + }) + assert.Error(t, err) +} diff --git a/server/docker/Dockerfile b/server/docker/Dockerfile new file mode 100644 index 0000000..48cf623 --- /dev/null +++ b/server/docker/Dockerfile @@ -0,0 +1,24 @@ +FROM golang:1.25-bookworm AS builder + +RUN apt-get update && apt-get install -y protobuf-compiler + +WORKDIR /build + +# Copy proto definitions. +COPY proto/ /proto/ + +# Copy server source and generate proto. +COPY server/ /build/ +RUN make setup proto +RUN go build -o /axon-tunnel-server ./cmd/... + +# Runtime image. +FROM debian:bookworm-slim + +RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* + +COPY --from=builder /axon-tunnel-server /usr/local/bin/axon-tunnel-server + +EXPOSE 50052 8080 + +ENTRYPOINT ["/usr/local/bin/axon-tunnel-server"] diff --git a/server/go.mod b/server/go.mod new file mode 100644 index 0000000..d520020 --- /dev/null +++ b/server/go.mod @@ -0,0 +1,35 @@ +module github.com/cortexapps/axon-server + +go 1.25.7 + +require ( + github.com/google/uuid v1.6.0 + github.com/prometheus/client_golang v1.23.2 + github.com/stretchr/testify v1.11.1 + github.com/uber-go/tally/v4 v4.1.17 + go.uber.org/zap v1.27.1 + google.golang.org/grpc v1.79.2 + google.golang.org/protobuf v1.36.11 +) + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang/mock v1.6.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect + github.com/twmb/murmur3 v1.1.8 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/server/go.sum b/server/go.sum new file mode 100644 index 0000000..576c139 --- /dev/null +++ b/server/go.sum @@ -0,0 +1,112 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twmb/murmur3 v1.1.8 h1:8Yt9taO/WN3l08xErzjeschgZU2QSrwm1kclYq+0aRg= +github.com/twmb/murmur3 v1.1.8/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= +github.com/uber-go/tally/v4 v4.1.17 h1:C+U4BKtVDXTszuzU+WH8JVQvRVnaVKxzZrROFyDrvS8= +github.com/uber-go/tally/v4 v4.1.17/go.mod h1:ZdpiHRGSa3z4NIAc1VlEH4SiknR885fOIF08xmS0gaU= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU= +google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/server/metrics/metrics.go b/server/metrics/metrics.go new file mode 100644 index 0000000..8dc0917 --- /dev/null +++ b/server/metrics/metrics.go @@ -0,0 +1,119 @@ +package metrics + +import ( + "net/http" + + prom "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/uber-go/tally/v4" + tallyprom "github.com/uber-go/tally/v4/prometheus" +) + +// Labels used for tagging all metrics. +const ( + LabelServerID = "server_id" + LabelTenantID = "tenant_id" + LabelIntegration = "integration" + LabelAlias = "alias" + LabelMethod = "method" + LabelStatusCode = "status_code" + LabelErrorType = "error_type" +) + +// Metrics holds all server-side metric instruments. +type Metrics struct { + Scope tally.Scope + Closer func() + Registry *prom.Registry + + ConnectionsActive tally.Gauge + ConnectionsTotal tally.Counter + HeartbeatSent tally.Counter + HeartbeatReceived tally.Counter + HeartbeatMissed tally.Counter + DispatchInflight tally.Gauge + DispatchErrors tally.Counter + DispatchBytesSent tally.Counter + DispatchBytesRecv tally.Counter + AuthFailures tally.Counter +} + +// New creates a new Metrics instance backed by a Prometheus reporter. +func New(serverID string) *Metrics { + registry := prom.NewRegistry() + + reporter := tallyprom.NewReporter(tallyprom.Options{ + Registerer: registry, + Gatherer: registry, + }) + + scope, closer := tally.NewRootScope(tally.ScopeOptions{ + Tags: map[string]string{LabelServerID: serverID}, + CachedReporter: reporter, + Prefix: "tunnel", + Separator: tallyprom.DefaultSeparator, + }, 0) + + m := &Metrics{ + Scope: scope, + Closer: func() { closer.Close() }, + Registry: registry, + + ConnectionsActive: scope.Gauge("connections.active"), + ConnectionsTotal: scope.Counter("connections.total"), + HeartbeatSent: scope.Counter("heartbeat.sent"), + HeartbeatReceived: scope.Counter("heartbeat.received"), + HeartbeatMissed: scope.Counter("heartbeat.missed"), + DispatchInflight: scope.Gauge("dispatch.inflight"), + DispatchErrors: scope.Counter("dispatch.errors"), + DispatchBytesSent: scope.Counter("dispatch.bytes_sent"), + DispatchBytesRecv: scope.Counter("dispatch.bytes_received"), + AuthFailures: scope.Counter("auth.failures"), + } + + return m +} + +// DispatchCount returns a tagged counter for dispatch operations. +func (m *Metrics) DispatchCount(tenantID, integration, alias, method string, statusCode int) { + m.Scope.Tagged(map[string]string{ + LabelTenantID: tenantID, + LabelIntegration: integration, + LabelAlias: alias, + LabelMethod: method, + LabelStatusCode: http.StatusText(statusCode), + }).Counter("dispatch.count").Inc(1) +} + +// DispatchDuration records dispatch latency. +func (m *Metrics) DispatchDuration(tenantID, integration, alias string, d float64) { + m.Scope.Tagged(map[string]string{ + LabelTenantID: tenantID, + LabelIntegration: integration, + LabelAlias: alias, + }).Histogram("dispatch.duration_ms", tally.DefaultBuckets).RecordValue(d) +} + +// DispatchError records a dispatch error. +func (m *Metrics) DispatchError(tenantID, integration, alias, errorType string) { + m.Scope.Tagged(map[string]string{ + LabelTenantID: tenantID, + LabelIntegration: integration, + LabelAlias: alias, + LabelErrorType: errorType, + }).Counter("dispatch.errors").Inc(1) +} + +// StreamDuration records how long a tunnel stream was alive. +func (m *Metrics) StreamDuration(tenantID, integration, alias string) tally.Stopwatch { + return m.Scope.Tagged(map[string]string{ + LabelTenantID: tenantID, + LabelIntegration: integration, + LabelAlias: alias, + }).Timer("stream.duration_seconds").Start() +} + +// Handler returns an HTTP handler for the /metrics endpoint. +func (m *Metrics) Handler() http.Handler { + return promhttp.HandlerFor(m.Registry, promhttp.HandlerOpts{}) +} diff --git a/server/tunnel/client_registry.go b/server/tunnel/client_registry.go new file mode 100644 index 0000000..fff0d86 --- /dev/null +++ b/server/tunnel/client_registry.go @@ -0,0 +1,206 @@ +package tunnel + +import ( + "fmt" + "sort" + "sync" + "sync/atomic" + + pb "github.com/cortexapps/axon-server/.generated/proto/tunnelpb" + "github.com/cortexapps/axon-server/broker" + "go.uber.org/zap" +) + +// StreamHandle represents a single tunnel stream to a client. +type StreamHandle struct { + StreamID string + // Send sends a TunnelServerMessage to the client through this stream. + Send func(msg *pb.TunnelServerMessage) error + // Cancel closes this stream. + Cancel func() +} + +// ClientIdentity holds the identity metadata for a connected client. +type ClientIdentity struct { + TenantID string + Integration string + Alias string + InstanceID string +} + +// clientEntry represents all connections for a single broker token. +type clientEntry struct { + Identity ClientIdentity + Token broker.Token + Streams map[string]*StreamHandle // streamID -> handle + BrokerServerRegistered atomic.Bool + roundRobin atomic.Uint64 +} + +// ClientRegistry is a thread-safe registry of connected clients, +// keyed by hashed broker token. +type ClientRegistry struct { + mu sync.RWMutex + entries map[string]*clientEntry // hashed token -> entry + logger *zap.Logger +} + +// NewClientRegistry creates a new client registry. +func NewClientRegistry(logger *zap.Logger) *ClientRegistry { + return &ClientRegistry{ + entries: make(map[string]*clientEntry), + logger: logger, + } +} + +// Register adds a new stream for a broker token. +// If the token already exists, it validates the identity matches (same tenant) +// and adds the stream. Returns an error on identity collision. +func (r *ClientRegistry) Register(token broker.Token, identity ClientIdentity, stream *StreamHandle) error { + r.mu.Lock() + defer r.mu.Unlock() + + key := token.Hashed() + if existing, ok := r.entries[key]; ok { + // Same tenant is allowed (reconnect or new instance) + if existing.Identity.TenantID != identity.TenantID { + return fmt.Errorf("token collision: different tenant_id for token (existing=%s, new=%s)", + existing.Identity.TenantID, identity.TenantID) + } + existing.Streams[stream.StreamID] = stream + r.logger.Info("Added stream to existing client entry", + zap.String("tenantId", identity.TenantID), + zap.String("instanceId", identity.InstanceID), + zap.String("streamId", stream.StreamID), + zap.Int("totalStreams", len(existing.Streams)), + ) + return nil + } + + r.entries[key] = &clientEntry{ + Identity: identity, + Token: token, + Streams: map[string]*StreamHandle{stream.StreamID: stream}, + } + + r.logger.Info("Registered new client", + zap.String("tenantId", identity.TenantID), + zap.String("integration", identity.Integration), + zap.String("alias", identity.Alias), + zap.String("instanceId", identity.InstanceID), + zap.String("streamId", stream.StreamID), + ) + return nil +} + +// Unregister removes a specific stream for a token. +// If it was the last stream, the entire entry is removed. +// Returns true if the entire entry was removed. +func (r *ClientRegistry) Unregister(token broker.Token, streamID string) bool { + r.mu.Lock() + defer r.mu.Unlock() + + key := token.Hashed() + entry, ok := r.entries[key] + if !ok { + return false + } + + delete(entry.Streams, streamID) + + if len(entry.Streams) == 0 { + delete(r.entries, key) + r.logger.Info("Removed client entry (last stream closed)", + zap.String("tenantId", entry.Identity.TenantID), + zap.String("streamId", streamID), + ) + return true + } + + r.logger.Info("Removed stream from client entry", + zap.String("tenantId", entry.Identity.TenantID), + zap.String("streamId", streamID), + zap.Int("remainingStreams", len(entry.Streams)), + ) + return false +} + +// GetIdentity returns the identity for a token, or nil if not found. +func (r *ClientRegistry) GetIdentity(token broker.Token) *ClientIdentity { + r.mu.RLock() + defer r.mu.RUnlock() + + entry, ok := r.entries[token.Hashed()] + if !ok { + return nil + } + id := entry.Identity + return &id +} + +// PickStream returns a stream handle for dispatching via round-robin. +// Returns nil if no streams are available for the token. +func (r *ClientRegistry) PickStream(token broker.Token) *StreamHandle { + r.mu.RLock() + defer r.mu.RUnlock() + + entry, ok := r.entries[token.Hashed()] + if !ok || len(entry.Streams) == 0 { + return nil + } + + // Collect stream handles into a slice for round-robin. + streams := make([]*StreamHandle, 0, len(entry.Streams)) + for _, s := range entry.Streams { + streams = append(streams, s) + } + + // Sort by StreamID to ensure deterministic ordering across calls, + // since Go map iteration order is randomized. + sort.Slice(streams, func(i, j int) bool { + return streams[i].StreamID < streams[j].StreamID + }) + + idx := entry.roundRobin.Add(1) - 1 + return streams[idx%uint64(len(streams))] +} + +// SetBrokerServerRegistered marks a token as successfully registered with BROKER_SERVER. +func (r *ClientRegistry) SetBrokerServerRegistered(token broker.Token) { + r.mu.RLock() + defer r.mu.RUnlock() + + if entry, ok := r.entries[token.Hashed()]; ok { + entry.BrokerServerRegistered.Store(true) + } +} + +// ForEach calls fn for each registered client entry. +// Used for periodic re-registration. +func (r *ClientRegistry) ForEach(fn func(token broker.Token, identity ClientIdentity)) { + r.mu.RLock() + defer r.mu.RUnlock() + + for _, entry := range r.entries { + fn(entry.Token, entry.Identity) + } +} + +// Count returns the number of registered client entries. +func (r *ClientRegistry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.entries) +} + +// StreamCount returns the total number of active streams across all clients. +func (r *ClientRegistry) StreamCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + + total := 0 + for _, entry := range r.entries { + total += len(entry.Streams) + } + return total +} diff --git a/server/tunnel/client_registry_test.go b/server/tunnel/client_registry_test.go new file mode 100644 index 0000000..8d2582d --- /dev/null +++ b/server/tunnel/client_registry_test.go @@ -0,0 +1,174 @@ +package tunnel + +import ( + "testing" + + pb "github.com/cortexapps/axon-server/.generated/proto/tunnelpb" + "github.com/cortexapps/axon-server/broker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func testIdentity(tenantID string) ClientIdentity { + return ClientIdentity{ + TenantID: tenantID, + Integration: "github", + Alias: "my-github", + InstanceID: "instance-1", + } +} + +func testStream(streamID string) *StreamHandle { + return &StreamHandle{ + StreamID: streamID, + Send: func(msg *pb.TunnelServerMessage) error { return nil }, + Cancel: func() {}, + } +} + +func TestRegisterAndLookup(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + token := broker.NewToken("token-abc") + identity := testIdentity("tenant-1") + stream := testStream("stream-1") + + err := registry.Register(token, identity, stream) + require.NoError(t, err) + + assert.Equal(t, 1, registry.Count()) + assert.Equal(t, 1, registry.StreamCount()) + + got := registry.GetIdentity(token) + require.NotNil(t, got) + assert.Equal(t, "tenant-1", got.TenantID) + assert.Equal(t, "github", got.Integration) + assert.Equal(t, "my-github", got.Alias) +} + +func TestRegisterMultipleStreams(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + token := broker.NewToken("token-abc") + identity := testIdentity("tenant-1") + + err := registry.Register(token, identity, testStream("stream-1")) + require.NoError(t, err) + + // Same tenant, different instance — allowed. + identity2 := identity + identity2.InstanceID = "instance-2" + err = registry.Register(token, identity2, testStream("stream-2")) + require.NoError(t, err) + + assert.Equal(t, 1, registry.Count()) + assert.Equal(t, 2, registry.StreamCount()) +} + +func TestRegisterTokenCollision(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + token := broker.NewToken("token-abc") + + err := registry.Register(token, testIdentity("tenant-1"), testStream("stream-1")) + require.NoError(t, err) + + // Different tenant with same token hash — rejected. + err = registry.Register(token, testIdentity("tenant-2"), testStream("stream-2")) + require.Error(t, err) + assert.Contains(t, err.Error(), "token collision") +} + +func TestUnregisterStream(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + token := broker.NewToken("token-abc") + identity := testIdentity("tenant-1") + registry.Register(token, identity, testStream("stream-1")) + registry.Register(token, identity, testStream("stream-2")) + + // Remove one stream — entry should remain. + removed := registry.Unregister(token, "stream-1") + assert.False(t, removed) + assert.Equal(t, 1, registry.Count()) + assert.Equal(t, 1, registry.StreamCount()) + + // Remove last stream — entry should be removed. + removed = registry.Unregister(token, "stream-2") + assert.True(t, removed) + assert.Equal(t, 0, registry.Count()) + assert.Equal(t, 0, registry.StreamCount()) + assert.Nil(t, registry.GetIdentity(token)) +} + +func TestUnregisterNonexistent(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + removed := registry.Unregister(broker.TokenFromHash("no-such-hash"), "stream-1") + assert.False(t, removed) +} + +func TestPickStreamRoundRobin(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + token := broker.NewToken("token-abc") + identity := testIdentity("tenant-1") + registry.Register(token, identity, testStream("stream-1")) + registry.Register(token, identity, testStream("stream-2")) + + // Pick multiple times and verify we get both streams. + seen := map[string]bool{} + for range 10 { + s := registry.PickStream(token) + require.NotNil(t, s) + seen[s.StreamID] = true + } + assert.True(t, seen["stream-1"], "should pick stream-1") + assert.True(t, seen["stream-2"], "should pick stream-2") +} + +func TestPickStreamNoEntry(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + s := registry.PickStream(broker.TokenFromHash("no-such-hash")) + assert.Nil(t, s) +} + +func TestBrokerServerRegistered(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + token := broker.NewToken("token-abc") + identity := testIdentity("tenant-1") + registry.Register(token, identity, testStream("stream-1")) + + // Not registered initially. + registry.SetBrokerServerRegistered(token) + + // Verify no panic on non-existent entry. + registry.SetBrokerServerRegistered(broker.TokenFromHash("no-such-hash")) +} + +func TestForEach(t *testing.T) { + logger := zaptest.NewLogger(t) + registry := NewClientRegistry(logger) + + registry.Register(broker.NewToken("token-1"), testIdentity("tenant-1"), testStream("s1")) + registry.Register(broker.NewToken("token-2"), testIdentity("tenant-2"), testStream("s2")) + + var entries []string + registry.ForEach(func(token broker.Token, identity ClientIdentity) { + entries = append(entries, identity.TenantID) + }) + assert.Len(t, entries, 2) + assert.Contains(t, entries, "tenant-1") + assert.Contains(t, entries, "tenant-2") +} diff --git a/server/tunnel/service.go b/server/tunnel/service.go new file mode 100644 index 0000000..4f921b4 --- /dev/null +++ b/server/tunnel/service.go @@ -0,0 +1,337 @@ +package tunnel + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/cortexapps/axon-server/broker" + "github.com/cortexapps/axon-server/config" + "github.com/cortexapps/axon-server/metrics" + pb "github.com/cortexapps/axon-server/.generated/proto/tunnelpb" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// ResponseHandler is called when an HttpResponse is received from a client. +// It's used to deliver responses to pending dispatch requests. +type ResponseHandler func(response *pb.HttpResponse) + +// StreamCloseHandler is called when a tunnel stream is closed. +// It's used to fail pending dispatch requests for the closed stream. +type StreamCloseHandler func(streamID string) + +// Service implements the TunnelService gRPC server. +type Service struct { + pb.UnimplementedTunnelServiceServer + + config config.Config + logger *zap.Logger + registry *ClientRegistry + brokerClient *broker.Client + metrics *metrics.Metrics + responseHandler ResponseHandler + streamCloseHandler StreamCloseHandler + + mu sync.RWMutex +} + +// NewService creates a new tunnel service. +func NewService( + cfg config.Config, + logger *zap.Logger, + registry *ClientRegistry, + brokerClient *broker.Client, + m *metrics.Metrics, +) *Service { + return &Service{ + config: cfg, + logger: logger, + registry: registry, + brokerClient: brokerClient, + metrics: m, + } +} + +// SetResponseHandler sets the callback for delivering HTTP responses +// to the dispatch layer. +func (s *Service) SetResponseHandler(handler ResponseHandler) { + s.mu.Lock() + defer s.mu.Unlock() + s.responseHandler = handler +} + +// SetStreamCloseHandler sets the callback for when a tunnel stream closes. +// This is used to fail pending dispatch requests for the closed stream. +func (s *Service) SetStreamCloseHandler(handler StreamCloseHandler) { + s.mu.Lock() + defer s.mu.Unlock() + s.streamCloseHandler = handler +} + +// Tunnel implements the bidirectional streaming RPC. +func (s *Service) Tunnel(stream pb.TunnelService_TunnelServer) error { + // Read ClientHello as the first message. + firstMsg, err := stream.Recv() + if err != nil { + return fmt.Errorf("recv ClientHello: %w", err) + } + + hello := firstMsg.GetHello() + if hello == nil { + return fmt.Errorf("first message must be ClientHello") + } + + // Validate required fields. + if hello.BrokerToken == "" { + return fmt.Errorf("broker_token is required") + } + if hello.TenantId == "" { + return fmt.Errorf("tenant_id is required") + } + + streamID := uuid.New().String() + token := broker.NewToken(hello.BrokerToken) + + identity := ClientIdentity{ + TenantID: hello.TenantId, + Integration: hello.Integration, + Alias: hello.Alias, + InstanceID: hello.InstanceId, + } + + s.logger.Info("Client connecting", + zap.String("tenantId", identity.TenantID), + zap.String("integration", identity.Integration), + zap.String("alias", identity.Alias), + zap.String("instanceId", identity.InstanceID), + zap.String("clientVersion", hello.ClientVersion), + zap.String("streamId", streamID), + ) + + // Create stream handle with a context for cancellation. + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + + sendMu := &sync.Mutex{} + handle := &StreamHandle{ + StreamID: streamID, + Send: func(msg *pb.TunnelServerMessage) error { + sendMu.Lock() + defer sendMu.Unlock() + return stream.Send(msg) + }, + Cancel: cancel, + } + + // Send ServerHello before registering so the handshake completes + // before the stream becomes dispatchable. Use sendMu for consistency. + sendMu.Lock() + err = stream.Send(&pb.TunnelServerMessage{ + Message: &pb.TunnelServerMessage_Hello{ + Hello: &pb.ServerHello{ + ServerId: s.config.ServerID, + HeartbeatIntervalMs: int32(s.config.HeartbeatInterval.Milliseconds()), + StreamId: streamID, + }, + }, + }) + sendMu.Unlock() + if err != nil { + return fmt.Errorf("send ServerHello: %w", err) + } + + // Register in client registry (now safe — handshake is done). + if err := s.registry.Register(token, identity, handle); err != nil { + s.logger.Error("Failed to register client", zap.Error(err)) + return err + } + s.metrics.ConnectionsActive.Update(float64(s.registry.StreamCount())) + s.metrics.ConnectionsTotal.Inc(1) + + // Start stream duration tracking. + stopwatch := s.metrics.StreamDuration(identity.TenantID, identity.Integration, identity.Alias) + + // Notify BROKER_SERVER asynchronously (infinite retry). + go s.notifyClientConnected(ctx, token, hello.InstanceId, hello.ClientVersion) + + // Start heartbeat sender. + heartbeatDone := make(chan struct{}) + go s.heartbeatSender(ctx, stream, sendMu, heartbeatDone) + + // Track last heartbeat for timeout detection using atomic for goroutine safety. + var lastHeartbeat atomic.Int64 + lastHeartbeat.Store(time.Now().UnixNano()) + + // Start heartbeat timeout monitor goroutine. + go func() { + timeout := 2 * s.config.HeartbeatInterval + ticker := time.NewTicker(timeout) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + last := time.Unix(0, lastHeartbeat.Load()) + if time.Since(last) > timeout { + s.logger.Warn("Heartbeat timeout — closing stream", + zap.String("streamId", streamID), + zap.String("tenantId", identity.TenantID), + zap.Duration("elapsed", time.Since(last)), + ) + s.metrics.HeartbeatMissed.Inc(1) + cancel() + return + } + } + } + }() + + // Read loop for client messages. + for { + select { + case <-ctx.Done(): + s.cleanupStream(token, streamID, stopwatch) + return nil + default: + } + + msg, err := stream.Recv() + if err != nil { + s.logger.Info("Client stream closed", + zap.String("streamId", streamID), + zap.String("tenantId", identity.TenantID), + zap.Error(err), + ) + s.cleanupStream(token, streamID, stopwatch) + return nil + } + + switch m := msg.Message.(type) { + case *pb.TunnelClientMessage_Heartbeat: + lastHeartbeat.Store(time.Now().UnixNano()) + s.metrics.HeartbeatReceived.Inc(1) + + case *pb.TunnelClientMessage_HttpResponse: + s.mu.RLock() + handler := s.responseHandler + s.mu.RUnlock() + if handler != nil { + handler(m.HttpResponse) + } + + case *pb.TunnelClientMessage_Hello: + s.logger.Warn("Received duplicate ClientHello, ignoring", + zap.String("streamId", streamID), + ) + } + } +} + +// cleanupStream removes a stream from the registry and notifies BROKER_SERVER. +func (s *Service) cleanupStream(token broker.Token, streamID string, stopwatch interface{ Stop() }) { + stopwatch.Stop() + + // Fail any pending dispatch requests for this stream. + s.mu.RLock() + closeHandler := s.streamCloseHandler + s.mu.RUnlock() + if closeHandler != nil { + closeHandler(streamID) + } + + // Fetch identity before unregistering so we can pass clientID to the disconnect notification. + var clientID string + if identity := s.registry.GetIdentity(token); identity != nil { + clientID = identity.InstanceID + } + + entryRemoved := s.registry.Unregister(token, streamID) + s.metrics.ConnectionsActive.Update(float64(s.registry.StreamCount())) + + // Only notify BROKER_SERVER if the entire entry was removed (last stream). + if entryRemoved { + go s.notifyClientDisconnected(token, clientID) + } +} + +// notifyClientConnected sends client-connected to BROKER_SERVER with infinite retry. +func (s *Service) notifyClientConnected(ctx context.Context, token broker.Token, clientID, clientVersion string) { + backoff := time.Second + maxBackoff := 30 * time.Second + + for { + err := s.brokerClient.ClientConnected(token, clientID, map[string]string{ + "broker_client_version": clientVersion, + }) + if err == nil { + s.registry.SetBrokerServerRegistered(token) + s.logger.Info("BROKER_SERVER client-connected succeeded", + zap.String("clientId", clientID), + ) + return + } + + s.logger.Warn("BROKER_SERVER client-connected failed, retrying", + zap.Error(err), + zap.Duration("backoff", backoff), + ) + + select { + case <-ctx.Done(): + return + case <-time.After(backoff): + } + + backoff = min(backoff*2, maxBackoff) + } +} + +// notifyClientDisconnected sends client-disconnected to BROKER_SERVER with limited retry. +func (s *Service) notifyClientDisconnected(token broker.Token, clientID string) { + backoff := time.Second + for attempt := range 3 { + err := s.brokerClient.ClientDisconnected(token, clientID) + if err == nil { + return + } + s.logger.Warn("BROKER_SERVER client-disconnected failed", + zap.Error(err), + zap.Int("attempt", attempt+1), + ) + time.Sleep(backoff) + backoff *= 2 + } +} + +// heartbeatSender periodically sends heartbeat messages to the client. +func (s *Service) heartbeatSender(ctx context.Context, stream pb.TunnelService_TunnelServer, sendMu *sync.Mutex, done chan struct{}) { + defer close(done) + ticker := time.NewTicker(s.config.HeartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sendMu.Lock() + err := stream.Send(&pb.TunnelServerMessage{ + Message: &pb.TunnelServerMessage_Heartbeat{ + Heartbeat: &pb.Heartbeat{ + TimestampMs: time.Now().UnixMilli(), + }, + }, + }) + sendMu.Unlock() + if err != nil { + s.logger.Debug("Failed to send heartbeat", zap.Error(err)) + return + } + s.metrics.HeartbeatSent.Inc(1) + } + } +}