diff --git a/Cargo.lock b/Cargo.lock index 6800415..4304554 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -137,7 +137,7 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "attested-tls-proxy" -version = "0.0.1" +version = "0.0.2" dependencies = [ "anyhow", "axum", diff --git a/Cargo.toml b/Cargo.toml index f361a66..73e3e7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [".", "dummy-attestation-server"] [package] name = "attested-tls-proxy" -version = "0.0.1" +version = "0.0.2" edition = "2024" license = "MIT" description = "An HTTP attested TLS proxy server and client for secure communication with CVM services" @@ -28,7 +28,7 @@ rand_core = { version = "0.6.4", features = ["getrandom"] } dcap-qvl = "0.3.10" hex = "0.4.3" hyper = { version = "1.7.0", features = ["server", "http2"] } -hyper-util = "0.1.17" +hyper-util = { version = "0.1.17", features = ["tokio"] } http-body-util = "0.1.3" bytes = "1.10.1" http = "1.3.1" diff --git a/src/lib.rs b/src/lib.rs index 8d33f1e..fb3d533 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ use hyper_util::rt::TokioIo; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::rustls::server::VerifierBuilderError; -use tracing::{error, warn}; +use tracing::{debug, error, warn}; #[cfg(test)] mod test_helpers; @@ -54,12 +54,21 @@ static X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip"); /// The longest time in seconds to wait between reconnection attempts const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; +const KEEP_ALIVE_INTERVAL: u64 = 30; +const KEEP_ALIVE_TIMEOUT: u64 = 10; + type RequestWithResponseSender = ( http::Request, oneshot::Sender>, hyper::Error>>, ); type Http2Sender = hyper::client::conn::http2::SendRequest; +type Http2Connection = hyper::client::conn::http2::Connection< + TokioIo>, + hyper::body::Incoming, + TokioExecutor, +>; + /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { /// The underlying attested TLS server @@ -126,12 +135,14 @@ impl ProxyServer { } /// Accept an incoming connection and handle it in a seperate task - pub async fn accept(&self) -> Result<(), ProxyError> { + /// + /// Returns the handle for the task handling the connection + pub async fn accept(&self) -> Result, ProxyError> { let target = self.target.clone(); let (inbound, client_addr) = self.listener.accept().await?; let attested_tls_server = self.attested_tls_server.clone(); - tokio::spawn(async move { + let join_handle = tokio::spawn(async move { match attested_tls_server.handle_connection(inbound).await { Ok((tls_stream, measurements, attestation_type)) => { if let Err(err) = Self::handle_connection( @@ -152,7 +163,7 @@ impl ProxyServer { } }); - Ok(()) + Ok(join_handle) } /// Helper to get the socket address of the underlying TCP listener @@ -168,18 +179,16 @@ impl ProxyServer { target: String, client_addr: SocketAddr, ) -> Result<(), ProxyError> { - tracing::debug!("proxy-server accepted connection"); - - // Setup an HTTP server - let http = hyper::server::conn::http2::Builder::new(TokioExecutor); + debug!("[proxy-server] accepted connection"); // Setup a request handler let service = service_fn(move |mut req| { + debug!("[proxy-server] Handling request {req:?}"); let headers = req.headers_mut(); // Add or update the HOST header let old_value = update_header(headers, &http::header::HOST, &target); - tracing::info!("Updating Host header - old value: {old_value:?} new value: {target}",); + debug!("Updating Host header - old value: {old_value:?} new value: {target}",); // Add the x-real-ip header let client_ip = client_addr.ip().to_string(); @@ -221,6 +230,7 @@ impl ProxyServer { async move { match Self::handle_http_request(req, target).await { Ok(res) => { + debug!("[proxy-server] Responding {res:?}"); Ok::>, hyper::Error>(res) } Err(e) => { @@ -235,7 +245,14 @@ impl ProxyServer { // Serve this connection using the request handler defined above let io = TokioIo::new(tls_stream); - http.serve_connection(io, service).await?; + + // Setup an HTTP server + hyper::server::conn::http2::Builder::new(TokioExecutor) + .timer(hyper_util::rt::tokio::TokioTimer::new()) + .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) + .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) + .serve_connection(io, service) + .await?; Ok(()) } @@ -353,73 +370,125 @@ impl ProxyClient { >, )>(1024); - // Connect to the proxy server and provide / verify attestation - let (mut sender, mut measurements, mut remote_attestation_type) = - Self::setup_connection_with_backoff(&target, &attested_tls_client, true).await?; + // used only to signal "initial connect succeeded" or "failed with error" + let (ready_tx, ready_rx) = oneshot::channel::>(); - let attested_tls_client_clone = attested_tls_client.clone(); tokio::spawn(async move { - // Read an incoming request from the channel (from the source client) - while let Some((req, response_tx)) = requests_rx.recv().await { - // Attempt to forward it to the proxy server - let (response, should_reconnect) = match sender.send_request(req).await { - Ok(mut resp) => { - // If we have measurements from the proxy-server, inject them into the - // response header - let headers = resp.headers_mut(); - if let Some(measurements) = measurements.clone() { - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); + let mut first = true; + let mut ready_tx = Some(ready_tx); + 'reconnect: loop { + let (mut sender, conn, measurements, remote_attestation_type) = + // Connect to the proxy server and provide / verify attestation + match Self::setup_connection_with_backoff(&target, &attested_tls_client, first) + .await + { + Ok(output) => { + if first { + if let Some(tx) = ready_tx.take() { + let _ = tx.send(Ok(())); } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - error!("Failed to encode measurement values: {e}"); + first = false; + } + output + } + Err(err) => { + if first { + if let Some(tx) = ready_tx.take() { + let _ = tx.send(Err(err)); } + return; + } else { + error!("Reconnect setup failed unexpectedly: {err}"); + continue; } } + }; + + let (conn_done_tx, mut conn_done_rx) = + tokio::sync::watch::channel::>(None); + + tokio::spawn(async move { + let res = conn.await; + let _ = conn_done_tx.send(res.err()); + }); + loop { + tokio::select! { + // Read an incoming request from the channel (from the source client) + incoming_req_option = requests_rx.recv() => { + if let Some((req, response_tx)) = incoming_req_option { + debug!("[proxy-client] Read incoming request from source client: {req:?}"); + // Attempt to forward it to the proxy server + let (response, should_reconnect) = match sender.send_request(req).await { + Ok(mut resp) => { + debug!("[proxy-client] Read response from proxy-server: {resp:?}"); + // If we have measurements from the proxy-server, inject them into the + // response header + let headers = resp.headers_mut(); + if let Some(measurements) = measurements.clone() { + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + error!("Failed to encode measurement values: {e}"); + } + } + } + + update_header( + headers, + ATTESTATION_TYPE_HEADER, + remote_attestation_type.as_str(), + ); + (Ok(resp.map(|b| b.boxed())), false) + } + Err(e) => { + warn!("Failed to send request to proxy-server: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + + (Ok(resp), true) + } + }; + + // Send the response back to the source client + if response_tx.send(response).is_err() { + warn!("Failed to forward response to source client, probably they dropped the connection"); + } - update_header( - headers, - ATTESTATION_TYPE_HEADER, - remote_attestation_type.as_str(), - ); - (Ok(resp.map(|b| b.boxed())), false) - } - Err(e) => { - warn!("Failed to send request to proxy-server: {e}"); - let mut resp = Response::new(full(format!("Request failed: {e}"))); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - - (Ok(resp), true) - } - }; - - // Send the response back to the source client - if response_tx.send(response).is_err() { - warn!("Failed to forward response to source client, probably they dropped the connection"); - } + if should_reconnect { + // Leave the inner loop and continue on the reconnect loop + warn!("Reconnecting to proxy-server due to failed request"); + break; + } + } else { + // The request sender was dropped - so no more incoming requests + debug!("Request sender dropped - leaving connection handler loop"); + break 'reconnect; + } + } - // If the connection to the proxy server failed, reconnect - if should_reconnect { - // Reconnect to the server - retrying indefinately with a backoff - (sender, measurements, remote_attestation_type) = - Self::setup_connection_with_backoff( - &target, - &attested_tls_client_clone, - false, - ) - .await - .expect("Function will not return an error when should_bail is false"); + // Connection closed + _ = conn_done_rx.changed() => { + // Leave the inner loop and continue on the reconnect loop + warn!("Connection dropped - reconnecting..."); + break; + } + }; } } }); - Ok(Self { - listener, - requests_tx, - }) + match ready_rx.await { + Ok(Ok(())) => Ok(Self { + listener, + requests_tx, + }), + Ok(Err(e)) => Err(e), + Err(e) => Err(e.into()), + } } /// Helper to return the local socket address from the underlying TCP listener @@ -480,7 +549,15 @@ impl ProxyClient { target: &str, attested_tls_client: &AttestedTlsClient, should_bail: bool, - ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { + ) -> Result< + ( + Http2Sender, + Http2Connection, + Option, + AttestationType, + ), + ProxyError, + > { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); @@ -509,25 +586,30 @@ impl ProxyClient { async fn setup_connection( inner: &AttestedTlsClient, target: &str, - ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { + ) -> Result< + ( + Http2Sender, + Http2Connection, + Option, + AttestationType, + ), + ProxyError, + > { let (tls_stream, measurements, remote_attestation_type) = inner.connect_tcp(target).await?; // The attestation exchange is now complete - setup an HTTP client let outbound_io = TokioIo::new(tls_stream); let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) + .timer(hyper_util::rt::tokio::TokioTimer::new()) + .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) + .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) + .keep_alive_while_idle(true) .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; - // Drive the connection - tokio::spawn(async move { - if let Err(e) = conn.await { - warn!("Client connection error: {e}"); - } - }); - // Return the HTTP client, as well as remote measurements - Ok((sender, measurements, remote_attestation_type)) + Ok((sender, conn, measurements, remote_attestation_type)) } // Handle a request from the source client to the proxy server @@ -634,7 +716,7 @@ mod tests { use super::*; use test_helpers::{ example_http_service, generate_certificate_chain, generate_tls_config, - generate_tls_config_with_client_auth, mock_dcap_measurements, + generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements, }; // Server has mock DCAP, client has no attestation and no client auth @@ -1105,4 +1187,90 @@ mod tests { )) )); } + + #[tokio::test] + async fn http_proxy_client_reconnects_on_lost_connection() { + init_tracing(); + + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr.to_string(), + AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + // This is used to trigger a dropped connection to the proxy server + let (connection_breaker_tx, connection_breaker_rx) = oneshot::channel(); + + tokio::spawn(async move { + let connection_handle = proxy_server.accept().await.unwrap(); + + // Wait for a signal to simulate a dropped connection, then drop the task handling the + // connection + connection_breaker_rx.await.unwrap(); + connection_handle.abort(); + + // Now accept another connection + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + proxy_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + // Now break the connection + connection_breaker_tx.send(()).unwrap(); + + // Make another request + let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + let headers = res.headers(); + + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + + let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + let measurements = + MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx) + .unwrap(); + assert_eq!(measurements, mock_dcap_measurements()); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); + } } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 065f510..44de3c9 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -3,7 +3,7 @@ use axum::response::IntoResponse; use std::{ collections::HashMap, net::{IpAddr, SocketAddr}, - sync::Arc, + sync::{Arc, Once}, }; use tokio::net::TcpListener; use tokio_rustls::rustls::{ @@ -11,6 +11,9 @@ use tokio_rustls::rustls::{ server::{danger::ClientCertVerifier, WebPkiClientVerifier}, ClientConfig, RootCertStore, ServerConfig, }; +use tracing_subscriber::{fmt, EnvFilter}; + +static INIT: Once = Once::new(); use crate::{ attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}, @@ -171,3 +174,14 @@ pub fn mock_dcap_measurements() -> MultiMeasurements { (DcapMeasurementRegister::RTMR3, [0u8; 48]), ])) } + +pub fn init_tracing() { + INIT.call_once(|| { + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); + + fmt() + .with_env_filter(filter) + .with_test_writer() // <-- IMPORTANT for tests + .init(); + }); +}