diff --git a/Cargo.lock b/Cargo.lock index 950ec89..3900279 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,36 +136,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] -name = "attested-tls-proxy" +name = "attested-tls" version = "0.0.1" dependencies = [ "anyhow", - "axum", "az-tdx-vtpm", "base64 0.22.1", - "bytes", - "clap", "configfs-tsm", "dcap-qvl", "futures-util", "hex", "http", - "http-body-util", - "hyper", - "hyper-util", "num-bigint", "once_cell", "openssl", - "p256", "parity-scale-codec", "pem-rfc7468", - "pkcs1", - "pkcs8", "rand_core 0.6.4", "rcgen", "reqwest", - "rsa", - "rustls-pemfile", "rustls-webpki 0.103.8", "serde", "serde_json", @@ -177,14 +166,46 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-tungstenite", - "tower-http", "tracing", - "tracing-subscriber", "tss-esapi", "webpki-roots", "x509-parser", ] +[[package]] +name = "attested-tls-proxy" +version = "0.0.1" +dependencies = [ + "anyhow", + "attested-tls", + "axum", + "bytes", + "clap", + "http", + "http-body-util", + "hyper", + "hyper-util", + "p256", + "pem-rfc7468", + "pkcs1", + "pkcs8", + "rcgen", + "reqwest", + "rsa", + "rustls-pemfile", + "serde", + "serde_json", + "tdx-quote", + "tempfile", + "thiserror 2.0.17", + "time", + "tokio", + "tokio-rustls", + "tower-http", + "tracing", + "tracing-subscriber", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -1584,9 +1605,9 @@ checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" [[package]] name = "libm" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" diff --git a/Cargo.toml b/Cargo.toml index 4f399d3..31f3422 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [".", "dummy-attestation-server"] +members = [".", "dummy-attestation-server", "attested-tls"] [package] name = "attested-tls-proxy" @@ -11,22 +11,16 @@ repository = "https://github.com/flashbots/attested-tls-proxy" keywords = ["attested-TLS", "CVM", "TDX"] [dependencies] +attested-tls = { path = "attested-tls" } tokio = { version = "1.48.0", features = ["full"] } tokio-rustls = { version = "0.26.4", default-features = false, features = [ "ring", ] } -sha2 = "0.10.9" -x509-parser = "0.18.0" thiserror = "2.0.17" clap = { version = "4.5.51", features = ["derive", "env"] } -webpki-roots = "1.0.4" rustls-pemfile = "2.2.0" anyhow = "1.0.100" pem-rfc7468 = { version = "0.7.0", features = ["std"] } -configfs-tsm = "0.0.2" -rand_core = { version = "0.6.4", features = ["getrandom"] } -dcap-qvl = "0.3.4" -hex = "0.4.3" hyper = { version = "1.7.0", features = ["server", "http2"] } hyper-util = "0.1.17" http-body-util = "0.1.3" @@ -34,24 +28,14 @@ bytes = "1.10.1" http = "1.3.1" serde_json = "1.0.145" serde = "1.0.228" -base64 = "0.22.1" reqwest = { version = "0.12.23", default-features = false, features = [ "rustls-tls-webpki-roots-no-provider", ] } tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } -parity-scale-codec = "3.7.5" -openssl = "0.10.75" -az-tdx-vtpm = { version = "0.7.4", optional = true } -tss-esapi = { version = "7.6.0", optional = true } -num-bigint = "0.4.6" -webpki = { package = "rustls-webpki", version = "0.103.8" } time = "0.3.44" -once_cell = "1.21.3" axum = "0.8.6" tower-http = { version = "0.6.7", features = ["fs"] } -tokio-tungstenite = { version = "0.28.0", optional = true } -futures-util = { version = "0.3.31", optional = true } rsa = { version = "0.9", default-features = false } p256 = { version = "0.13.2", features = ["pkcs8"] } @@ -62,15 +46,13 @@ pkcs8 = "0.10.2" rcgen = "0.14.5" tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } +attested-tls = { path = "attested-tls", features = ["test-helpers", "mock"] } [features] -default = ["azure", "ws"] +default = ["azure"] # Adds support for Microsoft Azure attestation generation and verification -azure = ["tss-esapi", "az-tdx-vtpm"] - -# Adds websocket support -ws = ["tokio-tungstenite", "futures-util"] +azure = ["attested-tls/azure"] [package.metadata.deb] maintainer = "Flashbots Team " diff --git a/attested-tls/Cargo.toml b/attested-tls/Cargo.toml new file mode 100644 index 0000000..1fa5829 --- /dev/null +++ b/attested-tls/Cargo.toml @@ -0,0 +1,64 @@ +[package] +name = "attested-tls" +version = "0.0.1" +edition = "2024" +license = "MIT" +description = "A remote-attested TLS protocol for secure communication with CVM services" +repository = "https://github.com/flashbots/attested-tls-proxy" +keywords = ["attested-TLS", "CVM", "TDX"] + +[dependencies] +tokio = { version = "1.48.0", features = ["full"] } +tokio-rustls = { version = "0.26.4", default-features = false, features = [ + "ring", +] } +sha2 = "0.10.9" +x509-parser = "0.18.0" +thiserror = "2.0.17" +webpki-roots = "1.0.4" +anyhow = "1.0.100" +pem-rfc7468 = { version = "0.7.0", features = ["std"] } +configfs-tsm = "0.0.2" +rand_core = { version = "0.6.4", features = ["getrandom"] } +dcap-qvl = "0.3.4" +hex = "0.4.3" +http = "1.3.1" +serde_json = "1.0.145" +serde = "1.0.228" +base64 = "0.22.1" +reqwest = { version = "0.12.23", default-features = false, features = [ + "rustls-tls-webpki-roots-no-provider", +] } +tracing = "0.1.41" +parity-scale-codec = "3.7.5" +openssl = "0.10.75" +az-tdx-vtpm = { version = "0.7.4", optional = true } +tss-esapi = { version = "7.6.0", optional = true } +num-bigint = "0.4.6" +webpki = { package = "rustls-webpki", version = "0.103.8" } +time = "0.3.44" +once_cell = "1.21.3" +tokio-tungstenite = { version = "0.28.0", optional = true } +futures-util = { version = "0.3.31", optional = true } + +rcgen = { version = "0.14.5", optional = true } +tdx-quote = { version = "0.0.5", features = ["mock"], optional = true } + +[dev-dependencies] +rcgen = "0.14.5" +tempfile = "3.23.0" +tdx-quote = { version = "0.0.5", features = ["mock"] } + +[features] +default = ["azure", "ws"] + +# Adds support for Microsoft Azure attestation generation and verification +azure = ["tss-esapi", "az-tdx-vtpm"] + +# Adds websocket support +ws = ["tokio-tungstenite", "futures-util"] + +# Exposes helper functions for testing - do not enable in production as this allows dangerous configuration +test-helpers = ["rcgen"] + +mock = ["tdx-quote"] diff --git a/attested-tls/README.md b/attested-tls/README.md new file mode 100644 index 0000000..fed9c69 --- /dev/null +++ b/attested-tls/README.md @@ -0,0 +1,92 @@ +# attested-tls + +### Measurements File + +Accepted measurements for the remote party can be specified in a JSON file containing an array of objects, each of which specifies an accepted attestation type and set of measurements. + +This aims to match the formatting used by `cvm-reverse-proxy`. + +These objects have the following fields: +- `measurement_id` - a name used to describe the entry. For example the name and version of the CVM OS image that these measurements correspond to. +- `attestation_type` - a string containing one of the attestation types (confidential computing platforms) described below. +- `measurements` - an object with fields referring to the five measurement registers. Field names are the same as for the measurement headers (see below). + +Example: + +```JSON +[ + { + "measurement_id": "dcap-tdx-example", + "attestation_type": "dcap-tdx", + "measurements": { + "0": { + "expected": "47a1cc074b914df8596bad0ed13d50d561ad1effc7f7cc530ab86da7ea49ffc03e57e7da829f8cba9c629c3970505323" + }, + "1": { + "expected": "da6e07866635cb34a9ffcdc26ec6622f289e625c42c39b320f29cdf1dc84390b4f89dd0b073be52ac38ca7b0a0f375bb" + }, + "2": { + "expected": "a7157e7c5f932e9babac9209d4527ec9ed837b8e335a931517677fa746db51ee56062e3324e266e3f39ec26a516f4f71" + }, + "3": { + "expected": "e63560e50830e22fbc9b06cdce8afe784bf111e4251256cf104050f1347cd4ad9f30da408475066575145da0b098a124" + }, + "4": { + "expected": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + } + } + } +] +``` + +The only mandatory field is `attestation_type`. If an attestation type is specified, but no measurements, *any* measurements will be accepted for this attestation type. The measurements can still be checked up-stream by the source client or target service using header injection described below. But it is then up to these external programs to reject unacceptable measurements. + +If a measurements file is not provided, a single allowed attestation type **must** be specified using the `--allowed-remote-attestation-type` option. This may be `none` for cases where the remote party is not running in a CVM, but that must be explicitly specified. + +### Attestation Types + +These are the attestation type names used in the HTTP headers, and the measurements file, and when specifying a local attestation type with the `--client-attestation-type` or `--server-attestation-type` command line options. + +- `auto` - detect attestation type (used only when specifying the local attestation type as a command-line argument) +- `none` - No attestation provided +- `dummy` - Forwards the attestation to a remote service (for testing purposes, not yet supported) +- `gcp-tdx` - DCAP TDX on Google Cloud Platform +- `azure-tdx` - TDX on Azure, with MAA (not yet supported) +- `qemu-tdx` - TDX on Qemu (no cloud platform) +- `dcap-tdx` - DCAP TDX (platform not specified) + +## Protocol Specification + +This is based on TLS 1.3. + +The protocol name `flashbots-ratls/1` must be given in the TLS configuration for ALPN protocol negotiation during the TLS handshake. Future versions of this protocol will use incrementing version numbers, eg: `flashbots-ratls/2`. + +### Attestation Exchange + +Immediately after the TLS handshake, an attestation exchange is made. The server first provides an attestation message (even if it has the `none` attestation type). The client verifies, if verification is successful it also provides an attestation message and otherwise closes the connection. If the server cannot verify the client's attestation, it closes the connection. + +Attestation exchange messages are formatted as follows: +- A 4 byte length prefix - a big endian encoded unsigned 32 bit integer +- A SCALE (Simple Concatenated Aggregate Little-Endian) encoded [struct](./src/attestation/mod.rs) with the following fields: + - `attestation_type` - a string with one of the attestation types (described above) including `none`. + - `attestation` - the actual attestation data. In the case of DCAP this is a binary quote report. In the case of `none` this is an empty byte array. + +SCALE is used by parity/substrate and was chosen because it is simple and actually matches the formatting used in TDX quotes. So it was already used as a dependency (via the [`dcap-qvl`](https://docs.rs/dcap-qvl) crate). + +### Attestation Generation and Verification + +Attestation input takes the form of a 64 byte array. + +The first 32 bytes are the SHA256 hash of the encoded public key from the TLS leaf certificate of the party providing the attestation, DER encoded exactly as given in the certificate. + +The remaining 32 bytes are exported key material ([RFC5705](https://www.rfc-editor.org/rfc/rfc5705)) from the TLS session. This must have the exporter label `EXPORTER-Channel-Binding` and no context data. + +In the case of attestation types `dcap-tdx`, `gcp-tdx`, and `qemu-tdx`, a standard DCAP attestation is generated using the `configfs-tsm` linux filesystem interface. This means that this binary must be run with access to `/sys/kernel/config/tsm/report` which on many systems requires sudo. + +When verifying DCAP attestations, the Intel PCS is used to retrieve collateral unless a PCCS url is provided via a command line argument. If expired TCB collateral is provided, the quote will fail to verify. + +## Dependencies and feature flags + +The `azure` feature, for Microsoft Azure attestation requires [tpm2](https://tpm2-software.github.io) to be installed. On Debian-based systems this is provided by [`libtss2-dev`](https://packages.debian.org/trixie/libtss2-dev), and on nix `tpm2-tss`. + +This feature is enabled by default. For non-azure deployments you can compile without this requirement by specifying `--no-default-features`. But note that this is will disable both generation and verification of azure attestations. diff --git a/src/attestation/azure/ak_certificate.rs b/attested-tls/src/attestation/azure/ak_certificate.rs similarity index 100% rename from src/attestation/azure/ak_certificate.rs rename to attested-tls/src/attestation/azure/ak_certificate.rs diff --git a/src/attestation/azure/mod.rs b/attested-tls/src/attestation/azure/mod.rs similarity index 100% rename from src/attestation/azure/mod.rs rename to attested-tls/src/attestation/azure/mod.rs diff --git a/src/attestation/azure/nv_index.rs b/attested-tls/src/attestation/azure/nv_index.rs similarity index 100% rename from src/attestation/azure/nv_index.rs rename to attested-tls/src/attestation/azure/nv_index.rs diff --git a/src/attestation/dcap.rs b/attested-tls/src/attestation/dcap.rs similarity index 96% rename from src/attestation/dcap.rs rename to attested-tls/src/attestation/dcap.rs index 5094972..921d222 100644 --- a/src/attestation/dcap.rs +++ b/attested-tls/src/attestation/dcap.rs @@ -19,7 +19,7 @@ pub async fn create_dcap_attestation(input_data: [u8; 64]) -> Result, At } /// Verify a DCAP TDX quote, and return the measurement values -#[cfg(not(test))] +#[cfg(not(any(test, feature = "mock")))] pub async fn verify_dcap_attestation( input: Vec, expected_input_data: [u8; 64], @@ -32,7 +32,7 @@ pub async fn verify_dcap_attestation( } /// Allows the timestamp to be given, making it possible to test with existing attestations -async fn verify_dcap_attestation_with_given_timestamp( +pub async fn verify_dcap_attestation_with_given_timestamp( input: Vec, expected_input_data: [u8; 64], pccs_url: Option, @@ -62,7 +62,7 @@ async fn verify_dcap_attestation_with_given_timestamp( Ok(measurements) } -#[cfg(test)] +#[cfg(any(test, feature = "mock"))] pub async fn verify_dcap_attestation( input: Vec, expected_input_data: [u8; 64], @@ -77,7 +77,7 @@ pub async fn verify_dcap_attestation( } /// Create a mock quote for testing on non-confidential hardware -#[cfg(test)] +#[cfg(any(test, feature = "mock"))] fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { let attestation_key = tdx_quote::SigningKey::random(&mut rand_core::OsRng); let provisioning_certification_key = tdx_quote::SigningKey::random(&mut rand_core::OsRng); @@ -91,7 +91,7 @@ fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { } /// Create a quote -#[cfg(not(test))] +#[cfg(not(any(test, feature = "mock")))] fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { configfs_tsm::create_tdx_quote(input) } @@ -116,7 +116,7 @@ pub enum DcapVerificationError { SystemTime(#[from] std::time::SystemTimeError), #[error("DCAP quote verification: {0}")] DcapQvl(#[from] anyhow::Error), - #[cfg(test)] + #[cfg(any(test, feature = "mock"))] #[error("Quote parse: {0}")] QuoteParse(#[from] tdx_quote::QuoteParseError), } diff --git a/src/attestation/measurements.rs b/attested-tls/src/attestation/measurements.rs similarity index 99% rename from src/attestation/measurements.rs rename to attested-tls/src/attestation/measurements.rs index 0f2c721..be039fe 100644 --- a/src/attestation/measurements.rs +++ b/attested-tls/src/attestation/measurements.rs @@ -120,7 +120,7 @@ impl MultiMeasurements { ]))) } - #[cfg(test)] + #[cfg(any(test, feature = "mock"))] pub fn from_tdx_quote(quote: &tdx_quote::Quote) -> Self { Self::Dcap(HashMap::from([ (DcapMeasurementRegister::MRTD, quote.mrtd()), @@ -236,7 +236,7 @@ impl MeasurementPolicy { } /// Expect mock measurements used in tests - #[cfg(test)] + #[cfg(any(test, feature = "mock"))] pub fn mock() -> Self { Self { accepted_measurements: vec![MeasurementRecord { diff --git a/src/attestation/mod.rs b/attested-tls/src/attestation/mod.rs similarity index 99% rename from src/attestation/mod.rs rename to attested-tls/src/attestation/mod.rs index a3df20d..09b145a 100644 --- a/src/attestation/mod.rs +++ b/attested-tls/src/attestation/mod.rs @@ -283,7 +283,7 @@ impl AttestationVerifier { } /// Expect mock measurements used in tests - #[cfg(test)] + #[cfg(any(test, feature = "test-helpers"))] pub fn mock() -> Self { Self { measurement_policy: MeasurementPolicy::mock(), diff --git a/src/attested_tls.rs b/attested-tls/src/lib.rs similarity index 77% rename from src/attested_tls.rs rename to attested-tls/src/lib.rs index 6b5c2ee..c9af316 100644 --- a/src/attested_tls.rs +++ b/attested-tls/src/lib.rs @@ -1,10 +1,15 @@ //! Attested TLS protocol server and client -use crate::{ - attestation::{ - measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage, - AttestationGenerator, AttestationType, AttestationVerifier, - }, - host_to_host_with_port, +pub mod attestation; + +#[cfg(feature = "ws")] +pub mod websockets; + +#[cfg(any(test, feature = "test-helpers"))] +pub mod test_helpers; + +use crate::attestation::{ + measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage, + AttestationGenerator, AttestationType, AttestationVerifier, }; use parity_scale_codec::{Decode, Encode}; use sha2::{Digest, Sha256}; @@ -51,6 +56,16 @@ pub struct AttestedTlsServer { acceptor: TlsAcceptor, } +impl std::fmt::Debug for AttestedTlsServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AttestedTlsServer") + .field("attestation_generator", &self.attestation_generator) + .field("attestation_verifier", &self.attestation_verifier) + .field("cert_chain", &self.cert_chain) + .finish() + } +} + impl AttestedTlsServer { pub async fn new( cert_and_key: TlsCertAndKey, @@ -88,13 +103,16 @@ impl AttestedTlsServer { /// Start with preconfigured TLS /// - /// This is not fully public as it allows dangerous configuration - pub(crate) async fn new_with_tls_config( + /// This allows dangerous configuration + pub async fn new_with_tls_config( cert_chain: Vec>, server_config: Arc, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result { + #[cfg(feature = "mock")] + tracing::warn!("AttestedTlsServer instantiated in MOCK mode - do NOT use in production"); + let acceptor = tokio_rustls::TlsAcceptor::from(server_config); Ok(Self { @@ -256,13 +274,16 @@ impl AttestedTlsClient { /// Create a new proxy client with given TLS configuration /// - /// This not fully public as it allows dangerous configuration but is used in tests - pub(crate) async fn new_with_tls_config( + /// This allows dangerous configuration but is used in tests + pub async fn new_with_tls_config( client_config: Arc, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { + #[cfg(feature = "mock")] + tracing::warn!("AttestedTlsClient instantiated in MOCK mode - do NOT use in production"); + let connector = TlsConnector::from(client_config.clone()); Ok(Self { @@ -414,8 +435,8 @@ pub async fn get_tls_cert( } /// Helper for testing getting remote certificate -#[cfg(test)] -pub(crate) async fn get_tls_cert_with_config( +#[cfg(any(test, feature = "test-helpers"))] +pub async fn get_tls_cert_with_config( server_name: &str, attestation_verifier: AttestationVerifier, client_config: Arc, @@ -502,10 +523,26 @@ fn server_name_from_host( ServerName::try_from(host_part.to_string()) } +/// If no port was provided, default to 443 +fn host_to_host_with_port(host: &str) -> String { + if host.contains(':') { + host.to_string() + } else { + format!("{host}:443") + } +} + #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; - use crate::test_helpers::{generate_certificate_chain, generate_tls_config}; + use crate::{ + attestation::measurements::{ + DcapMeasurementRegister, MeasurementPolicy, MeasurementRecord, + }, + test_helpers::{generate_certificate_chain, generate_tls_config}, + }; use tokio::net::TcpListener; #[tokio::test] @@ -543,4 +580,105 @@ mod tests { let (_stream, _measurements, _attestation_type) = client.connect_tcp(&server_addr.to_string()).await.unwrap(); } + + // Negative test - server does not provide attestation but client requires it + // Server has no attestation, client has no attestation and no client auth + #[tokio::test] + async fn fails_on_no_attestation_when_expected() { + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (tcp_stream, _) = listener.accept().await.unwrap(); + let (_stream, _measurements, _attestation_type) = + server.handle_connection(tcp_stream).await.unwrap(); + }); + + let client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let client_result = client.connect_tcp(&server_addr.to_string()).await; + + assert!(matches!( + client_result.unwrap_err(), + AttestedTlsError::Attestation(AttestationError::AttestationTypeNotAccepted) + )); + } + + // Negative test - server does not provide attestation but client requires it + // Server has no attestaion, client has no attestation and no client auth + #[tokio::test] + async fn fails_on_bad_measurements() { + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (tcp_stream, _) = listener.accept().await.unwrap(); + let (_stream, _measurements, _attestation_type) = + server.handle_connection(tcp_stream).await.unwrap(); + }); + + let attestation_verifier = AttestationVerifier { + measurement_policy: MeasurementPolicy { + accepted_measurements: vec![MeasurementRecord { + measurement_id: "test".to_string(), + measurements: MultiMeasurements::Dcap(HashMap::from([ + (DcapMeasurementRegister::MRTD, [0; 48]), + (DcapMeasurementRegister::RTMR0, [0; 48]), + (DcapMeasurementRegister::RTMR1, [1; 48]), // This differs from the mock measurements + (DcapMeasurementRegister::RTMR2, [0; 48]), + (DcapMeasurementRegister::RTMR3, [0; 48]), + ])), + }], + }, + pccs_url: None, + log_dcap_quote: false, + }; + + let client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + attestation_verifier, + None, + ) + .await + .unwrap(); + + let client_result = client.connect_tcp(&server_addr.to_string()).await; + + assert!(matches!( + client_result.unwrap_err(), + AttestedTlsError::Attestation(AttestationError::MeasurementsNotAccepted) + )); + } } diff --git a/attested-tls/src/test_helpers.rs b/attested-tls/src/test_helpers.rs new file mode 100644 index 0000000..0008a22 --- /dev/null +++ b/attested-tls/src/test_helpers.rs @@ -0,0 +1,143 @@ +//! Helper functions used in tests +use std::{collections::HashMap, net::IpAddr, sync::Arc}; +use tokio_rustls::rustls::{ + pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, + server::{danger::ClientCertVerifier, WebPkiClientVerifier}, + ClientConfig, RootCertStore, ServerConfig, +}; + +use crate::{ + attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}, + SUPPORTED_ALPN_PROTOCOL_VERSIONS, +}; + +/// Helper to generate a self-signed certificate for testing +pub fn generate_certificate_chain( + ip: IpAddr, +) -> (Vec>, PrivateKeyDer<'static>) { + let mut params = rcgen::CertificateParams::new(vec![]).unwrap(); + params.subject_alt_names.push(rcgen::SanType::IpAddress(ip)); + params + .distinguished_name + .push(rcgen::DnType::CommonName, ip.to_string()); + + let keypair = rcgen::KeyPair::generate().unwrap(); + let cert = params.self_signed(&keypair).unwrap(); + + let certs = vec![CertificateDer::from(cert)]; + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(keypair.serialize_der())); + (certs, key) +} + +/// Helper to generate TLS configuration for testing +/// +/// For the server: A given self-signed certificate +/// For the client: A root certificate store with the server's certificate +pub fn generate_tls_config( + certificate_chain: Vec>, + key: PrivateKeyDer<'static>, +) -> (Arc, Arc) { + let supported_protocols: Vec<_> = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certificate_chain.clone(), key) + .expect("Failed to create rustls server config"); + + server_config.alpn_protocols = supported_protocols.clone(); + + let mut root_store = RootCertStore::empty(); + root_store.add(certificate_chain[0].clone()).unwrap(); + + let mut client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + client_config.alpn_protocols = supported_protocols; + + (Arc::new(server_config), Arc::new(client_config)) +} + +/// Helper to generate a mutual TLS configuration with client authentification for testing +pub fn generate_tls_config_with_client_auth( + alice_certificate_chain: Vec>, + alice_key: PrivateKeyDer<'static>, + bob_certificate_chain: Vec>, + bob_key: PrivateKeyDer<'static>, +) -> ( + (Arc, Arc), + (Arc, Arc), +) { + let supported_protocols: Vec<_> = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + + let (alice_client_verifier, alice_root_store) = + client_verifier_from_remote_cert(bob_certificate_chain[0].clone()); + + let mut alice_server_config = ServerConfig::builder() + .with_client_cert_verifier(alice_client_verifier) + .with_single_cert(alice_certificate_chain.clone(), alice_key.clone_key()) + .expect("Failed to create rustls server config"); + + alice_server_config.alpn_protocols = supported_protocols.clone(); + + let mut alice_client_config = ClientConfig::builder() + .with_root_certificates(alice_root_store) + .with_client_auth_cert(alice_certificate_chain.clone(), alice_key) + .unwrap(); + + alice_client_config.alpn_protocols = supported_protocols.clone(); + + let (bob_client_verifier, bob_root_store) = + client_verifier_from_remote_cert(alice_certificate_chain[0].clone()); + + let mut bob_server_config = ServerConfig::builder() + .with_client_cert_verifier(bob_client_verifier) + .with_single_cert(bob_certificate_chain.clone(), bob_key.clone_key()) + .expect("Failed to create rustls server config"); + + bob_server_config.alpn_protocols = supported_protocols.clone(); + + let mut bob_client_config = ClientConfig::builder() + .with_root_certificates(bob_root_store) + .with_client_auth_cert(bob_certificate_chain, bob_key) + .unwrap(); + + bob_client_config.alpn_protocols = supported_protocols; + ( + (Arc::new(alice_server_config), Arc::new(alice_client_config)), + (Arc::new(bob_server_config), Arc::new(bob_client_config)), + ) +} + +/// Given a TLS certificate, return a [WebPkiClientVerifier] and [RootCertStore] which will accept +/// that certificate +fn client_verifier_from_remote_cert( + cert: CertificateDer<'static>, +) -> (Arc, RootCertStore) { + let mut root_store = RootCertStore::empty(); + root_store.add(cert).unwrap(); + + ( + WebPkiClientVerifier::builder(Arc::new(root_store.clone())) + .build() + .unwrap(), + root_store, + ) +} + +/// All-zero measurment values used in some tests +pub fn mock_dcap_measurements() -> MultiMeasurements { + MultiMeasurements::Dcap(HashMap::from([ + (DcapMeasurementRegister::MRTD, [0u8; 48]), + (DcapMeasurementRegister::RTMR0, [0u8; 48]), + (DcapMeasurementRegister::RTMR1, [0u8; 48]), + (DcapMeasurementRegister::RTMR2, [0u8; 48]), + (DcapMeasurementRegister::RTMR3, [0u8; 48]), + ])) +} diff --git a/src/websockets.rs b/attested-tls/src/websockets.rs similarity index 98% rename from src/websockets.rs rename to attested-tls/src/websockets.rs index 3c9e4dc..90823b0 100644 --- a/src/websockets.rs +++ b/attested-tls/src/websockets.rs @@ -5,7 +5,7 @@ use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream} use crate::{ attestation::{measurements::MultiMeasurements, AttestationType}, - attested_tls::{AttestedTlsClient, AttestedTlsError, AttestedTlsServer}, + AttestedTlsClient, AttestedTlsError, AttestedTlsServer, }; /// Websocket message type re-exported for convenience diff --git a/test-assets/azure-tdx-1764662251380464271 b/attested-tls/test-assets/azure-tdx-1764662251380464271 similarity index 100% rename from test-assets/azure-tdx-1764662251380464271 rename to attested-tls/test-assets/azure-tdx-1764662251380464271 diff --git a/test-assets/dcap-tdx-1766059550570652607 b/attested-tls/test-assets/dcap-tdx-1766059550570652607 similarity index 100% rename from test-assets/dcap-tdx-1766059550570652607 rename to attested-tls/test-assets/dcap-tdx-1766059550570652607 diff --git a/test-assets/hclreport.bin b/attested-tls/test-assets/hclreport.bin similarity index 100% rename from test-assets/hclreport.bin rename to attested-tls/test-assets/hclreport.bin diff --git a/test-assets/measurements.json b/attested-tls/test-assets/measurements.json similarity index 100% rename from test-assets/measurements.json rename to attested-tls/test-assets/measurements.json diff --git a/test-assets/measurements_2.json b/attested-tls/test-assets/measurements_2.json similarity index 100% rename from test-assets/measurements_2.json rename to attested-tls/test-assets/measurements_2.json diff --git a/src/lib.rs b/src/lib.rs index 8d33f1e..3378e87 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,12 @@ //! An attested TLS protocol and HTTPS proxy -pub mod attestation; pub mod attested_get; -pub mod attested_tls; pub mod file_server; pub mod health_check; pub mod normalize_pem; -#[cfg(feature = "azure")] -pub mod websockets; - -pub use attestation::AttestationGenerator; +pub use attested_tls; +pub use attested_tls::attestation; +pub use attested_tls::attestation::AttestationGenerator; use bytes::Bytes; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -32,11 +29,11 @@ use tokio_rustls::rustls::pki_types::CertificateDer; #[cfg(test)] use tokio_rustls::rustls::{ClientConfig, ServerConfig}; -use crate::{ +use attested_tls::{ attestation::{ measurements::MultiMeasurements, AttestationError, AttestationType, AttestationVerifier, }, - attested_tls::{AttestedTlsClient, AttestedTlsError, AttestedTlsServer, TlsCertAndKey}, + AttestedTlsClient, AttestedTlsError, AttestedTlsServer, TlsCertAndKey, }; /// The header name for giving attestation type @@ -332,10 +329,8 @@ impl ProxyClient { Self::new_with_inner(address, attested_tls_client, &target_name).await } - /// Create a new proxy client with given TLS configuration - /// - /// This is private as it allows dangerous configuration but is used in tests - async fn new_with_inner( + /// Create a new proxy client with given [AttestedTlsClient] + pub async fn new_with_inner( address: impl ToSocketAddrs, attested_tls_client: AttestedTlsClient, target_name: &str, @@ -622,13 +617,8 @@ where #[cfg(test)] mod tests { - use std::collections::HashMap; - use crate::{ - attestation::measurements::{ - DcapMeasurementRegister, MeasurementPolicy, MeasurementRecord, MultiMeasurements, - }, - attested_tls::get_tls_cert_with_config, + attestation::measurements::MultiMeasurements, attested_tls::get_tls_cert_with_config, }; use super::*; @@ -640,6 +630,7 @@ mod tests { // Server has mock DCAP, client has no attestation and no client auth #[tokio::test] async fn http_proxy_with_server_attestation() { + let _ = tracing_subscriber::fmt::try_init(); let target_addr = example_http_service().await; let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); @@ -1000,109 +991,4 @@ mod tests { assert_eq!(retrieved_chain, cert_chain); } - - // Negative test - server does not provide attestation but client requires it - // Server has no attestaion, client has no attestation and no client auth - #[tokio::test] - async fn fails_on_no_attestation_when_expected() { - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::expect_none(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let proxy_client_result = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - proxy_addr.to_string(), - AttestationGenerator::with_no_attestation(), - AttestationVerifier::mock(), - None, - ) - .await; - - assert!(matches!( - proxy_client_result.unwrap_err(), - ProxyError::AttestedTls(AttestedTlsError::Attestation( - AttestationError::AttestationTypeNotAccepted - )) - )); - } - - // Negative test - server does not provide attestation but client requires it - // Server has no attestaion, client has no attestation and no client auth - #[tokio::test] - async fn fails_on_bad_measurements() { - let target_addr = example_http_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", - target_addr.to_string(), - AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), - AttestationVerifier::expect_none(), - ) - .await - .unwrap(); - - let proxy_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let attestation_verifier = AttestationVerifier { - measurement_policy: MeasurementPolicy { - accepted_measurements: vec![MeasurementRecord { - measurement_id: "test".to_string(), - measurements: MultiMeasurements::Dcap(HashMap::from([ - (DcapMeasurementRegister::MRTD, [0; 48]), - (DcapMeasurementRegister::RTMR0, [0; 48]), - (DcapMeasurementRegister::RTMR1, [1; 48]), // This differs from the mock measurements - (DcapMeasurementRegister::RTMR2, [0; 48]), - (DcapMeasurementRegister::RTMR3, [0; 48]), - ])), - }], - }, - pccs_url: None, - log_dcap_quote: false, - }; - - let proxy_client_result = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0".to_string(), - proxy_addr.to_string(), - AttestationGenerator::with_no_attestation(), - attestation_verifier, - None, - ) - .await; - - assert!(matches!( - proxy_client_result.unwrap_err(), - ProxyError::AttestedTls(AttestedTlsError::Attestation( - AttestationError::MeasurementsNotAccepted - )) - )); - } } diff --git a/src/main.rs b/src/main.rs index f946139..a4fbdc7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,9 +6,11 @@ use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ - attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier}, attested_get::attested_get, - attested_tls::{get_tls_cert, TlsCertAndKey}, + attested_tls::{ + attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier}, + get_tls_cert, TlsCertAndKey, + }, file_server::attested_file_server, health_check, normalize_pem::normalize_private_key_pem_to_pkcs8,