diff --git a/Cargo.lock b/Cargo.lock index a6fd04a..adcb394 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4926,9 +4926,11 @@ dependencies = [ "axum", "base64 0.22.1", "chrono", + "dashmap", "dotenvy", "ed25519-dalek", "fastembed", + "futures-util", "git2", "hex", "hmac 0.13.0", diff --git a/crates/tracevault-server/Cargo.toml b/crates/tracevault-server/Cargo.toml index 6c98823..b609e82 100644 --- a/crates/tracevault-server/Cargo.toml +++ b/crates/tracevault-server/Cargo.toml @@ -31,6 +31,8 @@ rand = "0.8" base64 = "0.22" git2 = "0.20" reqwest = { version = "0.13", features = ["json", "stream"] } +dashmap = "6" +futures-util = "0.3" async-trait = "0.1" aes-gcm = "0.10" dotenvy = "0.15.7" diff --git a/crates/tracevault-server/migrations/025_user_anthropic_keys_max_concurrent.sql b/crates/tracevault-server/migrations/025_user_anthropic_keys_max_concurrent.sql new file mode 100644 index 0000000..cc83c77 --- /dev/null +++ b/crates/tracevault-server/migrations/025_user_anthropic_keys_max_concurrent.sql @@ -0,0 +1,13 @@ +-- Per-credential concurrency cap for the transparent Anthropic LLM proxy +-- (issue softwaremill/tracevault#210, parent #181). +-- +-- The cap is the maximum number of in-flight proxy requests this credential +-- can have at any one moment. Enforced in-process via a tokio Semaphore in +-- AppState, sized to this value at first use of the credential. +-- +-- Default 8: comfortable for typical multi-agent setups (Claude Code + GSD2), +-- well under any paid Anthropic tier. Upper bound 256 prevents user-typed +-- nonsense values; lower bound 1 prevents accidental lockout. +ALTER TABLE user_anthropic_keys + ADD COLUMN max_concurrent INTEGER NOT NULL DEFAULT 8 + CHECK (max_concurrent > 0 AND max_concurrent <= 256); diff --git a/crates/tracevault-server/src/api/me.rs b/crates/tracevault-server/src/api/me.rs index f488818..219a4f7 100644 --- a/crates/tracevault-server/src/api/me.rs +++ b/crates/tracevault-server/src/api/me.rs @@ -19,11 +19,24 @@ use crate::AppState; pub struct AnthropicKeyStatus { pub configured: bool, pub configured_at: Option>, + /// Per-credential proxy concurrency cap. `None` when no key is + /// configured; otherwise the value stored on the row. + pub max_concurrent: Option, } #[derive(Deserialize)] pub struct PutAnthropicKeyRequest { - pub key: String, + /// Optional new Anthropic key. When omitted the existing ciphertext is + /// preserved — this is the "cap only" update path, used from the UI + /// when the user wants to change `max_concurrent` without rotating + /// the key. At least one of `key` or `max_concurrent` must be present. + #[serde(default)] + pub key: Option, + /// Optional per-credential proxy concurrency cap. Omit to keep the + /// existing value on update, or fall back to the DB default (8) on + /// first insert. + #[serde(default)] + pub max_concurrent: Option, } /// Reject the synthetic nil user_id that the AuthUser extractor returns when @@ -50,17 +63,40 @@ pub async fn get_anthropic_key_status( auth: AuthUser, ) -> Result, AppError> { let user_id = require_real_user(&auth)?; - let configured_at = UserAnthropicKeyRepo::configured_at(&state.pool, user_id).await?; - Ok(Json(AnthropicKeyStatus { - configured: configured_at.is_some(), - configured_at, + let status = UserAnthropicKeyRepo::status(&state.pool, user_id).await?; + Ok(Json(match status { + Some(s) => AnthropicKeyStatus { + configured: true, + configured_at: Some(s.configured_at), + max_concurrent: Some(s.max_concurrent), + }, + None => AnthropicKeyStatus { + configured: false, + configured_at: None, + max_concurrent: None, + }, })) } /// PUT /api/v1/me/anthropic-key /// -/// Upserts the caller's Anthropic key, encrypted with the server's master -/// encryption key. Returns 204 on success. +/// Upserts the caller's Anthropic key and/or its concurrency cap. The +/// request body has two optional fields, `key` and `max_concurrent`, but +/// at least one must be present. Use cases: +/// +/// * `{ key: "sk-ant-...", max_concurrent: 16 }` — first-time setup or +/// full rotation. +/// * `{ key: "sk-ant-..." }` — rotate the key; cap preserved (default 8 +/// applied if no row yet). +/// * `{ max_concurrent: 16 }` — change only the cap; key must already +/// exist (400 otherwise). +/// +/// In all cases the in-memory per-credential semaphore for this user is +/// dropped from the DashMap so the *next* proxy request rebuilds it +/// against the new cap value. In-flight requests keep their permits on +/// the old (dropped) semaphore for the lifetime of their response, +/// effectively letting the cap change apply at the natural next quiet +/// point. pub async fn put_anthropic_key( State(state): State, auth: AuthUser, @@ -68,39 +104,82 @@ pub async fn put_anthropic_key( ) -> Result { let user_id = require_real_user(&auth)?; - let key = req.key.trim(); - if key.is_empty() { + if req.key.is_none() && req.max_concurrent.is_none() { return Err(AppError::BadRequest( - "Anthropic key must not be empty".into(), + "Request must include `key`, `max_concurrent`, or both".into(), )); } - // Real Anthropic keys are ~110 chars; cap at 256 to leave generous - // headroom for future formats while preventing the endpoint from - // accepting a ~2 MB junk string and persisting it encrypted on the - // user_anthropic_keys row. - if key.len() > 256 { - return Err(AppError::BadRequest( - "Anthropic key is unreasonably long (max 256 chars)".into(), - )); + + // Validate max_concurrent if the caller specified one. Bounds mirror + // the DB CHECK constraint so we fail fast with a clear 400 instead of + // surfacing a generic constraint-violation 500 from the upsert. + if let Some(n) = req.max_concurrent { + if !(1..=256).contains(&n) { + return Err(AppError::BadRequest( + "max_concurrent must be between 1 and 256".into(), + )); + } } - // Anthropic API keys begin with `sk-ant-` (modern format). We reject - // anything that doesn't look like one to catch obvious paste mistakes - // (TV session token, empty string, environment variable name, etc.). - // We do *not* validate the key against api.anthropic.com here — that - // would couple this endpoint to upstream availability. - if !key.starts_with("sk-ant-") { - return Err(AppError::BadRequest( - "Anthropic key must start with 'sk-ant-'".into(), - )); + + match req.key.as_deref() { + Some(raw_key) => { + let key = raw_key.trim(); + if key.is_empty() { + return Err(AppError::BadRequest( + "Anthropic key must not be empty".into(), + )); + } + // Real Anthropic keys are ~110 chars; cap at 256 to leave generous + // headroom for future formats while preventing the endpoint from + // accepting a ~2 MB junk string and persisting it encrypted. + if key.len() > 256 { + return Err(AppError::BadRequest( + "Anthropic key is unreasonably long (max 256 chars)".into(), + )); + } + if !key.starts_with("sk-ant-") { + return Err(AppError::BadRequest( + "Anthropic key must start with 'sk-ant-'".into(), + )); + } + let encryption_key = state.encryption_key.as_deref().ok_or_else(|| { + AppError::Internal( + "Server is not configured with an encryption key; cannot store Anthropic keys" + .into(), + ) + })?; + UserAnthropicKeyRepo::upsert( + &state.pool, + encryption_key, + user_id, + key, + req.max_concurrent, + ) + .await?; + } + None => { + // Settings-only update — the caller explicitly passed + // max_concurrent without a new key. Requires an existing row; + // otherwise there is nothing to update and we refuse with 400 + // rather than silently inserting a half-row. + let new_cap = req.max_concurrent.expect("checked above"); + let updated = + UserAnthropicKeyRepo::update_max_concurrent(&state.pool, user_id, new_cap).await?; + if !updated { + return Err(AppError::BadRequest( + "Cannot update settings: no Anthropic key configured yet".into(), + )); + } + } } - let encryption_key = state.encryption_key.as_deref().ok_or_else(|| { - AppError::Internal( - "Server is not configured with an encryption key; cannot store Anthropic keys".into(), - ) - })?; + // Flush the in-memory per-credential semaphore so the next request + // rebuilds it against the new cap (or the freshly-persisted row). + // In-flight requests still hold permits on the old, now-orphaned + // Arc — when they finish they release naturally and the + // arc drops. + state.proxy_per_credential_semaphores.remove(&user_id); - UserAnthropicKeyRepo::upsert(&state.pool, encryption_key, user_id, key).await?; Ok(StatusCode::NO_CONTENT) } diff --git a/crates/tracevault-server/src/api/proxy.rs b/crates/tracevault-server/src/api/proxy.rs index e3cb2d6..f01ed38 100644 --- a/crates/tracevault-server/src/api/proxy.rs +++ b/crates/tracevault-server/src/api/proxy.rs @@ -71,10 +71,19 @@ const FORWARDED_RESPONSE_HEADERS: &[&str] = &[ /// `error.type` discriminants used in the Anthropic-shaped error envelope. /// Mirrors the documented Anthropic API error types so unmodified clients /// route these the same way they'd route a real api.anthropic.com error. +// +// All variants share the `*Error` suffix to mirror Anthropic's wire +// vocabulary (the `error.type` JSON field). +#[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, Copy)] enum ProxyErrorKind { AuthenticationError, ApiError, + /// Mirrors Anthropic's `overloaded_error` — agents already back off + /// gracefully on this `type` value, so reusing it for our internal + /// concurrency caps keeps client behavior identical to a real upstream + /// overload. + OverloadedError, } impl ProxyErrorKind { @@ -82,6 +91,7 @@ impl ProxyErrorKind { match self { ProxyErrorKind::AuthenticationError => "authentication_error", ProxyErrorKind::ApiError => "api_error", + ProxyErrorKind::OverloadedError => "overloaded_error", } } } @@ -147,10 +157,29 @@ pub async fn anthropic_proxy( ); } - let (user_id, upstream_key) = match authenticate(&state, &headers, &path).await { - Ok(pair) => pair, + let (user_id, upstream_key, max_concurrent) = match authenticate(&state, &headers, &path).await + { + Ok(triple) => triple, + Err(resp) => return resp, + }; + + // Acquire concurrency permits BEFORE dispatching upstream. Global cap + // first, then per-credential — see HeldPermits / build_downstream_response + // for why permits travel with the response stream rather than living + // as locals. + let global_permit = match try_acquire_global_permit(&state, user_id, &path) { + Ok(p) => p, Err(resp) => return resp, }; + let credential_permit = + match try_acquire_credential_permit(&state, user_id, max_concurrent, &path) { + Ok(p) => p, + Err(resp) => return resp, + }; + let permits = HeldPermits { + _credential: credential_permit, + _global: global_permit, + }; let upstream_resp = match forward_to_upstream( &state, @@ -178,18 +207,18 @@ pub async fn anthropic_proxy( "proxied request" ); - build_downstream_response(upstream_resp) + build_downstream_response(upstream_resp, permits) } /// Concern 1: extract `x-api-key`, resolve it to a user, and load that /// user's decrypted Anthropic credential. Returns the -/// `(user_id, upstream_plaintext_key)` pair on success, or an -/// Anthropic-shaped error envelope on any auth/credential failure. +/// `(user_id, upstream_plaintext_key, max_concurrent)` triple on success, +/// or an Anthropic-shaped error envelope on any auth/credential failure. async fn authenticate( state: &AppState, headers: &HeaderMap, path: &str, -) -> Result<(Uuid, String), Response> { +) -> Result<(Uuid, String, i32), Response> { let tv_token = match headers.get("x-api-key").and_then(|v| v.to_str().ok()) { Some(t) if !t.is_empty() => t, _ => { @@ -209,8 +238,8 @@ async fn authenticate( let token_hash = sha256_hex(tv_token); let user_id = resolve_token(state, &token_hash).await?; - let upstream_key = load_anthropic_key(state, user_id).await?; - Ok((user_id, upstream_key)) + let (upstream_key, max_concurrent) = load_credential(state, user_id).await?; + Ok((user_id, upstream_key, max_concurrent)) } /// Concern 2: build the upstream request from the user's downstream @@ -271,10 +300,20 @@ async fn forward_to_upstream( /// `Response` — copies status + allow-listed response headers and streams /// the body byte-for-byte via `bytes_stream()` so SSE responses pass /// through without buffering. -fn build_downstream_response(upstream_resp: reqwest::Response) -> Response { +/// +/// `permits` carries any concurrency permits acquired earlier in the +/// handler. We attach them to the response stream so they are dropped +/// only when the *streaming body* finishes — not when this function +/// returns. Otherwise SSE streams would release capacity the moment the +/// upstream's headers came back, allowing far more concurrent in-flight +/// upstream connections than the cap allows. +fn build_downstream_response(upstream_resp: reqwest::Response, permits: HeldPermits) -> Response { let upstream_status = upstream_resp.status(); let upstream_headers = upstream_resp.headers().clone(); - let body_stream = upstream_resp.bytes_stream(); + let body_stream = PermitHoldingStream { + inner: upstream_resp.bytes_stream(), + _permits: permits, + }; let mut downstream = Response::builder().status(upstream_status); if let Some(hdrs) = downstream.headers_mut() { @@ -293,6 +332,41 @@ fn build_downstream_response(upstream_resp: reqwest::Response) -> Response { }) } +/// Bundle of concurrency permits that must be held for the lifetime of a +/// proxy response (including its streaming body). Permits are released +/// in field-declaration order on drop, so the per-credential permit +/// releases before the global one — the inverse of acquisition order. +struct HeldPermits { + _credential: tokio::sync::OwnedSemaphorePermit, + _global: Option, +} + +/// Stream wrapper that owns concurrency permits alongside the inner +/// `bytes_stream()`. Dropping the stream (including via the response +/// body completing or the client disconnecting) drops the permits. +struct PermitHoldingStream { + inner: S, + _permits: HeldPermits, +} + +impl futures_util::Stream for PermitHoldingStream +where + S: futures_util::Stream + Unpin, +{ + type Item = S::Item; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.inner).poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + /// Resolve a sha256'd TV token to a user_id. Returns: /// - Ok(user_id) when the token is a valid, non-expired `auth_sessions` row /// - Err(401 envelope) when the token is missing or matches an org @@ -367,31 +441,14 @@ async fn resolve_token(state: &AppState, token_hash: &str) -> Result Result { - let row = UserAnthropicKeyRepo::get_ciphertext(&state.pool, user_id) - .await - .map_err(|e| { - tracing::warn!( - user_id = %user_id, - error_type = "api_error", - err = %e, - "failed to load user_anthropic_keys row" - ); - anthropic_error( - StatusCode::INTERNAL_SERVER_ERROR, - ProxyErrorKind::ApiError, - "Failed to load upstream credentials", - ) - })?; - - let (encrypted, nonce) = match row { - Some(r) => r, - None => { +/// Fetch the user's stored credential (encrypted Anthropic key + cap) and +/// decrypt the key. Returns `(plaintext, max_concurrent)` on success or an +/// Anthropic-shaped error envelope on any failure (no key configured, no +/// master key on this server, ciphertext corrupted, DB error). +async fn load_credential(state: &AppState, user_id: Uuid) -> Result<(String, i32), Response> { + let credential = match UserAnthropicKeyRepo::get_credential(&state.pool, user_id).await { + Ok(Some(c)) => c, + Ok(None) => { tracing::warn!( user_id = %user_id, error_type = "authentication_error", @@ -404,6 +461,19 @@ async fn load_anthropic_key(state: &AppState, user_id: Uuid) -> Result { + tracing::warn!( + user_id = %user_id, + error_type = "api_error", + err = %e, + "failed to load user_anthropic_keys row" + ); + return Err(anthropic_error( + StatusCode::INTERNAL_SERVER_ERROR, + ProxyErrorKind::ApiError, + "Failed to load upstream credentials", + )); + } }; let master_key = state.encryption_key.as_deref().ok_or_else(|| { @@ -419,19 +489,104 @@ async fn load_anthropic_key(state: &AppState, user_id: Uuid) -> Result` is the established error-return shape in this +// module (see `authenticate`, `forward_to_upstream`). +#[allow(clippy::result_large_err)] +fn try_acquire_global_permit( + state: &AppState, + user_id: Uuid, + path: &str, +) -> Result, Response> { + let Some(sem) = state.proxy_global_semaphore.as_ref() else { + return Ok(None); + }; + match sem.clone().try_acquire_owned() { + Ok(p) => Ok(Some(p)), + Err(_) => { + tracing::warn!( + user_id = %user_id, + error_type = "overloaded_error", + reason = "global_cap", + path = %path, + "proxy rejected request: global concurrency cap reached" + ); + Err(anthropic_error( + StatusCode::TOO_MANY_REQUESTS, + ProxyErrorKind::OverloadedError, + "Server is at capacity. Retry shortly.", + )) + } + } +} + +/// Try to acquire a permit from the per-credential concurrency cap. +/// Lazily creates the semaphore on first use, sized to `max_concurrent`. +/// On capacity exhaustion returns an Anthropic-shaped 429 with +/// `overloaded_error` and a message naming the configured cap so the user +/// can debug it from their `/me/proxy/` UI. +#[allow(clippy::result_large_err)] +fn try_acquire_credential_permit( + state: &AppState, + user_id: Uuid, + max_concurrent: i32, + path: &str, +) -> Result { + // i32 -> usize is safe because the DB CHECK constraint clamps to (0, 256]. + // Defensive clamp at the lower end protects against an out-of-spec row. + let cap = max_concurrent.max(1) as usize; + + // Look up or insert the per-credential semaphore. The DashMap entry + // guard is held only across the `.clone()` of the Arc — never across + // the .await/.acquire — so there is no chance of a guard living across + // a yield point or self-deadlocking on the same shard. + let sem = state + .proxy_per_credential_semaphores + .entry(user_id) + .or_insert_with(|| std::sync::Arc::new(tokio::sync::Semaphore::new(cap))) + .clone(); + + match sem.try_acquire_owned() { + Ok(p) => Ok(p), + Err(_) => { + tracing::warn!( + user_id = %user_id, + error_type = "overloaded_error", + reason = "per_credential_cap", + cap_value = max_concurrent, + path = %path, + "proxy rejected request: per-credential concurrency cap reached" + ); + Err(anthropic_error( + StatusCode::TOO_MANY_REQUESTS, + ProxyErrorKind::OverloadedError, + &format!( + "Too many concurrent requests against this credential (cap: {max_concurrent}). Retry shortly." + ), + )) + } + } } /// Copy allow-listed and `anthropic-*` headers from `src` into `dst`. @@ -563,5 +718,6 @@ mod tests { "authentication_error" ); assert_eq!(ProxyErrorKind::ApiError.as_str(), "api_error"); + assert_eq!(ProxyErrorKind::OverloadedError.as_str(), "overloaded_error"); } } diff --git a/crates/tracevault-server/src/lib.rs b/crates/tracevault-server/src/lib.rs index 1b1d0bd..e05e24b 100644 --- a/crates/tracevault-server/src/lib.rs +++ b/crates/tracevault-server/src/lib.rs @@ -60,4 +60,26 @@ pub struct AppState { /// `https://api.anthropic.com` in production; overridden in tests so a /// wiremock stub upstream can stand in for the real Anthropic API. pub anthropic_upstream_base: String, + /// Optional global cap on in-flight proxy requests across all users. + /// `None` = unlimited (default); set the operator env var + /// `PROXY_MAX_GLOBAL_CONCURRENT` to enable. + pub proxy_global_semaphore: Option>, + /// Per-credential concurrency semaphores. Keyed by + /// `user_anthropic_keys.user_id` (effectively the credential ID today; + /// generalizes to org/credential IDs once those land). Each semaphore is + /// lazily created on first request for a credential, sized to the + /// credential's stored `max_concurrent` at that moment. + /// + /// Update semantics are intentionally lazy: a PUT that changes + /// `max_concurrent` only updates the DB row, *not* the in-memory + /// semaphore. The new cap takes effect on the next process restart, or + /// after the entry is explicitly evicted. This avoids the atomic-swap + /// edge cases of mid-flight cap changes. + /// + /// Growth: this DashMap grows monotonically with credentials that have + /// received at least one proxy request since startup. At expected scale + /// (<= ~10k credentials) the footprint is a few MB. Revisit eviction + /// (TTL or LRU) if active credentials exceed that threshold. + pub proxy_per_credential_semaphores: + std::sync::Arc>>, } diff --git a/crates/tracevault-server/src/main.rs b/crates/tracevault-server/src/main.rs index d49b641..29a32fa 100644 --- a/crates/tracevault-server/src/main.rs +++ b/crates/tracevault-server/src/main.rs @@ -72,6 +72,29 @@ async fn main() { .build() .expect("Failed to build proxy reqwest client"); + // Optional global concurrency cap across all proxy requests. Unset = no + // global limit; this is the right default for the small-team deployments + // we ship to today. Operators turn this on after capacity testing; a + // sensible starting value is 256. + let proxy_global_semaphore: Option> = + match std::env::var("PROXY_MAX_GLOBAL_CONCURRENT") { + Ok(s) => match s.parse::() { + Ok(n) if n > 0 => { + tracing::info!(cap = n, "proxy global concurrency cap enabled"); + Some(std::sync::Arc::new(tokio::sync::Semaphore::new(n))) + } + _ => { + tracing::warn!( + value = %s, + "PROXY_MAX_GLOBAL_CONCURRENT is set but not a positive integer; ignoring" + ); + None + } + }, + Err(_) => None, + }; + let proxy_per_credential_semaphores = std::sync::Arc::new(dashmap::DashMap::new()); + // Auto-sync repos that are in 'ready' state on startup sync_repos_on_startup(&pool, &repo_manager, &extensions).await; @@ -624,6 +647,8 @@ async fn main() { invite_expiry_minutes: cfg.invite_expiry_minutes, anthropic_upstream_base: api::proxy::DEFAULT_ANTHROPIC_UPSTREAM_BASE.to_string(), embedding_service, + proxy_global_semaphore: proxy_global_semaphore.clone(), + proxy_per_credential_semaphores: proxy_per_credential_semaphores.clone(), }); let listener = tokio::net::TcpListener::bind(&bind_addr).await.unwrap(); diff --git a/crates/tracevault-server/src/repo/user_anthropic_keys.rs b/crates/tracevault-server/src/repo/user_anthropic_keys.rs index 262a1bc..5194b92 100644 --- a/crates/tracevault-server/src/repo/user_anthropic_keys.rs +++ b/crates/tracevault-server/src/repo/user_anthropic_keys.rs @@ -14,67 +14,126 @@ use crate::error::AppError; pub struct UserAnthropicKeyRepo; +/// The plaintext key returned by `get_credential`, plus its concurrency cap. +/// We pull the cap out of the same query so the proxy hot path stays one +/// round-trip per request. +pub struct StoredCredential { + pub encrypted: String, + pub nonce: String, + pub max_concurrent: i32, +} + +/// Status returned by the GET endpoint: when the key was set + the +/// current concurrency cap. Never reveals key material. +pub struct StoredStatus { + pub configured_at: DateTime, + pub max_concurrent: i32, +} + impl UserAnthropicKeyRepo { /// Encrypt `plaintext_key` with the configured master `encryption_key` /// and upsert it for `user_id`. On conflict the existing row is /// overwritten and `updated_at` advances; `created_at` is preserved. + /// + /// `max_concurrent` is `Some(N)` to set or change the cap, or `None` + /// to keep the existing value on update (or fall back to the DB + /// default `8` on insert). pub async fn upsert( pool: &PgPool, encryption_key: &str, user_id: Uuid, plaintext_key: &str, + max_concurrent: Option, ) -> Result<(), AppError> { let (encrypted, nonce) = encryption::encrypt(plaintext_key, encryption_key) .map_err(|e| AppError::Internal(format!("failed to encrypt anthropic key: {e}")))?; + // COALESCE-based update lets us either accept an explicit new cap + // or preserve whatever was already stored. On INSERT, EXCLUDED's + // max_concurrent is NULL when the caller didn't specify one and + // the DB default kicks in for the column. sqlx::query( - "INSERT INTO user_anthropic_keys (user_id, key_encrypted, key_nonce) - VALUES ($1, $2, $3) + "INSERT INTO user_anthropic_keys (user_id, key_encrypted, key_nonce, max_concurrent) + VALUES ($1, $2, $3, COALESCE($4, 8)) ON CONFLICT (user_id) DO UPDATE SET key_encrypted = EXCLUDED.key_encrypted, key_nonce = EXCLUDED.key_nonce, + max_concurrent = COALESCE($4, user_anthropic_keys.max_concurrent), updated_at = now()", ) .bind(user_id) .bind(&encrypted) .bind(&nonce) + .bind(max_concurrent) .execute(pool) .await?; Ok(()) } - /// Return the encrypted ciphertext and nonce for `user_id`, or `None` - /// if no key is configured. Callers decrypt via `crate::encryption::decrypt`. - pub async fn get_ciphertext( + /// Return the stored credential (ciphertext + nonce + cap) for + /// `user_id`, or `None` if no key is configured. The proxy calls this + /// on every request — it is the one read on the hot path. + pub async fn get_credential( pool: &PgPool, user_id: Uuid, - ) -> Result, AppError> { - let row = sqlx::query_as::<_, (String, String)>( - "SELECT key_encrypted, key_nonce FROM user_anthropic_keys WHERE user_id = $1", + ) -> Result, AppError> { + let row = sqlx::query_as::<_, (String, String, i32)>( + "SELECT key_encrypted, key_nonce, max_concurrent + FROM user_anthropic_keys + WHERE user_id = $1", ) .bind(user_id) .fetch_optional(pool) .await?; - Ok(row) + Ok( + row.map(|(encrypted, nonce, max_concurrent)| StoredCredential { + encrypted, + nonce, + max_concurrent, + }), + ) } - /// Return `Some(updated_at)` if a key is configured for `user_id`, `None` - /// otherwise. Used by the status-only GET endpoint — never reveals key - /// material. - pub async fn configured_at( - pool: &PgPool, - user_id: Uuid, - ) -> Result>, AppError> { - let row = sqlx::query_scalar::<_, DateTime>( - "SELECT updated_at FROM user_anthropic_keys WHERE user_id = $1", + /// Return `Some(StoredStatus)` if a key is configured for `user_id`, + /// `None` otherwise. Used by the status-only GET endpoint — never + /// reveals key material. + pub async fn status(pool: &PgPool, user_id: Uuid) -> Result, AppError> { + let row = sqlx::query_as::<_, (DateTime, i32)>( + "SELECT updated_at, max_concurrent + FROM user_anthropic_keys + WHERE user_id = $1", ) .bind(user_id) .fetch_optional(pool) .await?; - Ok(row) + Ok(row.map(|(configured_at, max_concurrent)| StoredStatus { + configured_at, + max_concurrent, + })) + } + + /// Update only the `max_concurrent` cap for `user_id`, leaving the + /// stored ciphertext + nonce untouched. Returns `true` when a row was + /// updated, `false` when no row existed (the caller should surface + /// "configure a key first" in that case). + pub async fn update_max_concurrent( + pool: &PgPool, + user_id: Uuid, + max_concurrent: i32, + ) -> Result { + let res = sqlx::query( + "UPDATE user_anthropic_keys + SET max_concurrent = $2, updated_at = now() + WHERE user_id = $1", + ) + .bind(user_id) + .bind(max_concurrent) + .execute(pool) + .await?; + Ok(res.rows_affected() > 0) } /// Remove the row for `user_id`. Idempotent — returns Ok even if no row diff --git a/crates/tracevault-server/tests/proxy_integration.rs b/crates/tracevault-server/tests/proxy_integration.rs index 4a638b1..dc5aba3 100644 --- a/crates/tracevault-server/tests/proxy_integration.rs +++ b/crates/tracevault-server/tests/proxy_integration.rs @@ -35,6 +35,9 @@ use wiremock::{ struct Harness { app: Router, upstream: MockServer, + /// The same PgPool wired into the AppState — useful for tests that + /// seed extra users or keys beyond what the harness creates. + pool: sqlx::PgPool, /// Raw TV session token to send in x-api-key. Test user has a stored /// Anthropic key of `sk-ant-test-upstream-key`. user_session_token: String, @@ -76,6 +79,7 @@ async fn build_harness(pool: sqlx::PgPool) -> Harness { &encryption_key, user_with_key, "sk-ant-test-upstream-key", + None, ) .await .unwrap(); @@ -91,6 +95,8 @@ async fn build_harness(pool: sqlx::PgPool) -> Harness { invite_expiry_minutes: 60, embedding_service: None, anthropic_upstream_base: upstream.uri(), + proxy_global_semaphore: None, + proxy_per_credential_semaphores: std::sync::Arc::new(dashmap::DashMap::new()), }; let app = Router::new() @@ -115,6 +121,93 @@ async fn build_harness(pool: sqlx::PgPool) -> Harness { Harness { app, upstream, + pool, + user_session_token, + user_no_key_session_token, + org_api_key_token: raw_org_token, + } +} + +/// Build a harness with explicit concurrency caps. The default `build_harness` +/// uses `max_concurrent = 8` (DB default) and no global cap, which works for +/// every test that does not exercise the cap. The cap-specific tests need +/// tighter knobs: +/// * `per_credential_cap`: overrides the seeded user's `max_concurrent`. +/// * `global_cap`: when `Some(n)`, the AppState carries a global +/// `Semaphore::new(n)`; when `None`, the global cap is disabled. +async fn build_harness_with_caps( + pool: sqlx::PgPool, + per_credential_cap: i32, + global_cap: Option, +) -> Harness { + let upstream = MockServer::start().await; + + let org_id = common::seed_org(&pool).await; + let user_with_key = common::seed_user(&pool).await; + let user_without_key = common::seed_user(&pool).await; + let user_session_token = common::seed_auth_session(&pool, user_with_key).await; + let user_no_key_session_token = common::seed_auth_session(&pool, user_without_key).await; + + let raw_org_token = format!("tv_ak_{}", Uuid::new_v4()); + let org_token_hash = tracevault_server::auth::sha256_hex(&raw_org_token); + sqlx::query("INSERT INTO api_keys (org_id, key_hash, name) VALUES ($1, $2, $3)") + .bind(org_id) + .bind(&org_token_hash) + .bind("test-org-key") + .execute(&pool) + .await + .unwrap(); + + let encryption_key = common::fixture_encryption_key(); + tracevault_server::repo::user_anthropic_keys::UserAnthropicKeyRepo::upsert( + &pool, + &encryption_key, + user_with_key, + "sk-ant-test-upstream-key", + Some(per_credential_cap), + ) + .await + .unwrap(); + + let proxy_global_semaphore = + global_cap.map(|n| std::sync::Arc::new(tokio::sync::Semaphore::new(n))); + + let state = AppState { + pool: pool.clone(), + repo_manager: repo_manager::RepoManager::new("/tmp"), + extensions: tracevault_server::extensions::community_registry(), + encryption_key: Some(encryption_key), + http_client: reqwest::Client::new(), + proxy_http_client: reqwest::Client::new(), + cors_origin: "*".to_string(), + invite_expiry_minutes: 60, + embedding_service: None, + anthropic_upstream_base: upstream.uri(), + proxy_global_semaphore, + proxy_per_credential_semaphores: std::sync::Arc::new(dashmap::DashMap::new()), + }; + + let app = Router::new() + .route( + "/proxy/anthropic/{*path}", + get(api::proxy::anthropic_proxy) + .post(api::proxy::anthropic_proxy) + .put(api::proxy::anthropic_proxy) + .delete(api::proxy::anthropic_proxy), + ) + .layer(DefaultBodyLimit::max(32 * 1024 * 1024)) + .route( + "/api/v1/me/anthropic-key", + get(api::me::get_anthropic_key_status) + .put(api::me::put_anthropic_key) + .delete(api::me::delete_anthropic_key), + ) + .with_state(state); + + Harness { + app, + upstream, + pool, user_session_token, user_no_key_session_token, org_api_key_token: raw_org_token, @@ -406,6 +499,7 @@ async fn proxy_returns_502_when_upstream_unreachable(pool: sqlx::PgPool) { &encryption_key, user, "sk-ant-doesnt-matter", + None, ) .await .unwrap(); @@ -425,6 +519,8 @@ async fn proxy_returns_502_when_upstream_unreachable(pool: sqlx::PgPool) { invite_expiry_minutes: 60, embedding_service: None, anthropic_upstream_base: "http://127.0.0.1:1".to_string(), + proxy_global_semaphore: None, + proxy_per_credential_semaphores: std::sync::Arc::new(dashmap::DashMap::new()), }; let app = Router::new() @@ -600,6 +696,187 @@ async fn proxy_rejects_path_traversal_segments(pool: sqlx::PgPool) { ); } +// --- Proxy: per-credential and global concurrency caps (#210) ------------- + +use std::time::Duration; + +/// Build a request to the proxy with the standard headers + a marker query +/// so we can tell wiremock-served requests apart. +fn proxy_request(token: &str) -> Request { + Request::builder() + .method("POST") + .uri("/proxy/anthropic/v1/messages") + .header("x-api-key", token) + .header("content-type", "application/json") + .body(Body::from(r#"{"model":"claude-haiku","max_tokens":1}"#)) + .unwrap() +} + +/// Per-credential cap exceeded: with `max_concurrent = 2`, two in-flight +/// requests succeed (eventually), but the third in-flight request returns +/// 429 / `overloaded_error` with `reason = per_credential_cap`. +#[sqlx::test(migrations = "./migrations")] +async fn proxy_rejects_when_per_credential_cap_exceeded(pool: sqlx::PgPool) { + let h = build_harness_with_caps(pool, 2, None).await; + + // Upstream sits on each request for 2s so the in-flight permits are + // really held when we issue the rejecting request. + Mock::given(method("POST")) + .and(wm_path("/v1/messages")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("{}") + .set_delay(Duration::from_secs(2)), + ) + .mount(&h.upstream) + .await; + + let app = h.app.clone(); + let token = h.user_session_token.clone(); + + // Spawn two slow-but-eventually-OK requests so the per-credential + // semaphore is at full capacity. We deliberately do not await these; + // they keep the permits held until the wiremock delay elapses or the + // task is dropped at end-of-test. + let _h1 = tokio::spawn({ + let app = app.clone(); + let token = token.clone(); + async move { app.oneshot(proxy_request(&token)).await } + }); + let _h2 = tokio::spawn({ + let app = app.clone(); + let token = token.clone(); + async move { app.oneshot(proxy_request(&token)).await } + }); + + // Brief yield so both spawned tasks reach the acquire/upstream-send + // boundary before we issue the rejecting request. + tokio::time::sleep(Duration::from_millis(150)).await; + + let resp = app + .clone() + .oneshot(proxy_request(&token)) + .await + .expect("third request should respond, not panic"); + assert_eq!( + resp.status(), + StatusCode::TOO_MANY_REQUESTS, + "third in-flight request must hit the per-credential cap" + ); + let body = read_body_to_value(resp.into_body()).await; + assert_eq!(body["type"], "error"); + assert_eq!(body["error"]["type"], "overloaded_error"); + let msg = body["error"]["message"].as_str().unwrap_or(""); + assert!( + msg.contains("cap: 2"), + "error message should name the configured cap; got: {msg}" + ); +} + +/// After the in-flight requests complete and release their permits, a +/// new request must succeed. Guards against the bug where permits leak. +#[sqlx::test(migrations = "./migrations")] +async fn proxy_frees_permit_when_request_completes(pool: sqlx::PgPool) { + let h = build_harness_with_caps(pool, 1, None).await; + + Mock::given(method("POST")) + .and(wm_path("/v1/messages")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("{}") + .set_delay(Duration::from_millis(100)), + ) + .mount(&h.upstream) + .await; + + // Cap is 1. First request must succeed. + let r1 = h + .app + .clone() + .oneshot(proxy_request(&h.user_session_token)) + .await + .unwrap(); + assert_eq!(r1.status(), StatusCode::OK); + // Drain the body so the streaming permit is dropped — otherwise the + // permit stays held until r1.into_body() is consumed. + let _ = read_body_to_bytes(r1.into_body()).await; + + // Second request, sequential, after the first completes: must succeed. + // If the permit leaked we'd get 429 here instead. + let r2 = h + .app + .clone() + .oneshot(proxy_request(&h.user_session_token)) + .await + .unwrap(); + assert_eq!( + r2.status(), + StatusCode::OK, + "second sequential request must succeed once the first releases its permit" + ); +} + +/// Global cap exceeded: with `Semaphore::new(1)`, one in-flight request +/// from any user holds the only global slot; a request from a *different* +/// user must be rejected with `reason = global_cap`. +#[sqlx::test(migrations = "./migrations")] +async fn proxy_rejects_when_global_cap_exceeded(pool: sqlx::PgPool) { + let h = build_harness_with_caps(pool, 8, Some(1)).await; + + // Seed a second user + session with their own Anthropic key so we can + // prove the cap is global (cross-user), not per-credential. + let second_user = common::seed_user(&h.pool).await; + let second_token = common::seed_auth_session(&h.pool, second_user).await; + let encryption_key = common::fixture_encryption_key(); + tracevault_server::repo::user_anthropic_keys::UserAnthropicKeyRepo::upsert( + &h.pool, + &encryption_key, + second_user, + "sk-ant-second-upstream-key", + Some(8), + ) + .await + .unwrap(); + + Mock::given(method("POST")) + .and(wm_path("/v1/messages")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("{}") + .set_delay(Duration::from_secs(2)), + ) + .mount(&h.upstream) + .await; + + let app = h.app.clone(); + + // User 1 holds the only global slot. + let token1 = h.user_session_token.clone(); + let _holder = tokio::spawn({ + let app = app.clone(); + async move { app.oneshot(proxy_request(&token1)).await } + }); + tokio::time::sleep(Duration::from_millis(150)).await; + + // User 2 tries to use the proxy — they have their own per-credential + // budget but the global cap is exhausted, so this must 429. + let resp = app + .clone() + .oneshot(proxy_request(&second_token)) + .await + .expect("request should respond, not panic"); + assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS); + let body = read_body_to_value(resp.into_body()).await; + assert_eq!(body["error"]["type"], "overloaded_error"); + assert!( + body["error"]["message"] + .as_str() + .unwrap_or("") + .contains("Server is at capacity"), + "global cap rejection should use the server-wide message: {body}" + ); +} + // --- /api/v1/me/anthropic-key HTTP lifecycle (deferred from T02) --------- #[sqlx::test(migrations = "./migrations")] @@ -622,6 +899,8 @@ async fn me_anthropic_key_lifecycle(pool: sqlx::PgPool) { invite_expiry_minutes: 60, embedding_service: None, anthropic_upstream_base: upstream.uri(), + proxy_global_semaphore: None, + proxy_per_credential_semaphores: std::sync::Arc::new(dashmap::DashMap::new()), }; let app = Router::new() @@ -751,3 +1030,164 @@ async fn me_anthropic_key_lifecycle(pool: sqlx::PgPool) { let body = read_body_to_value(resp.into_body()).await; assert_eq!(body["configured"], false); } + +/// Build a minimal app + state with just the /me/anthropic-key endpoints, +/// for tests that exercise the settings-only PUT path. Returns +/// (app, bearer-header-value, user_id, shared-semaphore-map). +async fn build_me_endpoints_only( + pool: sqlx::PgPool, +) -> ( + Router, + String, + uuid::Uuid, + std::sync::Arc>>, +) { + let upstream = MockServer::start().await; + let user = common::seed_user(&pool).await; + let session = common::seed_auth_session(&pool, user).await; + let encryption_key = common::fixture_encryption_key(); + let sems: std::sync::Arc>> = + std::sync::Arc::new(dashmap::DashMap::new()); + + let state = AppState { + pool: pool.clone(), + repo_manager: repo_manager::RepoManager::new("/tmp"), + extensions: tracevault_server::extensions::community_registry(), + encryption_key: Some(encryption_key), + http_client: reqwest::Client::new(), + proxy_http_client: reqwest::Client::new(), + cors_origin: "*".to_string(), + invite_expiry_minutes: 60, + embedding_service: None, + anthropic_upstream_base: upstream.uri(), + proxy_global_semaphore: None, + proxy_per_credential_semaphores: sems.clone(), + }; + + let app = Router::new() + .route( + "/api/v1/me/anthropic-key", + get(api::me::get_anthropic_key_status) + .put(api::me::put_anthropic_key) + .delete(api::me::delete_anthropic_key), + ) + .with_state(state); + + (app, format!("Bearer {session}"), user, sems) +} + +fn put_request(bearer: &str, body: serde_json::Value) -> Request { + Request::builder() + .method("PUT") + .uri("/api/v1/me/anthropic-key") + .header("authorization", bearer) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_vec(&body).unwrap())) + .unwrap() +} + +/// Setting cap-only on an existing row must update only max_concurrent and +/// preserve the ciphertext, *and* must drop the in-memory semaphore so the +/// next proxy request rebuilds it with the new cap. +#[sqlx::test(migrations = "./migrations")] +async fn me_anthropic_key_put_updates_cap_only(pool: sqlx::PgPool) { + let (app, bearer, user_id, sems) = build_me_endpoints_only(pool.clone()).await; + + // Seed initial key with cap=4. + let r = app + .clone() + .oneshot(put_request( + &bearer, + serde_json::json!({ "key": "sk-ant-initial-fixture", "max_concurrent": 4 }), + )) + .await + .unwrap(); + assert_eq!(r.status(), StatusCode::NO_CONTENT); + + // Prime the in-memory semaphore so we can prove it gets dropped. + sems.entry(user_id) + .or_insert_with(|| std::sync::Arc::new(tokio::sync::Semaphore::new(4))); + assert!(sems.contains_key(&user_id)); + + // Cap-only PUT. + let r = app + .clone() + .oneshot(put_request( + &bearer, + serde_json::json!({ "max_concurrent": 16 }), + )) + .await + .unwrap(); + assert_eq!(r.status(), StatusCode::NO_CONTENT); + + // Verify ciphertext unchanged via the GET endpoint (status returns + // configured=true, max_concurrent=16) AND that the semaphore entry + // is gone. + let r = app + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri("/api/v1/me/anthropic-key") + .header("authorization", &bearer) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(r.status(), StatusCode::OK); + let body = read_body_to_value(r.into_body()).await; + assert_eq!(body["configured"], true); + assert_eq!(body["max_concurrent"], 16); + assert!( + !sems.contains_key(&user_id), + "settings-only PUT must drop the in-memory semaphore so the new cap takes effect" + ); + + // And the encrypted key in the DB really is still the initial one. + let cred = tracevault_server::repo::user_anthropic_keys::UserAnthropicKeyRepo::get_credential( + &pool, user_id, + ) + .await + .unwrap() + .unwrap(); + let plaintext = tracevault_server::encryption::decrypt( + &cred.encrypted, + &cred.nonce, + &common::fixture_encryption_key(), + ) + .unwrap(); + assert_eq!(plaintext, "sk-ant-initial-fixture"); + assert_eq!(cred.max_concurrent, 16); +} + +/// Cap-only PUT before any key has been configured must return 400 — we +/// don't want a half-row containing only a cap and no key material. +#[sqlx::test(migrations = "./migrations")] +async fn me_anthropic_key_put_rejects_cap_only_when_unconfigured(pool: sqlx::PgPool) { + let (app, bearer, _user_id, _sems) = build_me_endpoints_only(pool).await; + + let r = app + .clone() + .oneshot(put_request( + &bearer, + serde_json::json!({ "max_concurrent": 16 }), + )) + .await + .unwrap(); + assert_eq!(r.status(), StatusCode::BAD_REQUEST); +} + +/// Empty body (neither key nor cap) must return 400 rather than silently +/// noop. +#[sqlx::test(migrations = "./migrations")] +async fn me_anthropic_key_put_rejects_empty_body(pool: sqlx::PgPool) { + let (app, bearer, _user_id, _sems) = build_me_endpoints_only(pool).await; + + let r = app + .clone() + .oneshot(put_request(&bearer, serde_json::json!({}))) + .await + .unwrap(); + assert_eq!(r.status(), StatusCode::BAD_REQUEST); +} diff --git a/crates/tracevault-server/tests/proxy_real_anthropic.rs b/crates/tracevault-server/tests/proxy_real_anthropic.rs index 2210873..179fa40 100644 --- a/crates/tracevault-server/tests/proxy_real_anthropic.rs +++ b/crates/tracevault-server/tests/proxy_real_anthropic.rs @@ -47,6 +47,7 @@ async fn build_real_state(pool: &sqlx::PgPool, upstream_key: &str) -> (AppState, &encryption_key, user, upstream_key, + None, ) .await .unwrap(); @@ -66,6 +67,8 @@ async fn build_real_state(pool: &sqlx::PgPool, upstream_key: &str) -> (AppState, embedding_service: None, // Defaults to the real api.anthropic.com — exactly what we want here. anthropic_upstream_base: api::proxy::DEFAULT_ANTHROPIC_UPSTREAM_BASE.to_string(), + proxy_global_semaphore: None, + proxy_per_credential_semaphores: std::sync::Arc::new(dashmap::DashMap::new()), }; (state, session_token) } diff --git a/crates/tracevault-server/tests/repo_user_anthropic_keys_test.rs b/crates/tracevault-server/tests/repo_user_anthropic_keys_test.rs index 4942777..f68d61d 100644 --- a/crates/tracevault-server/tests/repo_user_anthropic_keys_test.rs +++ b/crates/tracevault-server/tests/repo_user_anthropic_keys_test.rs @@ -1,7 +1,8 @@ //! Integration tests for `UserAnthropicKeyRepo`. Verifies the -//! upsert / get / configured_at / delete lifecycle and that the on-disk -//! ciphertext is recoverable via `encryption::decrypt` — i.e. the layer -//! that the proxy hot path will rely on. +//! upsert / get / status / delete lifecycle, the on-disk ciphertext is +//! recoverable via `encryption::decrypt`, and that the per-credential +//! `max_concurrent` cap roundtrips correctly through upsert / read paths +//! (issue softwaremill/tracevault#210). mod common; @@ -22,16 +23,16 @@ async fn upsert_then_get_roundtrips_plaintext(pool: sqlx::PgPool) { let master = fixture_key(); let plaintext = "sk-ant-test-fixture-not-a-real-key"; - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, plaintext) + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, plaintext, None) .await .expect("upsert"); - let (ct, nonce) = UserAnthropicKeyRepo::get_ciphertext(&pool, user_id) + let cred = UserAnthropicKeyRepo::get_credential(&pool, user_id) .await .expect("get") .expect("row present after upsert"); - let recovered = encryption::decrypt(&ct, &nonce, &master).expect("decrypt"); + let recovered = encryption::decrypt(&cred.encrypted, &cred.nonce, &master).expect("decrypt"); assert_eq!(recovered, plaintext); } @@ -40,50 +41,51 @@ async fn upsert_replaces_existing_key(pool: sqlx::PgPool) { let user_id = common::seed_user(&pool).await; let master = fixture_key(); - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-first") + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-first", None) .await .unwrap(); - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-second") + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-second", None) .await .unwrap(); - let (ct, nonce) = UserAnthropicKeyRepo::get_ciphertext(&pool, user_id) + let cred = UserAnthropicKeyRepo::get_credential(&pool, user_id) .await .unwrap() .unwrap(); - let recovered = encryption::decrypt(&ct, &nonce, &master).unwrap(); + let recovered = encryption::decrypt(&cred.encrypted, &cred.nonce, &master).unwrap(); assert_eq!(recovered, "sk-ant-second"); } #[sqlx::test(migrations = "./migrations")] -async fn get_ciphertext_returns_none_when_missing(pool: sqlx::PgPool) { +async fn get_credential_returns_none_when_missing(pool: sqlx::PgPool) { let user_id = common::seed_user(&pool).await; - let result = UserAnthropicKeyRepo::get_ciphertext(&pool, user_id) + let result = UserAnthropicKeyRepo::get_credential(&pool, user_id) .await .unwrap(); assert!(result.is_none()); } #[sqlx::test(migrations = "./migrations")] -async fn configured_at_reflects_presence(pool: sqlx::PgPool) { +async fn status_reflects_presence(pool: sqlx::PgPool) { let user_id = common::seed_user(&pool).await; let master = fixture_key(); - assert!(UserAnthropicKeyRepo::configured_at(&pool, user_id) + assert!(UserAnthropicKeyRepo::status(&pool, user_id) .await .unwrap() .is_none()); - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-test") + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-test", None) .await .unwrap(); - let ts = UserAnthropicKeyRepo::configured_at(&pool, user_id) + let s = UserAnthropicKeyRepo::status(&pool, user_id) .await - .unwrap(); - assert!( - ts.is_some(), - "configured_at should return Some after upsert" + .unwrap() + .expect("status should return Some after upsert"); + assert_eq!( + s.max_concurrent, 8, + "fresh upsert without explicit cap should use DB default of 8" ); } @@ -92,25 +94,27 @@ async fn upsert_advances_updated_at(pool: sqlx::PgPool) { let user_id = common::seed_user(&pool).await; let master = fixture_key(); - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-first") + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-first", None) .await .unwrap(); - let t1 = UserAnthropicKeyRepo::configured_at(&pool, user_id) + let t1 = UserAnthropicKeyRepo::status(&pool, user_id) .await .unwrap() - .unwrap(); + .unwrap() + .configured_at; // Sleep briefly so postgres `now()` resolves to a later timestamp. // Postgres `now()` has microsecond resolution; 10ms is plenty. tokio::time::sleep(std::time::Duration::from_millis(10)).await; - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-second") + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-second", None) .await .unwrap(); - let t2 = UserAnthropicKeyRepo::configured_at(&pool, user_id) + let t2 = UserAnthropicKeyRepo::status(&pool, user_id) .await .unwrap() - .unwrap(); + .unwrap() + .configured_at; assert!( t2 > t1, @@ -118,21 +122,88 @@ async fn upsert_advances_updated_at(pool: sqlx::PgPool) { ); } +#[sqlx::test(migrations = "./migrations")] +async fn upsert_persists_explicit_max_concurrent(pool: sqlx::PgPool) { + let user_id = common::seed_user(&pool).await; + let master = fixture_key(); + + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-test", Some(32)) + .await + .unwrap(); + + let cred = UserAnthropicKeyRepo::get_credential(&pool, user_id) + .await + .unwrap() + .unwrap(); + assert_eq!(cred.max_concurrent, 32); + + let status = UserAnthropicKeyRepo::status(&pool, user_id) + .await + .unwrap() + .unwrap(); + assert_eq!(status.max_concurrent, 32); +} + +#[sqlx::test(migrations = "./migrations")] +async fn upsert_without_cap_preserves_existing_value(pool: sqlx::PgPool) { + let user_id = common::seed_user(&pool).await; + let master = fixture_key(); + + // First write picks an explicit non-default cap. + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-first", Some(16)) + .await + .unwrap(); + + // Rotate the key without specifying the cap — the existing 16 must be + // preserved, *not* reset to the DB default of 8. + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-second", None) + .await + .unwrap(); + + let cred = UserAnthropicKeyRepo::get_credential(&pool, user_id) + .await + .unwrap() + .unwrap(); + assert_eq!( + cred.max_concurrent, 16, + "rotating the key without a new cap must keep the existing cap" + ); +} + +#[sqlx::test(migrations = "./migrations")] +async fn upsert_with_new_cap_overrides_existing_value(pool: sqlx::PgPool) { + let user_id = common::seed_user(&pool).await; + let master = fixture_key(); + + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-first", Some(16)) + .await + .unwrap(); + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-second", Some(4)) + .await + .unwrap(); + + let cred = UserAnthropicKeyRepo::get_credential(&pool, user_id) + .await + .unwrap() + .unwrap(); + assert_eq!(cred.max_concurrent, 4); +} + #[sqlx::test(migrations = "./migrations")] async fn delete_removes_row(pool: sqlx::PgPool) { let user_id = common::seed_user(&pool).await; let master = fixture_key(); - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-test") + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-test", None) .await .unwrap(); UserAnthropicKeyRepo::delete(&pool, user_id).await.unwrap(); - assert!(UserAnthropicKeyRepo::get_ciphertext(&pool, user_id) + assert!(UserAnthropicKeyRepo::get_credential(&pool, user_id) .await .unwrap() .is_none()); - assert!(UserAnthropicKeyRepo::configured_at(&pool, user_id) + assert!(UserAnthropicKeyRepo::status(&pool, user_id) .await .unwrap() .is_none()); @@ -151,7 +222,7 @@ async fn user_deletion_cascades_to_anthropic_key(pool: sqlx::PgPool) { let user_id = common::seed_user(&pool).await; let master = fixture_key(); - UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-test") + UserAnthropicKeyRepo::upsert(&pool, &master, user_id, "sk-ant-test", None) .await .unwrap(); sqlx::query("DELETE FROM users WHERE id = $1") @@ -160,7 +231,7 @@ async fn user_deletion_cascades_to_anthropic_key(pool: sqlx::PgPool) { .await .unwrap(); - assert!(UserAnthropicKeyRepo::get_ciphertext(&pool, user_id) + assert!(UserAnthropicKeyRepo::get_credential(&pool, user_id) .await .unwrap() .is_none()); diff --git a/web/src/routes/me/proxy/+page.svelte b/web/src/routes/me/proxy/+page.svelte index 9b07c1e..28a48a8 100644 --- a/web/src/routes/me/proxy/+page.svelte +++ b/web/src/routes/me/proxy/+page.svelte @@ -9,8 +9,13 @@ interface AnthropicKeyStatus { configured: boolean; configured_at: string | null; + max_concurrent: number | null; } + const DEFAULT_MAX_CONCURRENT = 8; + const MIN_MAX_CONCURRENT = 1; + const MAX_MAX_CONCURRENT = 256; + let status: AnthropicKeyStatus | null = $state(null); let loading = $state(true); let saving = $state(false); @@ -19,6 +24,7 @@ let error = $state(''); let success = $state(''); let newKey = $state(''); + let newMaxConcurrent: number = $state(DEFAULT_MAX_CONCURRENT); let copied = $state(false); const proxyBaseUrl = $derived( @@ -34,6 +40,11 @@ error = ''; try { status = await api.get('/api/v1/me/anthropic-key'); + // Pre-fill the form's cap with whatever's currently stored so a + // "just rotate the key" flow doesn't accidentally reset the cap. + if (status?.max_concurrent != null) { + newMaxConcurrent = status.max_concurrent; + } } catch (err) { error = err instanceof Error ? err.message : 'Failed to load proxy configuration'; } finally { @@ -41,19 +52,55 @@ } } + /// True when the user has either typed a new key or changed the cap + /// away from what's stored. The submit button is gated on this — it + /// prevents the user from submitting a no-op request. + const hasUnsavedChange = $derived.by(() => { + const keyTyped = newKey.trim().length > 0; + const capChanged = + status?.max_concurrent != null && newMaxConcurrent !== status.max_concurrent; + // When not configured yet, only a key counts — the cap on its own + // can't be the first write (server returns 400 in that case). + if (!status?.configured) return keyTyped; + return keyTyped || capChanged; + }); + async function handleSave(event: SubmitEvent) { event.preventDefault(); - if (!newKey.trim()) return; + if (!hasUnsavedChange) return; + // Defensive client-side bounds check. The server enforces the same + // range (DB CHECK + handler validation) but failing here gives a + // clearer error than a 400 from the API. + if ( + !Number.isInteger(newMaxConcurrent) || + newMaxConcurrent < MIN_MAX_CONCURRENT || + newMaxConcurrent > MAX_MAX_CONCURRENT + ) { + error = `Max concurrent must be a whole number between ${MIN_MAX_CONCURRENT} and ${MAX_MAX_CONCURRENT}.`; + return; + } + + // Build a minimal request body: only include `key` when the user + // actually typed one. Cap is always sent so the server picks up + // any change. + const body: { key?: string; max_concurrent: number } = { + max_concurrent: newMaxConcurrent + }; + const trimmedKey = newKey.trim(); + if (trimmedKey.length > 0) body.key = trimmedKey; + saving = true; error = ''; success = ''; try { - await api.put('/api/v1/me/anthropic-key', { key: newKey.trim() }); + await api.put('/api/v1/me/anthropic-key', body); newKey = ''; - success = 'Anthropic API key saved.'; + success = trimmedKey + ? 'Anthropic API key saved.' + : 'Concurrency cap updated.'; await loadStatus(); } catch (err) { - error = err instanceof Error ? err.message : 'Failed to save key'; + error = err instanceof Error ? err.message : 'Failed to save settings'; } finally { saving = false; } @@ -154,6 +201,11 @@ >last set {formatTimestamp(status.configured_at)} {/if} + {#if status.max_concurrent != null} + + · cap {status.max_concurrent} + + {/if} {:else}
@@ -168,7 +220,7 @@
+
+ + +

+ The proxy rejects further requests for this credential once this many are in + flight. Range {MIN_MAX_CONCURRENT}–{MAX_MAX_CONCURRENT}; default {DEFAULT_MAX_CONCURRENT}. + New value applies on the next proxy request; in-flight requests keep their + existing budget. +

+
+
- {#if status?.configured} {#if confirmingRemove}