diff --git a/Cargo.lock b/Cargo.lock index 6fb297aa..9132c8dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -813,7 +813,7 @@ dependencies = [ "rand 0.8.5", "rand_chacha 0.3.1", "recipher", - "reqwest", + "reqwest 0.12.15", "reqwest-middleware", "reqwest-retry", "rmp-serde", @@ -870,6 +870,7 @@ dependencies = [ "async-trait", "aws-lc-rs", "bigdecimal", + "blake3", "bytes", "chrono", "cipherstash-client", @@ -893,7 +894,7 @@ dependencies = [ "rust_decimal", "rustls", "rustls-pki-types", - "rustls-platform-verifier", + "rustls-platform-verifier 0.5.1", "serde", "serde_json", "socket2 0.5.8", @@ -925,6 +926,7 @@ dependencies = [ "hex", "postgres-types", "rand 0.9.0", + "reqwest 0.13.1", "rustls", "serde", "serde_json", @@ -1101,6 +1103,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.0" @@ -1500,6 +1512,15 @@ dependencies = [ "serde", ] +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "eql-mapper" version = "1.0.0" @@ -1953,13 +1974,14 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "1.6.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2", "http", "http-body", @@ -1967,6 +1989,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -1993,21 +2016,28 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" dependencies = [ + "base64", "bytes", "futures-channel", + "futures-core", "futures-util", "http", "http-body", "hyper", + "ipnet", + "libc", + "percent-encoding", "pin-project-lite", - "socket2 0.5.8", + "socket2 0.6.1", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -2228,6 +2258,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_ci" version = "1.2.0" @@ -3112,6 +3152,7 @@ version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" dependencies = [ + "aws-lc-rs", "bytes", "getrandom 0.3.2", "lru-slab", @@ -3402,6 +3443,44 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "reqwest" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e9018c9d814e5f30cc16a0f03271aeab3571e609612d9fe78c1aa8d11c2f62" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "mime", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "rustls-platform-verifier 0.6.2", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "reqwest-middleware" version = "0.3.3" @@ -3411,7 +3490,7 @@ dependencies = [ "anyhow", "async-trait", "http", - "reqwest", + "reqwest 0.12.15", "serde", "thiserror 1.0.69", "tower-service", @@ -3431,7 +3510,7 @@ dependencies = [ "http", "hyper", "parking_lot 0.11.2", - "reqwest", + "reqwest 0.12.15", "reqwest-middleware", "retry-policies", "tokio", @@ -3587,9 +3666,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.25" +version = "0.23.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c" +checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" dependencies = [ "aws-lc-rs", "log", @@ -3637,7 +3716,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5467026f437b4cb2a533865eaa73eb840019a0916f4b9ec563c6e617e086c9" dependencies = [ - "core-foundation", + "core-foundation 0.10.0", "core-foundation-sys", "jni", "log", @@ -3648,10 +3727,31 @@ dependencies = [ "rustls-webpki", "security-framework", "security-framework-sys", - "webpki-root-certs", + "webpki-root-certs 0.26.8", "windows-sys 0.59.0", ] +[[package]] +name = "rustls-platform-verifier" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +dependencies = [ + "core-foundation 0.10.0", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs 1.0.5", + "windows-sys 0.61.2", +] + [[package]] name = "rustls-platform-verifier-android" version = "0.1.1" @@ -3660,9 +3760,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.1" +version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ "aws-lc-rs", "ring", @@ -3720,12 +3820,12 @@ checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" [[package]] name = "security-framework" -version = "3.2.0" +version = "3.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" dependencies = [ "bitflags 2.9.0", - "core-foundation", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -3733,9 +3833,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.14.0" +version = "2.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" dependencies = [ "core-foundation-sys", "libc", @@ -4121,6 +4221,27 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.9.0", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tagptr" version = "0.2.0" @@ -4439,6 +4560,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags 2.9.0", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -4970,6 +5109,15 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.11" diff --git a/docs/SLOW_STATEMENTS.md b/docs/SLOW_STATEMENTS.md new file mode 100644 index 00000000..c337128a --- /dev/null +++ b/docs/SLOW_STATEMENTS.md @@ -0,0 +1,89 @@ +# Slow Statement Logging + +CipherStash Proxy includes built-in slow statement logging for troubleshooting performance issues. + +## Configuration + +Enable slow statement logging via environment variables: + +```bash +# Enable slow statement logging (required) +CS_LOG__SLOW_STATEMENTS=true + +# Optional: Set minimum duration threshold +# Default is 2000ms (2 seconds) - only set this if you want a different threshold +CS_LOG__SLOW_STATEMENT_MIN_DURATION_MS=500 + +# Optional: Set log level (default: warn when enabled) +CS_LOG__SLOW_STATEMENTS_LEVEL=warn + +# Recommended: Use structured logging for parsing +CS_LOG__FORMAT=structured +``` + +## Slow Statement Logs + +When a statement exceeds the threshold, the proxy logs a detailed breakdown: + +```json +{ + "client_id": 1, + "duration_ms": 10500, + "statement_type": "INSERT", + "protocol": "extended", + "encrypted": true, + "encrypted_values_count": 3, + "param_bytes": 1024, + "query_fingerprint": "a1b2c3d4", + "keyset_id": "uuid", + "mapping_disabled": false, + "breakdown": { + "parse_ms": 5, + "encrypt_ms": 450, + "server_write_ms": 12, + "server_wait_ms": 9800, + "server_response_ms": 233 + } +} +``` + +### Query Fingerprints + +**Note:** Query fingerprints are ephemeral and instance-local. Each proxy instance generates a unique random key at startup used to compute `query_fingerprint` values. This means fingerprints will change when the proxy restarts and cannot be correlated across different proxy instances. This is intentional for security (prevents dictionary attacks on query patterns). Use fingerprints for correlation within a single proxy instance's runtime only. + +## Prometheus Metrics + +### Labeled Histograms + +Duration histograms now include labels for filtering: + +- `statement_type`: insert, update, delete, select, other +- `protocol`: simple, extended +- `mapped`: true, false +- `multi_statement`: true, false + +Example queries: +```promql +# Average INSERT duration +histogram_quantile(0.5, rate(cipherstash_proxy_statements_session_duration_seconds_bucket{statement_type="insert"}[5m])) + +# Compare encrypted vs passthrough +histogram_quantile(0.99, rate(cipherstash_proxy_statements_session_duration_seconds_bucket{mapped="true"}[5m])) +``` + +### ZeroKMS Cipher Init + +``` +cipherstash_proxy_keyset_cipher_init_duration_seconds +``` + +Measures time for cipher initialization including ZeroKMS network call. High values indicate ZeroKMS connectivity issues. + +## Interpreting Results + +| Symptom | Likely Cause | +|---------|--------------| +| High `encrypt_ms` | ZeroKMS latency or large payload | +| High `server_wait_ms` | Database latency | +| High `cipher_init_duration` | ZeroKMS cold start or network | +| High `parse_ms` | Complex SQL or schema lookup | diff --git a/packages/cipherstash-proxy-integration/Cargo.toml b/packages/cipherstash-proxy-integration/Cargo.toml index 44e3d9c0..5a1f45ce 100644 --- a/packages/cipherstash-proxy-integration/Cargo.toml +++ b/packages/cipherstash-proxy-integration/Cargo.toml @@ -27,3 +27,4 @@ tokio-postgres-rustls = "0.13.0" tracing = { workspace = true } tracing-subscriber = { workspace = true } uuid = { version = "1.11.0", features = ["serde", "v4"] } +reqwest = { version = "0.13", features = ["rustls"] } diff --git a/packages/cipherstash-proxy-integration/src/common.rs b/packages/cipherstash-proxy-integration/src/common.rs index c4e085ad..79704876 100644 --- a/packages/cipherstash-proxy-integration/src/common.rs +++ b/packages/cipherstash-proxy-integration/src/common.rs @@ -12,6 +12,7 @@ use tracing::info; use tracing_subscriber::{filter::Directive, EnvFilter, FmtSubscriber}; pub const PROXY: u16 = 6432; +pub const PROXY_METRICS_PORT: u16 = 9930; pub const PG_PORT: u16 = 5532; pub const PG_TLS_PORT: u16 = 5617; diff --git a/packages/cipherstash-proxy-integration/src/diagnostics.rs b/packages/cipherstash-proxy-integration/src/diagnostics.rs new file mode 100644 index 00000000..72fcedf7 --- /dev/null +++ b/packages/cipherstash-proxy-integration/src/diagnostics.rs @@ -0,0 +1,127 @@ +#[cfg(test)] +mod tests { + use crate::common::{clear, connect_with_tls, PROXY, PROXY_METRICS_PORT}; + + /// Maximum number of retry attempts for fetching metrics. + /// 5 retries with 200ms delay gives ~1 second total wait time, + /// sufficient for Prometheus scrape interval in CI environments. + const METRICS_FETCH_MAX_RETRIES: u32 = 5; + + /// Delay between retry attempts in milliseconds. + /// 200ms provides a reasonable balance between responsiveness and allowing + /// sufficient time for metrics to be published by the Prometheus client. + const METRICS_FETCH_RETRY_DELAY_MS: u64 = 200; + + /// Fetch metrics with retry logic to handle CI timing variability. + async fn fetch_metrics_with_retry(max_retries: u32, delay_ms: u64) -> String { + let url = format!("http://localhost:{}/metrics", PROXY_METRICS_PORT); + let mut last_error = None; + + for attempt in 0..max_retries { + if attempt > 0 { + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + + match reqwest::get(&url).await { + Ok(response) => match response.text().await { + Ok(body) => return body, + Err(e) => last_error = Some(format!("Failed to read response: {}", e)), + }, + Err(e) => last_error = Some(format!("Failed to fetch metrics: {}", e)), + } + } + + panic!( + "Failed to fetch metrics after {} retries: {}", + max_retries, + last_error.unwrap_or_else(|| "unknown error".to_string()) + ); + } + + #[tokio::test] + async fn metrics_include_statement_labels() { + let client = connect_with_tls(PROXY).await; + + clear().await; + + // Insert a value to generate metrics + client + .execute( + "INSERT INTO plaintext (id, plaintext) VALUES ($1, $2)", + &[&1i64, &"metrics test"], + ) + .await + .unwrap(); + + // Select a value to generate metrics + let _rows = client + .query("SELECT * FROM plaintext LIMIT 1", &[]) + .await + .unwrap(); + + // Fetch metrics with retry logic for CI robustness + let body = + fetch_metrics_with_retry(METRICS_FETCH_MAX_RETRIES, METRICS_FETCH_RETRY_DELAY_MS).await; + + // Assert that the metrics include the expected labels + assert!( + body.contains("statement_type=\"insert\""), + "Metrics should include insert statement_type label" + ); + assert!( + body.contains("statement_type=\"select\""), + "Metrics should include select statement_type label" + ); + assert!( + body.contains("multi_statement=\"false\""), + "Metrics should include multi_statement=false label" + ); + } + + #[tokio::test] + async fn slow_statement_metrics_and_logs() { + let client = connect_with_tls(PROXY).await; + + clear().await; + + // Execute a query that takes longer than the default 2s threshold + // We use pg_sleep(2.1) to ensure it's considered slow + client.query("SELECT pg_sleep(2.1)", &[]).await.unwrap(); + + // Fetch metrics with retry logic + let body = + fetch_metrics_with_retry(METRICS_FETCH_MAX_RETRIES, METRICS_FETCH_RETRY_DELAY_MS).await; + + // Assert that the slow statements counter is present and non-zero + assert!( + body.contains("cipherstash_proxy_slow_statements_total"), + "Metrics should include slow statement counter. Found: {}", + body + ); + + // Extract the value to ensure it's at least 1 + let slow_statements_line = body + .lines() + .find(|l| l.starts_with("cipherstash_proxy_slow_statements_total")) + .expect("Slow statements counter line should exist"); + let slow_statements_count: u64 = slow_statements_line + .split_whitespace() + .last() + .expect("Should have a value") + .parse() + .expect("Should be a valid number"); + + assert!( + slow_statements_count >= 1, + "Slow statements count should be at least 1, found {}", + slow_statements_count + ); + + // Verify that duration histograms also reflect the slow query + // We check for _count as it works for both histograms and summaries + assert!( + body.contains("cipherstash_proxy_statements_session_duration_seconds"), + "Metrics should include session duration metrics" + ); + } +} diff --git a/packages/cipherstash-proxy-integration/src/lib.rs b/packages/cipherstash-proxy-integration/src/lib.rs index da1f0a47..e88d7cb8 100644 --- a/packages/cipherstash-proxy-integration/src/lib.rs +++ b/packages/cipherstash-proxy-integration/src/lib.rs @@ -1,5 +1,6 @@ mod common; mod decrypt; +mod diagnostics; mod disable_mapping; mod empty_result; mod encryption_sanity; diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index b47033da..8c920404 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" async-trait = "0.1" aws-lc-rs = "1.13.3" bigdecimal = { version = "0.4.6", features = ["serde-json"] } +blake3 = "1" arc-swap = "1.7.1" bytes = { version = "1.9", default-features = false } chrono = { version = "0.4.39", features = ["clock"] } diff --git a/packages/cipherstash-proxy/src/config/log.rs b/packages/cipherstash-proxy/src/config/log.rs index b5f61011..6f5d10ac 100644 --- a/packages/cipherstash-proxy/src/config/log.rs +++ b/packages/cipherstash-proxy/src/config/log.rs @@ -20,6 +20,14 @@ pub struct LogConfig { #[serde(default = "LogConfig::default_log_level")] pub level: LogLevel, + /// Enable slow statement logging + #[serde(default)] + pub slow_statements: bool, + + /// Threshold in milliseconds for slow statement logging (default: 2000ms) + #[serde(default = "LogConfig::default_slow_statement_min_duration_ms")] + pub slow_statement_min_duration_ms: u64, + // All log target levels - automatically generated and flattened from LogTargetLevels // To add a new target: just add it to the define_log_targets! macro in log/targets.rs #[serde(flatten)] @@ -90,6 +98,8 @@ impl LogConfig { output: LogConfig::default_log_output(), ansi_enabled: LogConfig::default_ansi_enabled(), level, + slow_statements: false, + slow_statement_min_duration_ms: Self::default_slow_statement_min_duration_ms(), // All target levels automatically set using generated LogTargetLevels targets: LogTargetLevels::with_level(level), } @@ -114,6 +124,15 @@ impl LogConfig { pub const fn default_log_level() -> LogLevel { LogLevel::Info } + + /// Default threshold for slow statement logging (2 seconds). + /// + /// This value represents a reasonable baseline for identifying slow queries in most + /// PostgreSQL workloads. Queries exceeding this duration are likely candidates for + /// optimization. Operators can adjust via CS_LOG__SLOW_STATEMENT_MIN_DURATION_MS. + pub const fn default_slow_statement_min_duration_ms() -> u64 { + 2000 + } } #[cfg(test)] diff --git a/packages/cipherstash-proxy/src/config/tandem.rs b/packages/cipherstash-proxy/src/config/tandem.rs index ef4ac427..a88a3b0d 100644 --- a/packages/cipherstash-proxy/src/config/tandem.rs +++ b/packages/cipherstash-proxy/src/config/tandem.rs @@ -302,7 +302,15 @@ impl TandemConfig { DEFAULT_THREAD_STACK_SIZE } + /// Returns true if slow statement logging is enabled + pub fn slow_statements_enabled(&self) -> bool { + self.log.slow_statements + } + /// Returns the slow statement minimum duration as a Duration + pub fn slow_statement_min_duration(&self) -> std::time::Duration { + std::time::Duration::from_millis(self.log.slow_statement_min_duration_ms) + } #[cfg(test)] pub fn for_testing() -> Self { Self { @@ -855,4 +863,47 @@ mod tests { }) }) } + + #[test] + fn slow_statement_accessors() { + with_no_cs_vars(|| { + temp_env::with_vars( + [ + ("CS_LOG__SLOW_STATEMENTS", Some("true")), + ("CS_LOG__SLOW_STATEMENT_MIN_DURATION_MS", Some("1000")), + ], + || { + let config = + TandemConfig::build_path("tests/config/cipherstash-proxy-test.toml") + .unwrap(); + + assert!(config.slow_statements_enabled()); + assert_eq!( + config.slow_statement_min_duration(), + std::time::Duration::from_millis(1000) + ); + }, + ); + }); + } + + #[test] + fn slow_statements_config() { + with_no_cs_vars(|| { + temp_env::with_vars( + [ + ("CS_LOG__SLOW_STATEMENTS", Some("true")), + ("CS_LOG__SLOW_STATEMENT_MIN_DURATION_MS", Some("500")), + ], + || { + let config = + TandemConfig::build_path("tests/config/cipherstash-proxy-test.toml") + .unwrap(); + + assert!(config.log.slow_statements); + assert_eq!(config.log.slow_statement_min_duration_ms, 500); + }, + ); + }); + } } diff --git a/packages/cipherstash-proxy/src/log/mod.rs b/packages/cipherstash-proxy/src/log/mod.rs index 725a0384..e57390f7 100644 --- a/packages/cipherstash-proxy/src/log/mod.rs +++ b/packages/cipherstash-proxy/src/log/mod.rs @@ -16,7 +16,7 @@ use tracing_subscriber::{ // All targets are now defined in the targets module using the define_log_targets! macro. pub use targets::{ AUTHENTICATION, CONFIG, CONTEXT, DECRYPT, DEVELOPMENT, ENCODING, ENCRYPT, ENCRYPT_CONFIG, - KEYSET, MAPPER, MIGRATE, PROTOCOL, PROXY, SCHEMA, + KEYSET, MAPPER, MIGRATE, PROTOCOL, PROXY, SCHEMA, SLOW_STATEMENTS, }; static INIT: Once = Once::new(); @@ -112,6 +112,8 @@ mod tests { output: LogConfig::default_log_output(), ansi_enabled: LogConfig::default_ansi_enabled(), level: LogLevel::Info, + slow_statements: false, + slow_statement_min_duration_ms: LogConfig::default_slow_statement_min_duration_ms(), targets: LogTargetLevels { development_level: LogLevel::Info, authentication_level: LogLevel::Debug, @@ -127,6 +129,7 @@ mod tests { mapper_level: LogLevel::Info, schema_level: LogLevel::Info, config_level: LogLevel::Info, + slow_statements_level: LogLevel::Info, }, }; diff --git a/packages/cipherstash-proxy/src/log/targets.rs b/packages/cipherstash-proxy/src/log/targets.rs index e45c0988..f959c56b 100644 --- a/packages/cipherstash-proxy/src/log/targets.rs +++ b/packages/cipherstash-proxy/src/log/targets.rs @@ -78,4 +78,5 @@ define_log_targets!( (PROXY, proxy_level), (MAPPER, mapper_level), (SCHEMA, schema_level), + (SLOW_STATEMENTS, slow_statements_level), ); diff --git a/packages/cipherstash-proxy/src/postgresql/backend.rs b/packages/cipherstash-proxy/src/postgresql/backend.rs index b5113aec..196fb1a8 100644 --- a/packages/cipherstash-proxy/src/postgresql/backend.rs +++ b/packages/cipherstash-proxy/src/postgresql/backend.rs @@ -157,6 +157,7 @@ where ) .await?; let read_duration = read_start.elapsed(); + self.context.record_execute_server_timing(read_duration); let sent: u64 = bytes.len() as u64; counter!(SERVER_BYTES_RECEIVED_TOTAL).increment(sent); @@ -362,7 +363,11 @@ where let sent: u64 = bytes.len() as u64; counter!(CLIENTS_BYTES_SENT_TOTAL).increment(sent); + let start = Instant::now(); self.client_sender.send(bytes)?; + let duration = start.elapsed(); + self.context.add_client_write_duration_for_execute(duration); + Ok(()) } @@ -463,7 +468,12 @@ where counter!(DECRYPTION_ERROR_TOTAL).increment(1); })?; - // Avoid the iter calculation if we can + let duration = Instant::now().duration_since(start); + + // Always record for slow-statement diagnostics + self.context.add_decrypt_duration_for_execute(duration); + + // Prometheus metrics remain gated if self.context.prometheus_enabled() { let decrypted_count = plaintexts @@ -472,8 +482,6 @@ where counter!(DECRYPTION_REQUESTS_TOTAL).increment(1); counter!(DECRYPTED_VALUES_TOTAL).increment(decrypted_count); - - let duration = Instant::now().duration_since(start); histogram!(DECRYPTION_DURATION_SECONDS).record(duration); } diff --git a/packages/cipherstash-proxy/src/postgresql/context/mod.rs b/packages/cipherstash-proxy/src/postgresql/context/mod.rs index f5af5d92..29df3cdb 100644 --- a/packages/cipherstash-proxy/src/postgresql/context/mod.rs +++ b/packages/cipherstash-proxy/src/postgresql/context/mod.rs @@ -1,8 +1,9 @@ pub mod column; +pub mod phase_timing; pub mod portal; pub mod statement; - -pub use self::{portal::Portal, statement::Statement}; +pub mod statement_metadata; +pub use self::{phase_timing::PhaseTiming, portal::Portal, statement::Statement}; use super::{ column_mapper::ColumnMapper, messages::{describe::Describe, Name, Target}, @@ -11,17 +12,25 @@ use super::{ use crate::{ config::TandemConfig, error::{EncryptError, Error}, - log::CONTEXT, - prometheus::{STATEMENTS_EXECUTION_DURATION_SECONDS, STATEMENTS_SESSION_DURATION_SECONDS}, + log::{CONTEXT, SLOW_STATEMENTS}, + prometheus::{ + SLOW_STATEMENTS_TOTAL, STATEMENTS_EXECUTION_DURATION_SECONDS, + STATEMENTS_SESSION_DURATION_SECONDS, + }, proxy::{EncryptConfig, EncryptionService, ReloadCommand, ReloadSender}, }; use cipherstash_client::IdentifiedBy; use eql_mapper::{Schema, TableResolver}; -use metrics::histogram; +use metrics::{counter, histogram}; +use serde_json::json; use sqltk::parser::ast::{Expr, Ident, ObjectName, ObjectNamePart, Set, Value, ValueWithSpan}; +pub use statement_metadata::StatementMetadata; use std::{ collections::{HashMap, VecDeque}, - sync::{Arc, LazyLock, RwLock}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, LazyLock, RwLock, + }, time::{Duration, Instant}, }; use tokio::sync::oneshot; @@ -33,6 +42,9 @@ type ExecuteQueue = Queue; type SessionMetricsQueue = Queue; type PortalQueue = Queue>; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct SessionId(u64); + #[derive(Clone, Debug, PartialEq)] pub struct KeysetIdentifier(pub IdentifiedBy); @@ -54,6 +66,7 @@ where reload_sender: ReloadSender, column_mapper: ColumnMapper, statements: Arc>>>, + statement_sessions: Arc>>, portals: Arc>>, describe: Arc>, execute: Arc>, @@ -62,42 +75,89 @@ where table_resolver: Arc, unsafe_disable_mapping: bool, keyset_id: Arc>>, + session_id_counter: Arc, } +/// Context for tracking an in-flight Execute operation. +/// +/// Timing data is accumulated here during backend message processing because +/// the backend operates on the execute queue rather than having direct access +/// to the session metrics queue. On completion via `complete_execution()`, +/// timing is transferred to the associated SessionMetricsContext. #[derive(Clone, Debug)] pub struct ExecuteContext { name: Name, start: Instant, + session_id: Option, + /// Server wait duration (time to first response byte). + /// Accumulated here during execution, transferred to SessionMetricsContext on completion. + server_wait_duration: Option, + /// Server response duration (time spent receiving response data after first byte). + /// Accumulated here during execution, transferred to SessionMetricsContext on completion. + server_response_duration: Duration, } impl ExecuteContext { - fn new(name: Name) -> ExecuteContext { + fn new(name: Name, session_id: Option) -> ExecuteContext { ExecuteContext { name, start: Instant::now(), + session_id, + server_wait_duration: None, + server_response_duration: Duration::from_secs(0), } } fn duration(&self) -> Duration { Instant::now().duration_since(self.start) } + + fn record_server_wait_or_add_response(&mut self, duration: Duration) { + if self.server_wait_duration.is_none() { + self.server_wait_duration = Some(duration); + } else { + self.server_response_duration += duration; + } + } + + fn server_wait_duration(&self) -> Option { + self.server_wait_duration + } + + fn server_response_duration(&self) -> Duration { + self.server_response_duration + } + + fn session_id(&self) -> Option { + self.session_id + } } #[derive(Clone, Debug)] pub struct SessionMetricsContext { + id: SessionId, start: Instant, + pub phase_timing: PhaseTiming, + pub metadata: StatementMetadata, } impl SessionMetricsContext { - fn new() -> SessionMetricsContext { + fn new(id: SessionId) -> SessionMetricsContext { SessionMetricsContext { + id, start: Instant::now(), + phase_timing: PhaseTiming::new(), + metadata: StatementMetadata::new(), } } fn duration(&self) -> Duration { Instant::now().duration_since(self.start) } + + fn id(&self) -> SessionId { + self.id + } } #[derive(Clone, Debug)] @@ -121,6 +181,7 @@ where Context { statements: Arc::new(RwLock::new(HashMap::new())), + statement_sessions: Arc::new(RwLock::new(HashMap::new())), portals: Arc::new(RwLock::new(HashMap::new())), describe: Arc::new(RwLock::from(Queue::new())), execute: Arc::new(RwLock::from(Queue::new())), @@ -135,6 +196,7 @@ where reload_sender, unsafe_disable_mapping: false, keyset_id: Arc::new(RwLock::new(None)), + session_id_counter: Arc::new(AtomicU64::new(1)), } } @@ -151,16 +213,79 @@ where let _ = self.describe.write().map(|mut queue| queue.complete()); } - pub fn start_session(&mut self) { - let ctx = SessionMetricsContext::new(); + pub fn start_session(&mut self) -> SessionId { + let id = SessionId(self.session_id_counter.fetch_add(1, Ordering::Relaxed)); + let ctx = SessionMetricsContext::new(id); let _ = self.session_metrics.write().map(|mut queue| queue.add(ctx)); + id } pub fn finish_session(&mut self) { debug!(target: CONTEXT, client_id = self.client_id, msg = "Session Metrics finished"); if let Some(session) = self.get_session_metrics() { - histogram!(STATEMENTS_SESSION_DURATION_SECONDS).record(session.duration()); + let duration = session.duration(); + let metadata = &session.metadata; + + // Get labels for metrics + let statement_type = metadata + .statement_type + .map(|t| t.as_label()) + .unwrap_or("unknown"); + let protocol = metadata.protocol.map(|p| p.as_label()).unwrap_or("unknown"); + let mapped = if metadata.encrypted { "true" } else { "false" }; + let multi_statement = if metadata.multi_statement { + "true" + } else { + "false" + }; + + // Record with labels + histogram!( + STATEMENTS_SESSION_DURATION_SECONDS, + "statement_type" => statement_type, + "protocol" => protocol, + "mapped" => mapped, + "multi_statement" => multi_statement + ) + .record(duration); + + // Log slow statements when enabled + if self.config.slow_statements_enabled() + && duration > self.config.slow_statement_min_duration() + { + let timing = &session.phase_timing; + + // Increment slow statements counter + counter!(SLOW_STATEMENTS_TOTAL).increment(1); + + let breakdown = json!({ + "parse_ms": timing.parse_duration.map(|d| d.as_millis()), + "encrypt_ms": timing.encrypt_duration.map(|d| d.as_millis()), + "server_write_ms": timing.server_write_duration.map(|d| d.as_millis()), + "server_wait_ms": timing.server_wait_duration.map(|d| d.as_millis()), + "server_response_ms": timing.server_response_duration.map(|d| d.as_millis()), + "client_write_ms": timing.client_write_duration.map(|d| d.as_millis()), + "decrypt_ms": timing.decrypt_duration.map(|d| d.as_millis()), + }); + + warn!( + target: SLOW_STATEMENTS, + client_id = self.client_id, + duration_ms = duration.as_millis() as u64, + statement_type = statement_type, + protocol = protocol, + encrypted = metadata.encrypted, + multi_statement = metadata.multi_statement, + encrypted_values_count = metadata.encrypted_values_count, + param_bytes = metadata.param_bytes, + query_fingerprint = ?metadata.query_fingerprint, + keyset_id = ?self.keyset_identifier(), + mapping_disabled = self.mapping_disabled(), + breakdown = %breakdown, + msg = "Slow statement detected" + ); + } } let _ = self @@ -169,30 +294,81 @@ where .map(|mut queue| queue.complete()); } - pub fn set_execute(&mut self, name: Name) { + pub fn set_execute(&mut self, name: Name, session_id: Option) { debug!(target: CONTEXT, client_id = self.client_id, execute = ?name); - let ctx = ExecuteContext::new(name); + let ctx = ExecuteContext::new(name, session_id); let _ = self.execute.write().map(|mut queue| queue.add(ctx)); } + /// Set execute state for portal, looking up session ID internally. + pub fn set_execute_for_portal(&mut self, name: Name) { + let session_id = self.get_portal_session_id(&name); + self.set_execute(name, session_id); + } + + /// Marks the current Execution as Complete. + /// + /// Transfers accumulated timing data from ExecuteContext to SessionMetricsContext.phase_timing: + /// - `server_wait_duration` (time to first response byte) is recorded to the session + /// - `server_response_duration` (time receiving response data) is added to the session /// - /// Marks the current Execution as Complete + /// This two-phase timing pattern exists because the backend operates on the execute queue + /// rather than having direct access to the session. Timing is accumulated in ExecuteContext + /// during message processing, then transferred to the correct SessionMetricsContext here. /// - /// If the associated portal is Unnamed, it is closed + /// If the associated portal is Unnamed, it is closed. /// /// From the PostgreSQL Extended Query docs: /// If successfully created, a named portal object lasts till the end of the current transaction, unless explicitly destroyed. /// An unnamed portal is destroyed at the end of the transaction, or as soon as the next Bind statement specifying the unnamed portal as destination is issued /// /// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY - /// pub fn complete_execution(&mut self) { debug!(target: CONTEXT, client_id = self.client_id, msg = "Execute complete"); if let Some(execute) = self.get_execute() { - histogram!(STATEMENTS_EXECUTION_DURATION_SECONDS).record(execute.duration()); + if let Some(session_id) = execute.session_id() { + if let Some(wait) = execute.server_wait_duration() { + self.record_server_wait_duration(session_id, wait); + } + let response = execute.server_response_duration(); + if !response.is_zero() { + self.add_server_response_duration(session_id, response); + } + } + + // Get labels from current session metadata + let (statement_type, protocol, mapped, multi_statement) = + if let Some(session) = self.get_session_metrics() { + let metadata = &session.metadata; + ( + metadata + .statement_type + .map(|t| t.as_label()) + .unwrap_or("unknown"), + metadata.protocol.map(|p| p.as_label()).unwrap_or("unknown"), + if metadata.encrypted { "true" } else { "false" }, + if metadata.multi_statement { + "true" + } else { + "false" + }, + ) + } else { + ("unknown", "unknown", "false", "false") + }; + + histogram!( + STATEMENTS_EXECUTION_DURATION_SECONDS, + "statement_type" => statement_type, + "protocol" => protocol, + "mapped" => mapped, + "multi_statement" => multi_statement + ) + .record(execute.duration()); + if execute.name.is_unnamed() { self.close_portal(&execute.name); } @@ -216,6 +392,17 @@ where .statements .write() .map(|mut guarded| guarded.remove(name)); + + let _ = self + .statement_sessions + .write() + .map(|mut guarded| guarded.remove(name)); + } + + /// Close both statement and its associated portal. + pub fn close_statement_and_portal(&mut self, name: &Name) { + self.close_portal(name); + self.close_statement(name); } pub fn add_portal(&mut self, name: Name, portal: Portal) { @@ -234,6 +421,36 @@ where statements.get(name).cloned() } + pub fn set_statement_session(&mut self, name: Name, session_id: SessionId) { + let _ = self + .statement_sessions + .write() + .map(|mut sessions| sessions.insert(name, session_id)); + } + + pub fn get_statement_session(&self, name: &Name) -> Option { + let sessions = self.statement_sessions.read().ok()?; + sessions.get(name).copied() + } + + /// Get session for statement, falling back to latest session with warning log. + pub fn get_statement_session_or_latest(&self, name: &Name) -> Option { + if let Some(id) = self.get_statement_session(name) { + return Some(id); + } + + let fallback = self.latest_session_id(); + if fallback.is_some() { + warn!( + target: CONTEXT, + client_id = self.client_id, + prepared_statement = %name.as_str(), + msg = "Session lookup failed for prepared statement, using latest session" + ); + } + fallback + } + /// /// Close the portal identified by `name` /// Portal is removed from queue @@ -264,10 +481,17 @@ where match portal.as_ref() { Portal::Encrypted { statement, .. } => Some(statement.clone()), - Portal::Passthrough => None, + Portal::Passthrough { .. } => None, } } + pub fn get_portal_session_id(&self, name: &Name) -> Option { + let portals = self.portals.read().ok()?; + let queue = portals.get(name)?; + let portal = queue.next()?; + portal.session_id() + } + pub fn get_statement_for_row_decription(&self) -> Option> { if let Some(statement) = self.get_statement_from_describe() { return Some(statement.clone()); @@ -546,6 +770,13 @@ where debug!(target: CONTEXT, msg = "Database schema reloaded", ?response); } + /// Reload schema if it has changed since last check. + pub async fn reload_schema_if_changed(&self) { + if self.schema_changed() { + self.reload_schema().await; + } + } + pub fn is_passthrough(&self) -> bool { self.encrypt_config.is_empty() || self.config.mapping_disabled() } @@ -632,6 +863,141 @@ where pub fn config(&self) -> &crate::config::TandemConfig { &self.config } + + fn with_session_metrics_mut(&mut self, session_id: SessionId, f: F) + where + F: FnOnce(&mut SessionMetricsContext), + { + if let Ok(mut queue) = self.session_metrics.write() { + if let Some(session) = queue + .queue + .iter_mut() + .find(|session| session.id() == session_id) + { + f(session); + } + } + } + + pub fn latest_session_id(&self) -> Option { + let queue = self.session_metrics.read().ok()?; + queue.queue.back().map(|session| session.id()) + } + + /// Record parse phase duration for the session (first write wins) + pub fn record_parse_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.record_parse(duration); + }); + } + + /// Add encrypt phase duration for the session (accumulate) + pub fn add_encrypt_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.add_encrypt(duration); + }); + } + + /// Record server write phase duration + pub fn record_server_write_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.record_server_write(duration); + }); + } + + /// Add server write phase duration (accumulate) + pub fn add_server_write_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.add_server_write(duration); + }); + } + + /// Record server wait phase duration (time to first response byte) + pub fn record_server_wait_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.record_server_wait(duration); + }); + } + + /// Record server response phase duration + pub fn record_server_response_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.record_server_response(duration); + }); + } + + /// Add server response phase duration (accumulate) + pub fn add_server_response_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.add_server_response(duration); + }); + } + + /// Record client write phase duration + pub fn record_client_write_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.record_client_write(duration); + }); + } + + /// Add client write phase duration (accumulate) + pub fn add_client_write_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.add_client_write(duration); + }); + } + + /// Add decrypt phase duration (accumulate) + pub fn add_decrypt_duration(&mut self, session_id: SessionId, duration: Duration) { + self.with_session_metrics_mut(session_id, |session| { + session.phase_timing.add_decrypt(duration); + }); + } + + /// Update statement metadata for a session + pub fn update_statement_metadata(&mut self, session_id: SessionId, f: F) + where + F: FnOnce(&mut StatementMetadata), + { + self.with_session_metrics_mut(session_id, |session| { + f(&mut session.metadata); + }); + } + + /// Update statement metadata if session ID is present, no-op otherwise. + pub fn with_session(&mut self, session_id: Option, f: F) + where + F: FnOnce(&mut SessionMetricsContext), + { + if let Some(sid) = session_id { + self.with_session_metrics_mut(sid, f); + } + } + + /// Record server wait for first response; otherwise accumulate response time for the current execute + pub fn record_execute_server_timing(&mut self, duration: Duration) { + if let Ok(mut queue) = self.execute.write() { + if let Some(execute) = queue.current_mut() { + execute.record_server_wait_or_add_response(duration); + } + } + } + + /// Add decrypt phase duration for the current execute session (if any) + pub fn add_decrypt_duration_for_execute(&mut self, duration: Duration) { + let session_id = self.get_execute().and_then(|execute| execute.session_id()); + if let Some(session_id) = session_id { + self.add_decrypt_duration(session_id, duration); + } + } + + /// Add client write duration for the current execute session (if any) + pub fn add_client_write_duration_for_execute(&mut self, duration: Duration) { + let session_id = self.get_execute().and_then(|execute| execute.session_id()); + if let Some(session_id) = session_id { + self.add_client_write_duration(session_id, duration); + } + } } impl Queue { @@ -652,6 +1018,11 @@ impl Queue { pub fn add(&mut self, item: T) { self.queue.push_back(item); } + + /// Get mutable reference to the current (first) item in the queue + pub fn current_mut(&mut self) -> Option<&mut T> { + self.queue.front_mut() + } } #[cfg(test)] @@ -727,7 +1098,7 @@ mod tests { } fn portal(statement: &Arc) -> Portal { - Portal::encrypted_with_format_codes(statement.clone(), vec![]) + Portal::encrypted_with_format_codes(statement.clone(), vec![], None) } fn get_statement(portal: Arc) -> Arc { @@ -781,7 +1152,7 @@ mod tests { context.add_portal(portal_name.clone(), portal(&statement)); // Add statement name to execute context - context.set_execute(portal_name.clone()); + context.set_execute(portal_name.clone(), None); // Portal statement should be the right statement let portal = context.get_portal_from_execute().unwrap(); @@ -827,8 +1198,8 @@ mod tests { context.add_portal(portal_name.clone(), portal(&statement_2)); // Execute both portals - context.set_execute(portal_name.clone()); - context.set_execute(portal_name.clone()); + context.set_execute(portal_name.clone(), None); + context.set_execute(portal_name.clone(), None); // Portal should point to first statement let portal = context.get_portal_from_execute().unwrap(); @@ -880,9 +1251,9 @@ mod tests { context.add_portal(portal_name_3.clone(), portal(&statement_3)); // Add portals to execute context - context.set_execute(portal_name_1.clone()); - context.set_execute(portal_name_2.clone()); - context.set_execute(portal_name_3.clone()); + context.set_execute(portal_name_1.clone(), None); + context.set_execute(portal_name_2.clone(), None); + context.set_execute(portal_name_3.clone(), None); // Multiple calls return the portal for the first Execution context let portal = context.get_portal_from_execute().unwrap(); diff --git a/packages/cipherstash-proxy/src/postgresql/context/phase_timing.rs b/packages/cipherstash-proxy/src/postgresql/context/phase_timing.rs new file mode 100644 index 00000000..8385023e --- /dev/null +++ b/packages/cipherstash-proxy/src/postgresql/context/phase_timing.rs @@ -0,0 +1,188 @@ +use std::time::{Duration, Instant}; + +/// Tracks timing for individual phases of statement processing +#[derive(Clone, Debug, Default)] +pub struct PhaseTiming { + /// SQL parsing and type-checking time + pub parse_duration: Option, + /// Encryption operation time (includes ZeroKMS network) + pub encrypt_duration: Option, + /// Time to write to PostgreSQL server + pub server_write_duration: Option, + /// Time from server write to first response byte + pub server_wait_duration: Option, + /// Time to receive complete server response + pub server_response_duration: Option, + /// Time to write response to client + pub client_write_duration: Option, + /// Decryption operation time + pub decrypt_duration: Option, +} + +impl PhaseTiming { + pub fn new() -> Self { + Self::default() + } + + /// Record parse phase duration (first write wins) + pub fn record_parse(&mut self, duration: Duration) { + self.parse_duration.get_or_insert(duration); + } + + /// Add parse duration (accumulate) + pub fn add_parse(&mut self, duration: Duration) { + self.parse_duration = Some(self.parse_duration.unwrap_or_default() + duration); + } + + /// Record encrypt phase duration (first write wins) + pub fn record_encrypt(&mut self, duration: Duration) { + self.encrypt_duration.get_or_insert(duration); + } + + /// Add encrypt duration (accumulate) + pub fn add_encrypt(&mut self, duration: Duration) { + self.encrypt_duration = Some(self.encrypt_duration.unwrap_or_default() + duration); + } + + /// Record server write phase duration (first write wins) + pub fn record_server_write(&mut self, duration: Duration) { + self.server_write_duration.get_or_insert(duration); + } + + /// Add server write duration (accumulate) + pub fn add_server_write(&mut self, duration: Duration) { + self.server_write_duration = + Some(self.server_write_duration.unwrap_or_default() + duration); + } + + /// Record server wait phase duration (first byte latency, first write wins) + pub fn record_server_wait(&mut self, duration: Duration) { + self.server_wait_duration.get_or_insert(duration); + } + + /// Record server response phase duration (first write wins) + pub fn record_server_response(&mut self, duration: Duration) { + self.server_response_duration.get_or_insert(duration); + } + + /// Add server response duration (accumulate) + pub fn add_server_response(&mut self, duration: Duration) { + self.server_response_duration = + Some(self.server_response_duration.unwrap_or_default() + duration); + } + + /// Record client write phase duration (first write wins) + pub fn record_client_write(&mut self, duration: Duration) { + self.client_write_duration.get_or_insert(duration); + } + + /// Add client write duration (accumulate) + pub fn add_client_write(&mut self, duration: Duration) { + self.client_write_duration = + Some(self.client_write_duration.unwrap_or_default() + duration); + } + + /// Record decrypt phase duration (first write wins) + pub fn record_decrypt(&mut self, duration: Duration) { + self.decrypt_duration.get_or_insert(duration); + } + + /// Add decrypt duration (accumulate) + pub fn add_decrypt(&mut self, duration: Duration) { + self.decrypt_duration = Some(self.decrypt_duration.unwrap_or_default() + duration); + } + + /// Calculate total tracked duration + pub fn total_tracked(&self) -> Duration { + [ + self.parse_duration, + self.encrypt_duration, + self.server_write_duration, + self.server_wait_duration, + self.server_response_duration, + self.client_write_duration, + self.decrypt_duration, + ] + .iter() + .filter_map(|d| *d) + .sum() + } +} + +/// Helper to time a phase +pub struct PhaseTimer { + start: Instant, +} + +impl PhaseTimer { + pub fn start() -> Self { + Self { + start: Instant::now(), + } + } + + pub fn elapsed(&self) -> Duration { + self.start.elapsed() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn phase_timing_records_durations() { + let mut timing = PhaseTiming::new(); + + timing.record_parse(Duration::from_millis(5)); + timing.record_encrypt(Duration::from_millis(100)); + timing.record_server_wait(Duration::from_millis(50)); + + assert_eq!(timing.parse_duration, Some(Duration::from_millis(5))); + assert_eq!(timing.encrypt_duration, Some(Duration::from_millis(100))); + assert_eq!(timing.server_wait_duration, Some(Duration::from_millis(50))); + } + + #[test] + fn total_tracked_sums_durations() { + let mut timing = PhaseTiming::new(); + + timing.record_parse(Duration::from_millis(5)); + timing.record_encrypt(Duration::from_millis(100)); + timing.record_server_wait(Duration::from_millis(50)); + + assert_eq!(timing.total_tracked(), Duration::from_millis(155)); + } + + #[test] + fn add_encrypt_accumulates() { + let mut timing = PhaseTiming::new(); + + timing.add_encrypt(Duration::from_millis(10)); + timing.add_encrypt(Duration::from_millis(15)); + + assert_eq!(timing.encrypt_duration, Some(Duration::from_millis(25))); + } + + #[test] + fn add_server_write_accumulates() { + let mut timing = PhaseTiming::new(); + + timing.add_server_write(Duration::from_millis(3)); + timing.add_server_write(Duration::from_millis(7)); + + assert_eq!( + timing.server_write_duration, + Some(Duration::from_millis(10)) + ); + } + + #[test] + fn phase_timer_measures_elapsed() { + let timer = PhaseTimer::start(); + std::thread::sleep(Duration::from_millis(10)); + let elapsed = timer.elapsed(); + + assert!(elapsed >= Duration::from_millis(10)); + } +} diff --git a/packages/cipherstash-proxy/src/postgresql/context/portal.rs b/packages/cipherstash-proxy/src/postgresql/context/portal.rs index c6cc6585..ad6b033b 100644 --- a/packages/cipherstash-proxy/src/postgresql/context/portal.rs +++ b/packages/cipherstash-proxy/src/postgresql/context/portal.rs @@ -1,4 +1,4 @@ -use super::{super::format_code::FormatCode, Column}; +use super::{super::format_code::FormatCode, Column, SessionId}; use crate::postgresql::context::statement::Statement; use std::sync::Arc; @@ -7,38 +7,44 @@ pub enum Portal { Encrypted { format_codes: Vec, statement: Arc, + session_id: Option, + }, + Passthrough { + session_id: Option, }, - Passthrough, } impl Portal { pub fn encrypted_with_format_codes( statement: Arc, format_codes: Vec, + session_id: Option, ) -> Portal { Portal::Encrypted { statement, format_codes, + session_id, } } - pub fn encrypted(statement: Arc) -> Portal { + pub fn encrypted(statement: Arc, session_id: Option) -> Portal { let format_codes = vec![]; Portal::Encrypted { statement, format_codes, + session_id, } } - pub fn passthrough() -> Portal { - Portal::Passthrough + pub fn passthrough(session_id: Option) -> Portal { + Portal::Passthrough { session_id } } pub fn projection_columns(&self) -> &Vec> { static EMPTY: Vec> = vec![]; match self { Portal::Encrypted { statement, .. } => &statement.projection_columns, - _ => &EMPTY, + Portal::Passthrough { .. } => &EMPTY, } } @@ -60,9 +66,16 @@ impl Portal { } _ => format_codes.clone(), }, - Portal::Passthrough => { + Portal::Passthrough { .. } => { unreachable!() } } } + + pub fn session_id(&self) -> Option { + match self { + Portal::Encrypted { session_id, .. } => *session_id, + Portal::Passthrough { session_id } => *session_id, + } + } } diff --git a/packages/cipherstash-proxy/src/postgresql/context/statement_metadata.rs b/packages/cipherstash-proxy/src/postgresql/context/statement_metadata.rs new file mode 100644 index 00000000..8973aef3 --- /dev/null +++ b/packages/cipherstash-proxy/src/postgresql/context/statement_metadata.rs @@ -0,0 +1,202 @@ +use serde::Serialize; +use sqltk::parser::ast::Statement; + +/// Statement type classification for metrics labels +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum StatementType { + Insert, + Update, + Delete, + Select, + Other, +} + +impl StatementType { + /// Create from parsed AST statement + pub fn from_statement(stmt: &Statement) -> Self { + match stmt { + Statement::Insert(_) => StatementType::Insert, + Statement::Update { .. } => StatementType::Update, + Statement::Delete(_) => StatementType::Delete, + Statement::Query(_) => StatementType::Select, + _ => StatementType::Other, + } + } + + /// Return lowercase label for metrics + pub fn as_label(&self) -> &'static str { + match self { + StatementType::Insert => "insert", + StatementType::Update => "update", + StatementType::Delete => "delete", + StatementType::Select => "select", + StatementType::Other => "other", + } + } +} + +/// Protocol type for metrics labels +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum ProtocolType { + Simple, + Extended, +} + +impl ProtocolType { + pub fn as_label(&self) -> &'static str { + match self { + ProtocolType::Simple => "simple", + ProtocolType::Extended => "extended", + } + } +} + +/// Metadata collected during statement processing for diagnostics +#[derive(Clone, Debug, Default)] +pub struct StatementMetadata { + /// Type of SQL statement + pub statement_type: Option, + /// Protocol used (simple or extended) + pub protocol: Option, + /// Whether encryption/decryption was performed + pub encrypted: bool, + /// Number of encrypted values in the statement + pub encrypted_values_count: usize, + /// Approximate size of parameters in bytes + pub param_bytes: usize, + /// Query fingerprint (first 8 chars of normalized query hash) + pub query_fingerprint: Option, + /// Whether the simple query contained multiple statements + pub multi_statement: bool, +} + +impl StatementMetadata { + pub fn new() -> Self { + Self::default() + } + + pub fn with_statement_type(mut self, stmt_type: StatementType) -> Self { + self.statement_type = Some(stmt_type); + self + } + + pub fn with_protocol(mut self, protocol: ProtocolType) -> Self { + self.protocol = Some(protocol); + self + } + + pub fn with_encrypted(mut self, encrypted: bool) -> Self { + self.encrypted = encrypted; + self + } + + pub fn set_encrypted_values_count(&mut self, count: usize) { + self.encrypted_values_count = count; + } + + pub fn set_param_bytes(&mut self, bytes: usize) { + self.param_bytes = bytes; + } + + /// Set query fingerprint from SQL statement. + /// + /// Uses Blake3 keyed hashing with a per-instance random key to prevent dictionary attacks + /// that could reveal SQL statements from fingerprints in logs/metrics. + /// + /// Fingerprints are instance-local identifiers for correlating log entries within a single + /// proxy instance. They are NOT stable across restarts or deployments and should not + /// be used for cross-instance correlation or persistent storage. + pub fn set_query_fingerprint(&mut self, sql: &str) { + use std::sync::LazyLock; + + // Random key generated once per proxy instance - makes fingerprints + // resistant to dictionary attacks while remaining consistent within instance + static FINGERPRINT_KEY: LazyLock<[u8; 32]> = LazyLock::new(rand::random); + + let hash = blake3::keyed_hash(&FINGERPRINT_KEY, sql.as_bytes()); + self.query_fingerprint = Some(hex::encode(&hash.as_bytes()[..4])); + } + + pub fn set_multi_statement(&mut self, value: bool) { + self.multi_statement = value; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sqltk::parser::dialect::PostgreSqlDialect; + use sqltk::parser::parser::Parser; + + fn parse(sql: &str) -> Statement { + Parser::new(&PostgreSqlDialect {}) + .try_with_sql(sql) + .unwrap() + .parse_statement() + .unwrap() + } + + #[test] + fn statement_type_from_statement() { + assert_eq!( + StatementType::from_statement(&parse("INSERT INTO foo VALUES (1)")), + StatementType::Insert + ); + assert_eq!( + StatementType::from_statement(&parse("UPDATE foo SET bar = 1")), + StatementType::Update + ); + assert_eq!( + StatementType::from_statement(&parse("DELETE FROM foo")), + StatementType::Delete + ); + assert_eq!( + StatementType::from_statement(&parse("SELECT * FROM foo")), + StatementType::Select + ); + assert_eq!( + StatementType::from_statement(&parse("CREATE TABLE foo (id INT)")), + StatementType::Other + ); + } + + #[test] + fn statement_type_labels() { + assert_eq!(StatementType::Insert.as_label(), "insert"); + assert_eq!(StatementType::Update.as_label(), "update"); + assert_eq!(StatementType::Delete.as_label(), "delete"); + assert_eq!(StatementType::Select.as_label(), "select"); + assert_eq!(StatementType::Other.as_label(), "other"); + } + + #[test] + fn metadata_builder_pattern() { + let metadata = StatementMetadata::new() + .with_statement_type(StatementType::Insert) + .with_protocol(ProtocolType::Extended) + .with_encrypted(true); + + assert_eq!(metadata.statement_type, Some(StatementType::Insert)); + assert_eq!(metadata.protocol, Some(ProtocolType::Extended)); + assert!(metadata.encrypted); + } + + #[test] + fn query_fingerprint_is_deterministic() { + let mut m1 = StatementMetadata::new(); + let mut m2 = StatementMetadata::new(); + + m1.set_query_fingerprint("SELECT * FROM users WHERE id = $1"); + m2.set_query_fingerprint("SELECT * FROM users WHERE id = $1"); + + assert_eq!(m1.query_fingerprint, m2.query_fingerprint); + } + + #[test] + fn multi_statement_flag_defaults_false() { + let metadata = StatementMetadata::new(); + assert!(!metadata.multi_statement); + } +} diff --git a/packages/cipherstash-proxy/src/postgresql/frontend.rs b/packages/cipherstash-proxy/src/postgresql/frontend.rs index 7d07098e..780360bb 100644 --- a/packages/cipherstash-proxy/src/postgresql/frontend.rs +++ b/packages/cipherstash-proxy/src/postgresql/frontend.rs @@ -1,4 +1,5 @@ -use super::context::{Context, Statement}; +use super::context::phase_timing::PhaseTimer; +use super::context::{Context, SessionId, Statement}; use super::error_handler::PostgreSqlErrorHandler; use super::messages::bind::Bind; use super::messages::describe::Describe; @@ -12,6 +13,7 @@ use crate::connect::Sender; use crate::error::{EncryptError, Error, MappingError}; use crate::log::{MAPPER, PROTOCOL}; use crate::postgresql::context::column::Column; +use crate::postgresql::context::statement_metadata::{ProtocolType, StatementType}; use crate::postgresql::context::Portal; use crate::postgresql::data::literal_from_sql; use crate::postgresql::messages::close::Close; @@ -267,9 +269,7 @@ where ?code, ); - if self.context.schema_changed() { - self.context.reload_schema().await; - } + self.context.reload_schema_if_changed().await; if self.error_state.is_some() { debug!(target: PROTOCOL, @@ -300,7 +300,14 @@ where debug!(target: PROTOCOL, msg = "Write to server", ?bytes); let sent: u64 = bytes.len() as u64; counter!(SERVER_BYTES_SENT_TOTAL).increment(sent); + + let start = Instant::now(); self.server_writer.write_all(&bytes).await?; + let duration = start.elapsed(); + if let Some(session_id) = self.context.latest_session_id() { + self.context.add_server_write_duration(session_id, duration); + } + Ok(()) } @@ -323,10 +330,7 @@ where debug!(target: PROTOCOL, client_id = self.context.client_id, ?close); match close.target { Target::Portal => self.context.close_portal(&close.name), - Target::Statement => { - self.context.close_portal(&close.name); - self.context.close_statement(&close.name); - } + Target::Statement => self.context.close_statement_and_portal(&close.name), } Ok(()) } @@ -334,7 +338,8 @@ where async fn execute_handler(&mut self, bytes: &BytesMut) -> Result<(), Error> { let execute = Execute::try_from(bytes)?; debug!(target: PROTOCOL, client_id = self.context.client_id, ?execute); - self.context.set_execute(execute.portal.to_owned()); + self.context + .set_execute_for_portal(execute.portal.to_owned()); Ok(()) } @@ -372,7 +377,14 @@ where /// - `Err(error)` - Processing failed, error should be sent to client async fn query_handler(&mut self, bytes: &BytesMut) -> Result, Error> { let handler_start = Instant::now(); - self.context.start_session(); + let session_id = self.context.start_session(); + + // Set protocol type for diagnostics + self.context.update_statement_metadata(session_id, |m| { + m.protocol = Some(ProtocolType::Simple); + }); + + let parse_timer = PhaseTimer::start(); let mut query = Query::try_from(bytes)?; @@ -385,12 +397,12 @@ where statements = parsed_statements.len(), ); - let mut portal = Portal::passthrough(); + let mut portal = Portal::passthrough(Some(session_id)); let mut encrypted = false; + let mut parse_duration_recorded = false; - for statement in parsed_statements { - if let Some(mapping_disabled) = - self.context.maybe_set_unsafe_disable_mapping(&statement) + for statement in &parsed_statements { + if let Some(mapping_disabled) = self.context.maybe_set_unsafe_disable_mapping(statement) { warn!( msg = "SET CIPHERSTASH.DISABLE_MAPPING = {mapping_disabled}", @@ -405,16 +417,16 @@ where continue; } - self.handle_set_keyset(&statement)?; + self.handle_set_keyset(statement)?; - self.check_for_schema_change(&statement); + self.check_for_schema_change(statement); - if !eql_mapper::requires_type_check(&statement) { + if !eql_mapper::requires_type_check(statement) { counter!(STATEMENTS_PASSTHROUGH_TOTAL).increment(1); continue; } - let typed_statement = match self.type_check(&statement) { + let typed_statement = match self.type_check(statement) { Ok(ts) => ts, Err(err) => { if self.context.mapping_errors_enabled() { @@ -433,8 +445,19 @@ where ); if typed_statement.requires_transform() { + // Record parse duration before encryption work starts + if !parse_duration_recorded { + self.context + .record_parse_duration(session_id, parse_timer.elapsed()); + parse_duration_recorded = true; + } + let encrypted_literals = self - .encrypt_literals(&typed_statement, &statement.literal_columns) + .encrypt_literals( + session_id, + &typed_statement, + &statement.literal_columns, + ) .await?; if let Some(transformed_statement) = self @@ -453,8 +476,11 @@ where counter!(STATEMENTS_ENCRYPTED_TOTAL).increment(1); - // Set Encrypted portal - portal = Portal::encrypted(Arc::new(statement)); + // Set Encrypted portal and mark as mapped + portal = Portal::encrypted(Arc::new(statement), Some(session_id)); + self.context.update_statement_metadata(session_id, |m| { + m.encrypted = true; + }); } None => { debug!(target: MAPPER, @@ -462,13 +488,38 @@ where msg = "Passthrough Statement" ); counter!(STATEMENTS_PASSTHROUGH_TOTAL).increment(1); - transformed_statements.push(statement); + transformed_statements.push(statement.clone()); } }; } + // Record parse/typecheck duration (if not already recorded before encryption) + if !parse_duration_recorded { + self.context + .record_parse_duration(session_id, parse_timer.elapsed()); + } + + // Set statement type based on parsed statements + let statement_type = if parsed_statements.len() == 1 { + parsed_statements + .first() + .map(StatementType::from_statement) + .unwrap_or(StatementType::Other) + } else { + StatementType::Other + }; + self.context.update_statement_metadata(session_id, |m| { + m.statement_type = Some(statement_type); + m.set_multi_statement(parsed_statements.len() > 1); + }); + + // Set query fingerprint + self.context.update_statement_metadata(session_id, |m| { + m.set_query_fingerprint(&query.statement); + }); + self.context.add_portal(Name::unnamed(), portal); - self.context.set_execute(Name::unnamed()); + self.context.set_execute(Name::unnamed(), Some(session_id)); if encrypted { let transformed_statement = transformed_statements @@ -534,6 +585,7 @@ where /// literals that don't require encryption and `Some(EqlCiphertext)` for encrypted values. async fn encrypt_literals( &mut self, + session_id: SessionId, typed_statement: &TypeCheckedStatement<'_>, literal_columns: &Vec>, ) -> Result>, Error> { @@ -564,10 +616,20 @@ where ?encrypted ); - counter!(ENCRYPTION_REQUESTS_TOTAL).increment(1); - counter!(ENCRYPTED_VALUES_TOTAL).increment(encrypted.len() as u64); - let duration = Instant::now().duration_since(start); + + // Add to phase timing diagnostics (accumulate) + self.context.add_encrypt_duration(session_id, duration); + + // Update metadata with encrypted values count + let encrypted_count = encrypted.iter().filter(|e| e.is_some()).count(); + self.context.update_statement_metadata(session_id, |m| { + m.encrypted = true; + m.set_encrypted_values_count(encrypted_count); + }); + + counter!(ENCRYPTION_REQUESTS_TOTAL).increment(1); + counter!(ENCRYPTED_VALUES_TOTAL).increment(encrypted_count as u64); histogram!(ENCRYPTION_DURATION_SECONDS).record(duration); Ok(encrypted) @@ -657,9 +719,18 @@ where /// - `Ok(None)` - No transformation needed, forward original message /// - `Err(error)` - Processing failed, error should be sent to client async fn parse_handler(&mut self, bytes: &BytesMut) -> Result, Error> { - self.context.start_session(); + let session_id = self.context.start_session(); + + // Set protocol type + self.context.update_statement_metadata(session_id, |m| { + m.protocol = Some(ProtocolType::Extended); + }); + + let parse_timer = PhaseTimer::start(); let mut message = Parse::try_from(bytes)?; + self.context + .set_statement_session(message.name.to_owned(), session_id); debug!( target: PROTOCOL, @@ -707,11 +778,18 @@ where // These override the underlying column type let param_types = message.param_types.clone(); + let mut parse_duration_recorded = false; + match self.to_encryptable_statement(&typed_statement, param_types)? { Some(statement) => { if typed_statement.requires_transform() { + // Record parse duration before encryption work starts + self.context + .record_parse_duration(session_id, parse_timer.elapsed()); + parse_duration_recorded = true; + let encrypted_literals = self - .encrypt_literals(&typed_statement, &statement.literal_columns) + .encrypt_literals(session_id, &typed_statement, &statement.literal_columns) .await?; if let Some(transformed_statement) = self @@ -742,6 +820,18 @@ where } } + // Record parse duration (if not already recorded before encryption) + if !parse_duration_recorded { + self.context + .record_parse_duration(session_id, parse_timer.elapsed()); + } + + // Set statement type and fingerprint + self.context.update_statement_metadata(session_id, |m| { + m.statement_type = Some(StatementType::from_statement(&statement)); + m.set_query_fingerprint(&message.statement); + }); + if message.requires_rewrite() { let bytes = BytesMut::try_from(message)?; @@ -891,22 +981,34 @@ where let mut bind = Bind::try_from(bytes)?; + let session_id = self + .context + .get_statement_session_or_latest(&bind.prepared_statement); + + // Track param bytes for diagnostics + let param_bytes: usize = bind.param_values.iter().map(|p| p.bytes.len()).sum(); + self.context + .with_session(session_id, |m| m.metadata.set_param_bytes(param_bytes)); + debug!(target: PROTOCOL, client_id = self.context.client_id, bind = ?bind); - let mut portal = Portal::passthrough(); + let mut portal = Portal::passthrough(session_id); if let Some(statement) = self.context.get_statement(&bind.prepared_statement) { debug!(target:MAPPER, client_id = self.context.client_id, ?statement); if statement.has_params() { - let encrypted = self.encrypt_params(&bind, &statement).await?; + let encrypted = self.encrypt_params(session_id, &bind, &statement).await?; bind.rewrite(encrypted)?; } if statement.has_projection() { portal = Portal::encrypted_with_format_codes( statement, bind.result_columns_format_codes.to_owned(), + session_id, ); + self.context + .with_session(session_id, |m| m.metadata.encrypted = true); } }; @@ -937,6 +1039,7 @@ where /// async fn encrypt_params( &mut self, + session_id: Option, bind: &Bind, statement: &Statement, ) -> Result>, Error> { @@ -955,14 +1058,22 @@ where counter!(ENCRYPTION_ERROR_TOTAL).increment(1); })?; - // Avoid the iter calculation if we can - if self.context.prometheus_enabled() { - let encrypted_count = encrypted.iter().filter(|e| e.is_some()).count() as u64; + let duration = Instant::now().duration_since(start); + // Record timing and metadata for this encryption operation + let encrypted_count = encrypted.iter().filter(|e| e.is_some()).count(); + self.context.with_session(session_id, |m| { + // Add to phase timing diagnostics (accumulate) + m.phase_timing.add_encrypt(duration); + // Always update metadata for slow-statement logging + m.metadata.encrypted = true; + m.metadata.set_encrypted_values_count(encrypted_count); + }); + + // Prometheus metrics remain gated + if self.context.prometheus_enabled() { counter!(ENCRYPTION_REQUESTS_TOTAL).increment(1); - counter!(ENCRYPTED_VALUES_TOTAL).increment(encrypted_count); - - let duration = Instant::now().duration_since(start); + counter!(ENCRYPTED_VALUES_TOTAL).increment(encrypted_count as u64); histogram!(ENCRYPTION_DURATION_SECONDS).record(duration); } diff --git a/packages/cipherstash-proxy/src/prometheus.rs b/packages/cipherstash-proxy/src/prometheus.rs index afaeef8d..d14f8322 100644 --- a/packages/cipherstash-proxy/src/prometheus.rs +++ b/packages/cipherstash-proxy/src/prometheus.rs @@ -26,6 +26,7 @@ pub const STATEMENTS_SESSION_DURATION_SECONDS: &str = "cipherstash_proxy_statements_session_duration_seconds"; pub const STATEMENTS_EXECUTION_DURATION_SECONDS: &str = "cipherstash_proxy_statements_execution_duration_seconds"; +pub const SLOW_STATEMENTS_TOTAL: &str = "cipherstash_proxy_slow_statements_total"; pub const ROWS_TOTAL: &str = "cipherstash_proxy_rows_total"; pub const ROWS_ENCRYPTED_TOTAL: &str = "cipherstash_proxy_rows_encrypted_total"; @@ -39,6 +40,8 @@ pub const SERVER_BYTES_RECEIVED_TOTAL: &str = "cipherstash_proxy_server_bytes_re pub const KEYSET_CIPHER_INIT_TOTAL: &str = "cipherstash_proxy_keyset_cipher_init_total"; pub const KEYSET_CIPHER_CACHE_HITS_TOTAL: &str = "cipherstash_proxy_keyset_cipher_cache_hits_total"; +pub const KEYSET_CIPHER_INIT_DURATION_SECONDS: &str = + "cipherstash_proxy_keyset_cipher_init_duration_seconds"; pub fn start(host: String, port: u16) -> Result<(), Error> { let address = format!("{host}:{port}"); @@ -115,6 +118,10 @@ pub fn start(host: String, port: u16) -> Result<(), Error> { Unit::Seconds, "Duration of time the proxied database spent executing SQL statements" ); + describe_counter!( + SLOW_STATEMENTS_TOTAL, + "Total number of statements exceeding slow statement threshold" + ); describe_counter!(ROWS_TOTAL, "Total number of rows returned to clients"); describe_counter!( @@ -156,6 +163,11 @@ pub fn start(host: String, port: u16) -> Result<(), Error> { KEYSET_CIPHER_CACHE_HITS_TOTAL, "Number of times a keyset-scoped cipher was found in the cache" ); + describe_histogram!( + KEYSET_CIPHER_INIT_DURATION_SECONDS, + Unit::Seconds, + "Duration of keyset-scoped cipher initialization (includes ZeroKMS network call)" + ); // Prometheus endpoint is empty on startup and looks like an error // Explicitly set count to zero diff --git a/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs b/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs index c3dddf21..945d6504 100644 --- a/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs +++ b/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs @@ -3,7 +3,10 @@ use crate::{ error::{EncryptError, Error, ZeroKMSError}, log::{ENCRYPT, PROXY}, postgresql::{Column, KeysetIdentifier}, - prometheus::{KEYSET_CIPHER_CACHE_HITS_TOTAL, KEYSET_CIPHER_INIT_TOTAL}, + prometheus::{ + KEYSET_CIPHER_CACHE_HITS_TOTAL, KEYSET_CIPHER_INIT_DURATION_SECONDS, + KEYSET_CIPHER_INIT_TOTAL, + }, proxy::EncryptionService, }; use cipherstash_client::{ @@ -15,10 +18,13 @@ use cipherstash_client::{ schema::column::IndexType, }; use eql_mapper::EqlTermVariant; -use metrics::counter; +use metrics::{counter, histogram}; use moka::future::Cache; -use std::borrow::Cow; -use std::{sync::Arc, time::Duration}; +use std::{ + borrow::Cow, + sync::Arc, + time::{Duration, Instant}, +}; use tracing::{debug, info, warn}; use uuid::Uuid; @@ -85,11 +91,17 @@ impl ZeroKms { let identified_by = keyset_id.as_ref().map(|id| id.0.clone()); + // Time the cipher initialization (includes network call to ZeroKMS) + let start = Instant::now(); + match ScopedCipher::init(zerokms_client, identified_by).await { Ok(cipher) => { + let init_duration = start.elapsed(); + let arc_cipher = Arc::new(cipher); counter!(KEYSET_CIPHER_INIT_TOTAL).increment(1); + histogram!(KEYSET_CIPHER_INIT_DURATION_SECONDS).record(init_duration); // Store in cache self.cipher_cache @@ -103,12 +115,23 @@ impl ZeroKms { let memory_usage_bytes = self.cipher_cache.weighted_size(); info!(msg = "Connected to ZeroKMS"); - debug!(target: PROXY, msg = "ScopedCipher cached", ?keyset_id, entry_count, memory_usage_bytes); + debug!(target: PROXY, + msg = "ScopedCipher cached", + ?keyset_id, + entry_count, + memory_usage_bytes, + init_duration_ms = init_duration.as_millis() + ); Ok(arc_cipher) } Err(err) => { - debug!(target: PROXY, msg = "Error initializing ZeroKMS ScopedCipher", error = err.to_string()); + let init_duration = start.elapsed(); + debug!(target: PROXY, + msg = "Error initializing ZeroKMS ScopedCipher", + error = err.to_string(), + init_duration_ms = init_duration.as_millis() + ); warn!(msg = "Error initializing ZeroKMS", error = err.to_string()); match err { diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index aecc812f..9f02b650 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -70,6 +70,7 @@ services: - CS_LOG__PROTOCOL_LEVEL=${CS_LOG__PROTOCOL_LEVEL:-debug} - CS_LOG__MAPPER_LEVEL=${CS_LOG__MAPPER_LEVEL:-debug} - CS_LOG__CONTEXT_LEVEL=${CS_LOG__CONTEXT_LEVEL:-debug} + - CS_LOG__SLOW_STATEMENTS=${CS_LOG__SLOW_STATEMENTS:-true} networks: - postgres extra_hosts: @@ -92,6 +93,7 @@ services: container_name: proxy-tls ports: - 6432:6432 + - 9930:9930 environment: - CS_DATABASE__NAME=${CS_DATABASE__NAME} - CS_DATABASE__USERNAME=${CS_DATABASE__USERNAME} @@ -112,6 +114,7 @@ services: - CS_LOG__PROTOCOL_LEVEL=${CS_LOG__PROTOCOL_LEVEL:-debug} - CS_LOG__MAPPER_LEVEL=${CS_LOG__MAPPER_LEVEL:-debug} - CS_LOG__CONTEXT_LEVEL=${CS_LOG__CONTEXT_LEVEL:-debug} + - CS_LOG__SLOW_STATEMENTS=${CS_LOG__SLOW_STATEMENTS:-true} volumes: - ./tls/server.cert:/etc/cipherstash-proxy/server.cert - ./tls/server.key:/etc/cipherstash-proxy/server.key diff --git a/tests/tasks/test/integration/prometheus.sh b/tests/tasks/test/integration/prometheus.sh index 36253db3..ad2685b8 100755 --- a/tests/tasks/test/integration/prometheus.sh +++ b/tests/tasks/test/integration/prometheus.sh @@ -33,13 +33,13 @@ if [[ $response != *"cipherstash_proxy_rows_total 1"* ]]; then exit 1 fi -if [[ $response != *"cipherstash_proxy_statements_execution_duration_seconds{quantile=\"1\"} 0."* ]]; then - echo "error: did not see string in output: \"cipherstash_proxy_statements_execution_duration_seconds{quantile=\"1\"} 0.\"" +if [[ ! $response =~ cipherstash_proxy_statements_execution_duration_seconds\{.*quantile=\"1\"\} ]]; then + echo "error: did not see execution duration metric with quantile=\"1\" in output" exit 1 fi -if [[ $response != *"cipherstash_proxy_statements_session_duration_seconds{quantile=\"1\"} 0."* ]]; then - echo "error: did not see string in output: \"cipherstash_proxy_statements_session_duration_seconds{quantile=\"1\"} 0.\"" +if [[ ! $response =~ cipherstash_proxy_statements_session_duration_seconds\{.*quantile=\"1\"\} ]]; then + echo "error: did not see session duration metric with quantile=\"1\" in output" exit 1 fi