diff --git a/crates/attestation/src/azure/mod.rs b/crates/attestation/src/azure/mod.rs index 42c88ed..4c8aecc 100644 --- a/crates/attestation/src/azure/mod.rs +++ b/crates/attestation/src/azure/mod.rs @@ -802,7 +802,7 @@ mod tests { /// Verify a complete observed Azure attestation payload that includes /// AK intermediates fetched from the leaf certificate's AIA URLs. #[tokio::test] - async fn test_verify() { + async fn test_verify_with_ak_intermediates() { // generated using [capture_azure_fixture] above. let attestation_bytes: &'static [u8] = include_bytes!("../../test-assets/azure-tdx-with-ak-intermediates-1780922561.yaml"); diff --git a/crates/attestation/src/dcap.rs b/crates/attestation/src/dcap.rs index dca2c10..96ae6e0 100644 --- a/crates/attestation/src/dcap.rs +++ b/crates/attestation/src/dcap.rs @@ -33,7 +33,7 @@ pub async fn verify_dcap_attestation( input: Vec, expected_input_data: [u8; 64], pccs: Option, -) -> Result { +) -> Result<(MultiMeasurements, Quote), DcapVerificationError> { let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs(); let override_azure_outdated_tcb = false; verify_dcap_attestation_with_given_timestamp( @@ -58,7 +58,7 @@ pub fn verify_dcap_attestation_sync( input: Vec, expected_input_data: [u8; 64], pccs: Pccs, -) -> Result { +) -> Result<(MultiMeasurements, Quote), DcapVerificationError> { let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs(); let override_azure_outdated_tcb = false; verify_dcap_attestation_with_timestamp_sync( @@ -84,7 +84,7 @@ pub fn verify_dcap_attestation_with_timestamp_sync( collateral: Option, now: u64, override_azure_outdated_tcb: bool, -) -> Result { +) -> Result<(MultiMeasurements, Quote), DcapVerificationError> { let quote = Quote::parse(&input)?; let ca = quote.ca()?; @@ -118,7 +118,7 @@ pub async fn verify_dcap_attestation_with_given_timestamp( collateral: Option, now: u64, override_azure_outdated_tcb: bool, -) -> Result { +) -> Result<(MultiMeasurements, Quote), DcapVerificationError> { let quote = Quote::parse(&input)?; let ca = quote.ca()?; @@ -156,7 +156,7 @@ fn verify_dcap_attestation_with_collateral_and_timestamp( collateral: QuoteCollateralV3, now: u64, override_azure_outdated_tcb: bool, -) -> Result { +) -> Result<(MultiMeasurements, Quote), DcapVerificationError> { tracing::info!("Verifying DCAP attestation: {quote:?}"); let fmspc = hex::encode_upper(quote.fmspc()?); @@ -197,11 +197,11 @@ fn verify_dcap_attestation_with_collateral_and_timestamp( let measurements = MultiMeasurements::from_dcap_qvl_quote("e)?; - if get_quote_input_data(quote.report) != expected_input_data { + if get_quote_input_data("e.report) != expected_input_data { return Err(DcapVerificationError::InputMismatch); } - Ok(measurements) + Ok((measurements, quote)) } #[cfg(any(test, feature = "mock"))] @@ -209,7 +209,7 @@ pub async fn verify_dcap_attestation( input: Vec, expected_input_data: [u8; 64], pccs: Option, -) -> Result { +) -> Result<(MultiMeasurements, Quote), DcapVerificationError> { let quote = Quote::parse(&input)?; let ca = quote.ca()?; let fmspc = hex::encode_upper(quote.fmspc()?); @@ -224,11 +224,11 @@ pub async fn verify_dcap_attestation( verifier.verify(&input, &collateral, now)?; let measurements = MultiMeasurements::from_dcap_qvl_quote("e)?; - if get_quote_input_data(quote.report) != expected_input_data { + if get_quote_input_data("e.report) != expected_input_data { return Err(DcapVerificationError::InputMismatch); } - Ok(measurements) + Ok((measurements, quote)) } #[cfg(any(test, feature = "mock"))] @@ -236,7 +236,7 @@ pub fn verify_dcap_attestation_sync( input: Vec, expected_input_data: [u8; 64], pccs: Pccs, -) -> Result { +) -> Result<(MultiMeasurements, Quote), DcapVerificationError> { let quote = Quote::parse(&input)?; let ca = quote.ca()?; let fmspc = hex::encode_upper(quote.fmspc()?); @@ -246,10 +246,11 @@ pub fn verify_dcap_attestation_sync( verifier.verify(&input, &collateral, now)?; let measurements = MultiMeasurements::from_dcap_qvl_quote("e)?; - if get_quote_input_data(quote.report.clone()) != expected_input_data { + if get_quote_input_data("e.report) != expected_input_data { return Err(DcapVerificationError::InputMismatch); } - Ok(measurements) + + Ok((measurements, quote)) } /// Create a mock quote for testing on non-confidential hardware @@ -267,7 +268,7 @@ fn generate_quote(input: [u8; 64]) -> Result, tdx_attest::TdxAttestError } /// Given a [Report] get the input data regardless of report type -pub fn get_quote_input_data(report: Report) -> [u8; 64] { +pub fn get_quote_input_data(report: &Report) -> [u8; 64] { match report { Report::TD10(r) => r.report_data, Report::TD15(r) => r.base.report_data, @@ -297,7 +298,7 @@ mod tests { use mock_tdx::{MockPcsConfig, spawn_mock_pcs_server}; use super::*; - use crate::measurements::MeasurementPolicy; + use crate::{AttestationType, measurements::MeasurementPolicy}; #[tokio::test] async fn test_dcap_verify() { @@ -331,7 +332,7 @@ mod tests { let async_collateral = serde_saphyr::from_slice(collateral_bytes).unwrap(); let sync_collateral = serde_saphyr::from_slice(collateral_bytes).unwrap(); - let async_measurements = verify_dcap_attestation_with_given_timestamp( + let (async_measurements, _) = verify_dcap_attestation_with_given_timestamp( attestation_bytes.to_vec(), [ 116, 39, 106, 100, 143, 31, 212, 145, 244, 116, 162, 213, 44, 114, 216, 80, 227, @@ -347,7 +348,7 @@ mod tests { .await .unwrap(); - let sync_measurements = verify_dcap_attestation_with_timestamp_sync( + let (sync_measurements, _) = verify_dcap_attestation_with_timestamp_sync( attestation_bytes.to_vec(), [ 116, 39, 106, 100, 143, 31, 212, 145, 244, 116, 162, 213, 44, 114, 216, 80, 227, @@ -363,7 +364,9 @@ mod tests { .unwrap(); assert_eq!(async_measurements, sync_measurements); - measurement_policy.check_measurement(&async_measurements).unwrap(); + measurement_policy + .check_measurement(AttestationType::DcapTdx, &async_measurements) + .unwrap(); } // This specifically tests a quote which has outdated TCB level from Azure @@ -409,7 +412,7 @@ mod tests { let expected_input_data = [0xA5; 64]; let attestation_bytes = create_dcap_attestation(expected_input_data).unwrap(); - let measurements = + let (measurements, _) = verify_dcap_attestation(attestation_bytes, expected_input_data, Some(pccs)) .await .unwrap(); diff --git a/crates/attestation/src/gcp.rs b/crates/attestation/src/gcp.rs new file mode 100644 index 0000000..fe328dc --- /dev/null +++ b/crates/attestation/src/gcp.rs @@ -0,0 +1,361 @@ +use std::{ + collections::HashMap, + io::Read, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; + +use dcap_qvl::{intel, quote::Quote}; +use serde_json::Value; +use thiserror::Error; + +/// Public registry of GCP Confidential VM TDX PPIDs +const GCP_PROVENANCE_REGISTRY_URL: &str = + "https://storage.googleapis.com/confidential-host-registry"; + +/// Maximum size in bytes of GCP provenance documents +const GCP_PROVENANCE_DOCUMENT_MAX_BYTES: u64 = 16 * 1024; +/// How long a cached PPID remains trusted before revalidation +const GCP_PROVENANCE_CACHE_TTL: Duration = Duration::from_secs(7 * 24 * 60 * 60); + +/// Checks PPIDs extracted from DCAP quotes against Googles public bucket, +/// to establish whether this is a GCP machine +#[derive(Clone, Debug)] +pub(crate) struct GcpProvenanceChecker { + /// Cached entries with retrieval timestamp + known_gcp_ppids: Arc, Instant>>>, +} + +impl GcpProvenanceChecker { + pub(crate) fn new() -> Self { + Self { known_gcp_ppids: Default::default() } + } + + /// Given a DCAP TDX quote, check if the associated PPID has a + /// 'provenance document' from GCP + pub(crate) async fn verify_provenance(&self, quote: Quote) -> Result<(), GcpProvenanceError> { + let now = Instant::now(); + let checker = self.clone(); + tokio::task::spawn_blocking(move || { + checker.verify_provenance_with_registry_url_sync_at( + "e, + GCP_PROVENANCE_REGISTRY_URL, + now, + ) + }) + .await + .map_err(|err| GcpProvenanceError::TaskJoin(err.to_string()))? + } + + /// Given a DCAP TDX quote, check if the associated PPID has a + /// 'provenance document' from GCP + pub(crate) fn verify_provenance_sync(&self, quote: &Quote) -> Result<(), GcpProvenanceError> { + self.verify_provenance_with_registry_url_sync_at( + quote, + GCP_PROVENANCE_REGISTRY_URL, + Instant::now(), + ) + } + + fn verify_provenance_with_registry_url_sync_at( + &self, + quote: &Quote, + registry_url: &str, + now: Instant, + ) -> Result<(), GcpProvenanceError> { + let ppid = extract_ppid_from_quote(quote)?; + { + let known_gcp_ppids = self + .known_gcp_ppids + .read() + .map_err(|err| GcpProvenanceError::CacheLock(err.to_string()))?; + if let Some(stored_at) = known_gcp_ppids.get(&ppid) && + is_cache_entry_fresh(*stored_at, now) + { + return Ok(()); + } + } + + { + let mut known_gcp_ppids = self + .known_gcp_ppids + .write() + .map_err(|err| GcpProvenanceError::CacheLock(err.to_string()))?; + if known_gcp_ppids + .get(&ppid) + .is_some_and(|stored_at| !is_cache_entry_fresh(*stored_at, now)) + { + known_gcp_ppids.remove(&ppid); + } else if known_gcp_ppids.contains_key(&ppid) { + return Ok(()); + } + } + + let provenance_url = + format!("{}/{}", registry_url.trim_end_matches('/'), hex::encode(&ppid)); + let document = fetch_provenance_document(&provenance_url)?; + validate_provenance_document(&document)?; + + let fetched_at = Instant::now(); + self.known_gcp_ppids + .write() + .map_err(|err| GcpProvenanceError::CacheLock(err.to_string()))? + .insert(ppid, fetched_at); + + Ok(()) + } +} + +fn is_cache_entry_fresh(stored_at: Instant, now: Instant) -> bool { + now.checked_duration_since(stored_at).is_some_and(|age| age <= GCP_PROVENANCE_CACHE_TTL) +} + +fn extract_ppid_from_quote(quote: &Quote) -> Result, GcpProvenanceError> { + let cert_chain = intel::extract_cert_chain(quote) + .map_err(|err| GcpProvenanceError::PpidExtraction(err.to_string()))?; + let leaf = cert_chain.first().ok_or(GcpProvenanceError::NoPckCertificate)?; + let extension = intel::parse_pck_extension(leaf) + .map_err(|err| GcpProvenanceError::PpidExtraction(err.to_string()))?; + + if extension.ppid.is_empty() { + return Err(GcpProvenanceError::EmptyPpid); + } + + Ok(extension.ppid) +} + +fn fetch_provenance_document(url: &str) -> Result { + let agent = ureq::AgentBuilder::new().timeout(Duration::from_secs(2)).build(); + let response = + agent.get(url).call().map_err(|err| GcpProvenanceError::RegistryFetch(err.to_string()))?; + + let mut limited_reader = response.into_reader().take(GCP_PROVENANCE_DOCUMENT_MAX_BYTES + 1); + let mut document = String::new(); + limited_reader + .read_to_string(&mut document) + .map_err(|err| GcpProvenanceError::RegistryFetch(err.to_string()))?; + + if document.len() as u64 > GCP_PROVENANCE_DOCUMENT_MAX_BYTES { + return Err(GcpProvenanceError::DocumentTooLarge); + } + + Ok(document) +} + +fn validate_provenance_document(document: &str) -> Result<(), GcpProvenanceError> { + let value: Value = serde_json::from_str(document)?; + let object = value.as_object().ok_or(GcpProvenanceError::InvalidDocument)?; + + let has_zone = object.get("zone").and_then(Value::as_str).is_some_and(|zone| !zone.is_empty()); + let has_timestamp = object.get("timestamp").is_some_and(|timestamp| match timestamp { + Value::String(timestamp) => !timestamp.is_empty(), + Value::Number(_) => true, + _ => false, + }); + + if has_zone && has_timestamp { Ok(()) } else { Err(GcpProvenanceError::InvalidDocument) } +} + +#[derive(Error, Debug)] +pub enum GcpProvenanceError { + #[error("quote parse: {0}")] + Quote(String), + #[error("PCK certificate chain is empty")] + NoPckCertificate, + #[error("PPID is empty")] + EmptyPpid, + #[error("PPID extraction: {0}")] + PpidExtraction(String), + #[error("registry fetch: {0}")] + RegistryFetch(String), + #[error("provenance document is invalid")] + InvalidDocument, + #[error("provenance document exceeds maximum size")] + DocumentTooLarge, + #[error("provenance document JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("provenance cache lock: {0}")] + CacheLock(String), + #[error("blocking task join: {0}")] + TaskJoin(String), +} + +#[cfg(test)] +mod tests { + use std::{ + io::{Read as _, Write as _}, + net::SocketAddr, + thread, + time::{Duration, Instant}, + }; + + use super::*; + use crate::dcap; + + const MOCK_PPID_HEX: &str = "d04ec06d4e6d92dc90d0ad3cf5ee2ddf"; + + fn spawn_test_registry_server( + status: u16, + body: impl Into, + ) -> (SocketAddr, thread::JoinHandle) { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let body = body.into(); + + let handle = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let mut buf = [0u8; 1024]; + let bytes_read = stream.read(&mut buf).unwrap(); + let request = String::from_utf8_lossy(&buf[..bytes_read]).to_string(); + let status_text = if status == 200 { "OK" } else { "Not Found" }; + let response = format!( + "HTTP/1.1 {status} {status_text}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", + body.len() + ); + stream.write_all(response.as_bytes()).unwrap(); + request + }); + + (addr, handle) + } + + #[test] + fn extracts_ppid_from_mock_tdx_quote() { + let attestation = dcap::create_dcap_attestation([0u8; 64]).unwrap(); + let quote = Quote::parse(&attestation).unwrap(); + let ppid = extract_ppid_from_quote("e).unwrap(); + + assert_eq!(hex::encode(ppid), MOCK_PPID_HEX); + } + + #[test] + fn extracts_ppid_from_fixture_dcap_quote() { + let attestation = include_bytes!("../test-assets/dcap-tdx-1766059550570652607"); + let quote = Quote::parse(attestation).unwrap(); + let ppid = extract_ppid_from_quote("e).unwrap(); + + assert_eq!(ppid.len(), 16); + assert!(!ppid.iter().all(|byte| *byte == 0)); + } + + #[test] + fn provenance_check_fetches_registry_document_for_ppid() { + let attestation = dcap::create_dcap_attestation([0u8; 64]).unwrap(); + let quote = Quote::parse(&attestation).unwrap(); + let (addr, request_handle) = spawn_test_registry_server( + 200, + r#"{"zone":"projects/test/zones/us-central1-a","timestamp":"2026-06-11T00:00:00Z"}"#, + ); + + GcpProvenanceChecker::new() + .verify_provenance_with_registry_url_sync_at( + "e, + &format!("http://{addr}"), + Instant::now(), + ) + .unwrap(); + + let request = request_handle.join().unwrap(); + assert!(request.starts_with(&format!("GET /{MOCK_PPID_HEX} HTTP/1.1"))); + } + + #[test] + fn provenance_check_caches_known_gcp_ppids() { + let attestation = dcap::create_dcap_attestation([0u8; 64]).unwrap(); + let quote = Quote::parse(&attestation).unwrap(); + let (addr, request_handle) = spawn_test_registry_server( + 200, + r#"{"zone":"projects/test/zones/us-central1-a","timestamp":"2026-06-11T00:00:00Z"}"#, + ); + let checker = GcpProvenanceChecker::new(); + let registry_url = format!("http://{addr}"); + + checker + .verify_provenance_with_registry_url_sync_at("e, ®istry_url, Instant::now()) + .unwrap(); + checker + .verify_provenance_with_registry_url_sync_at("e, ®istry_url, Instant::now()) + .unwrap(); + + let request = request_handle.join().unwrap(); + assert!(request.starts_with(&format!("GET /{MOCK_PPID_HEX} HTTP/1.1"))); + } + + #[test] + fn provenance_check_revalidates_stale_cached_ppids() { + let attestation = dcap::create_dcap_attestation([0u8; 64]).unwrap(); + let quote = Quote::parse(&attestation).unwrap(); + let (addr, request_handle) = spawn_test_registry_server( + 200, + r#"{"zone":"projects/test/zones/us-central1-a","timestamp":"2026-06-11T00:00:00Z"}"#, + ); + let checker = GcpProvenanceChecker::new(); + let registry_url = format!("http://{addr}"); + let ppid = extract_ppid_from_quote("e).unwrap(); + let stale_at = Instant::now() - (GCP_PROVENANCE_CACHE_TTL + Duration::from_secs(1)); + + checker.known_gcp_ppids.write().unwrap().insert(ppid, stale_at); + + checker + .verify_provenance_with_registry_url_sync_at("e, ®istry_url, Instant::now()) + .unwrap(); + + let request = request_handle.join().unwrap(); + assert!(request.starts_with(&format!("GET /{MOCK_PPID_HEX} HTTP/1.1"))); + } + + #[test] + fn provenance_check_fails_closed_on_registry_miss() { + let attestation = dcap::create_dcap_attestation([0u8; 64]).unwrap(); + let quote = Quote::parse(&attestation).unwrap(); + let (addr, request_handle) = spawn_test_registry_server(404, "not found"); + + let err = GcpProvenanceChecker::new() + .verify_provenance_with_registry_url_sync_at( + "e, + &format!("http://{addr}"), + Instant::now(), + ) + .unwrap_err(); + + request_handle.join().unwrap(); + assert!(matches!(err, GcpProvenanceError::RegistryFetch(_))); + } + + #[test] + fn provenance_check_fails_closed_on_invalid_document() { + let attestation = dcap::create_dcap_attestation([0u8; 64]).unwrap(); + let quote = Quote::parse(&attestation).unwrap(); + let (addr, request_handle) = spawn_test_registry_server(200, r#"{"zone":""}"#); + + let err = GcpProvenanceChecker::new() + .verify_provenance_with_registry_url_sync_at( + "e, + &format!("http://{addr}"), + Instant::now(), + ) + .unwrap_err(); + + request_handle.join().unwrap(); + assert!(matches!(err, GcpProvenanceError::InvalidDocument)); + } + + #[test] + fn provenance_check_fails_closed_on_oversized_document() { + let attestation = dcap::create_dcap_attestation([0u8; 64]).unwrap(); + let quote = Quote::parse(&attestation).unwrap(); + let oversized_body = "x".repeat((GCP_PROVENANCE_DOCUMENT_MAX_BYTES + 1) as usize); + let (addr, request_handle) = spawn_test_registry_server(200, oversized_body); + + let err = GcpProvenanceChecker::new() + .verify_provenance_with_registry_url_sync_at( + "e, + &format!("http://{addr}"), + Instant::now(), + ) + .unwrap_err(); + + request_handle.join().unwrap(); + assert!(matches!(err, GcpProvenanceError::DocumentTooLarge)); + } +} diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index 73bc5ba..419c423 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -3,6 +3,7 @@ #[cfg(feature = "azure")] pub mod azure; pub mod dcap; +mod gcp; pub mod measurements; use std::{ @@ -18,7 +19,11 @@ use pccs::{Pccs, PccsError}; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::{dcap::DcapVerificationError, measurements::MeasurementPolicy}; +use crate::{ + dcap::DcapVerificationError, + gcp::{GcpProvenanceChecker, GcpProvenanceError}, + measurements::MeasurementPolicy, +}; /// Used in attestation type detection to check if we are on GCP const GCP_METADATA_API: &str = "http://metadata.google.internal"; @@ -284,6 +289,8 @@ pub struct AttestationVerifier { pub override_azure_outdated_tcb: bool, /// Internal cache for collateral pub internal_pccs: Option, + /// Internal cache for known GCP PPIDs + gcp_provenance_checker: GcpProvenanceChecker, } impl AttestationVerifier { @@ -299,6 +306,7 @@ impl AttestationVerifier { dump_dcap_quotes, override_azure_outdated_tcb, internal_pccs: Some(Pccs::new(pccs_url)), + gcp_provenance_checker: GcpProvenanceChecker::new(), } } @@ -311,6 +319,7 @@ impl AttestationVerifier { dump_dcap_quotes: false, override_azure_outdated_tcb: false, internal_pccs: None, + gcp_provenance_checker: GcpProvenanceChecker::new(), } } @@ -323,6 +332,7 @@ impl AttestationVerifier { dump_dcap_quotes: false, override_azure_outdated_tcb: false, internal_pccs: None, + gcp_provenance_checker: GcpProvenanceChecker::new(), } } @@ -335,6 +345,7 @@ impl AttestationVerifier { dump_dcap_quotes: false, override_azure_outdated_tcb: false, internal_pccs: Some(Pccs::new(Some(pccs_url))), + gcp_provenance_checker: GcpProvenanceChecker::new(), } } @@ -399,17 +410,23 @@ impl AttestationVerifier { } } AttestationType::DcapTdx | AttestationType::GcpTdx | AttestationType::QemuTdx => { - dcap::verify_dcap_attestation( + let (measurements, quote) = dcap::verify_dcap_attestation( attestation_exchange_message.attestation, expected_input_data, self.internal_pccs.clone(), ) - .await? + .await?; + + if attestation_type == AttestationType::GcpTdx { + self.gcp_provenance_checker.verify_provenance(quote).await?; + } + + measurements } }; // Do a measurement / attestation type policy check - self.measurement_policy.check_measurement(&measurements)?; + self.measurement_policy.check_measurement(attestation_type, &measurements)?; tracing::debug!("Verification successful"); Ok(Some(measurements)) @@ -461,16 +478,22 @@ impl AttestationVerifier { #[cfg(not(any(test, feature = "mock")))] let pccs = self.internal_pccs.clone().ok_or(AttestationError::NoPccs)?; - dcap::verify_dcap_attestation_sync( + let (measurements, quote) = dcap::verify_dcap_attestation_sync( attestation_exchange_message.attestation, expected_input_data, pccs, - )? + )?; + + if attestation_type == AttestationType::GcpTdx { + self.gcp_provenance_checker.verify_provenance_sync("e)?; + } + + measurements } }; // Do a measurement / attestation type policy check - self.measurement_policy.check_measurement(&measurements)?; + self.measurement_policy.check_measurement(attestation_type, &measurements)?; tracing::debug!("Verification successful"); Ok(Some(measurements)) @@ -586,6 +609,8 @@ pub enum AttestationError { QuoteGeneration(#[from] tdx_attest::TdxAttestError), #[error("DCAP verification: {0}")] DcapVerification(#[from] DcapVerificationError), + #[error("GCP provenance: {0}")] + GcpProvenance(#[from] GcpProvenanceError), #[error("Attestation type not supported")] AttestationTypeNotSupported, #[error("Attestation type not accepted")] diff --git a/crates/attestation/src/measurements.rs b/crates/attestation/src/measurements.rs index db8c2c9..2da2352 100644 --- a/crates/attestation/src/measurements.rs +++ b/crates/attestation/src/measurements.rs @@ -275,6 +275,8 @@ pub struct MeasurementRecord { /// An identifier, for example the name and version of the corresponding /// OS image pub measurement_id: String, + /// The attestation type this record accepts + pub attestation_type: AttestationType, /// The expected measurement register values pub measurements: ExpectedMeasurements, } @@ -283,6 +285,7 @@ impl MeasurementRecord { pub fn allow_no_attestation() -> Self { Self { measurement_id: "Allow no attestation".to_string(), + attestation_type: AttestationType::None, measurements: ExpectedMeasurements::NoAttestation, } } @@ -290,6 +293,7 @@ impl MeasurementRecord { pub fn allow_any_measurement(attestation_type: AttestationType) -> Self { Self { measurement_id: format!("Any measurement for {attestation_type}"), + attestation_type, measurements: match attestation_type { AttestationType::None => ExpectedMeasurements::NoAttestation, AttestationType::AzureTdx => ExpectedMeasurements::Azure(HashMap::new()), @@ -359,6 +363,7 @@ impl MeasurementPolicy { Self { accepted_measurements: vec![MeasurementRecord { measurement_id: "test".to_string(), + attestation_type: AttestationType::DcapTdx, measurements: ExpectedMeasurements::Dcap(HashMap::from([ (DcapMeasurementRegister::MRTD, vec![mock_tdx::MOCK_MRTD]), (DcapMeasurementRegister::RTMR0, vec![mock_tdx::MOCK_RTMR0]), @@ -374,10 +379,14 @@ impl MeasurementPolicy { /// they are acceptable pub fn check_measurement( &self, + attestation_type: AttestationType, measurements: &MultiMeasurements, ) -> Result<(), AttestationError> { if self.accepted_measurements.iter().any(|measurement_record| match measurements { MultiMeasurements::Dcap(dcap_measurements) => { + if measurement_record.attestation_type != attestation_type { + return false; + } if let ExpectedMeasurements::Dcap(expected) = &measurement_record.measurements { // All measurements in our policy must be given and must match for (k, v) in expected.iter() { @@ -391,6 +400,9 @@ impl MeasurementPolicy { false } MultiMeasurements::Azure(azure_measurements) => { + if measurement_record.attestation_type != attestation_type { + return false; + } if let ExpectedMeasurements::Azure(expected) = &measurement_record.measurements { for (k, v) in expected.iter() { match azure_measurements.get(k) { @@ -403,6 +415,9 @@ impl MeasurementPolicy { false } MultiMeasurements::NoAttestation => { + if measurement_record.attestation_type != attestation_type { + return false; + } matches!(measurement_record.measurements, ExpectedMeasurements::NoAttestation) } }) { @@ -547,6 +562,7 @@ impl MeasurementPolicy { measurement_policy.push(MeasurementRecord { measurement_id: record.measurement_id.unwrap_or_default(), + attestation_type, measurements: expected_measurements, }); } else { @@ -612,20 +628,27 @@ mod tests { // Will not match mock measurements assert!(matches!( - specific_measurements.check_measurement(&mock_dcap_measurements()).unwrap_err(), + specific_measurements + .check_measurement(AttestationType::DcapTdx, &mock_dcap_measurements()) + .unwrap_err(), AttestationError::MeasurementsNotAccepted )); // Will not match another attestation type assert!(matches!( - specific_measurements.check_measurement(&MultiMeasurements::NoAttestation).unwrap_err(), + specific_measurements + .check_measurement(AttestationType::None, &MultiMeasurements::NoAttestation) + .unwrap_err(), AttestationError::MeasurementsNotAccepted )); // A non-specific measurement fails assert!(matches!( specific_measurements - .check_measurement(&MultiMeasurements::Azure(HashMap::new())) + .check_measurement( + AttestationType::AzureTdx, + &MultiMeasurements::Azure(HashMap::new()) + ) .unwrap_err(), AttestationError::MeasurementsNotAccepted )); @@ -638,17 +661,32 @@ mod tests { let allowed_attestation_type = MeasurementPolicy::from_file("test-assets/measurements_2.json".into()).await.unwrap(); - allowed_attestation_type.check_measurement(&mock_dcap_measurements()).unwrap(); + allowed_attestation_type + .check_measurement(AttestationType::DcapTdx, &mock_dcap_measurements()) + .unwrap(); // Will not match another attestation type assert!(matches!( allowed_attestation_type - .check_measurement(&MultiMeasurements::NoAttestation) + .check_measurement(AttestationType::None, &MultiMeasurements::NoAttestation) .unwrap_err(), AttestationError::MeasurementsNotAccepted )); } + #[test] + fn gcp_policy_rejects_dcap_labeled_measurements() { + let policy = MeasurementPolicy::single_attestation_type(AttestationType::GcpTdx); + let measurements = mock_dcap_measurements(); + + policy.check_measurement(AttestationType::GcpTdx, &measurements).unwrap(); + + assert!(matches!( + policy.check_measurement(AttestationType::DcapTdx, &measurements).unwrap_err(), + AttestationError::MeasurementsNotAccepted + )); + } + #[tokio::test] async fn test_buildernet_measurements() { // Refresh this fixture explicitly with: @@ -662,13 +700,20 @@ mod tests { assert!(!policy.accepted_measurements.is_empty()); assert!(matches!( - policy.check_measurement(&MultiMeasurements::NoAttestation).unwrap_err(), + policy + .check_measurement(AttestationType::None, &MultiMeasurements::NoAttestation) + .unwrap_err(), AttestationError::MeasurementsNotAccepted )); // A non-specific measurement fails assert!(matches!( - policy.check_measurement(&MultiMeasurements::Azure(HashMap::new())).unwrap_err(), + policy + .check_measurement( + AttestationType::AzureTdx, + &MultiMeasurements::Azure(HashMap::new()) + ) + .unwrap_err(), AttestationError::MeasurementsNotAccepted )); } @@ -724,17 +769,17 @@ mod tests { // First value should match let measurements1 = MultiMeasurements::Dcap(HashMap::from([(DcapMeasurementRegister::MRTD, [0u8; 48])])); - assert!(policy.check_measurement(&measurements1).is_ok()); + assert!(policy.check_measurement(AttestationType::DcapTdx, &measurements1).is_ok()); // Second value should also match let measurements2 = MultiMeasurements::Dcap(HashMap::from([(DcapMeasurementRegister::MRTD, [0x11u8; 48])])); - assert!(policy.check_measurement(&measurements2).is_ok()); + assert!(policy.check_measurement(AttestationType::DcapTdx, &measurements2).is_ok()); // Different value should not match let measurements3 = MultiMeasurements::Dcap(HashMap::from([(DcapMeasurementRegister::MRTD, [0x22u8; 48])])); - assert!(policy.check_measurement(&measurements3).is_err()); + assert!(policy.check_measurement(AttestationType::DcapTdx, &measurements3).is_err()); } #[tokio::test] @@ -814,21 +859,21 @@ mod tests { (DcapMeasurementRegister::MRTD, [0u8; 48]), (DcapMeasurementRegister::RTMR0, [0x11u8; 48]), ])); - assert!(policy.check_measurement(&measurements1).is_ok()); + assert!(policy.check_measurement(AttestationType::DcapTdx, &measurements1).is_ok()); // Both match (single + second of any) let measurements2 = MultiMeasurements::Dcap(HashMap::from([ (DcapMeasurementRegister::MRTD, [0u8; 48]), (DcapMeasurementRegister::RTMR0, [0x22u8; 48]), ])); - assert!(policy.check_measurement(&measurements2).is_ok()); + assert!(policy.check_measurement(AttestationType::DcapTdx, &measurements2).is_ok()); // Single matches but any doesn't let measurements3 = MultiMeasurements::Dcap(HashMap::from([ (DcapMeasurementRegister::MRTD, [0u8; 48]), (DcapMeasurementRegister::RTMR0, [0x33u8; 48]), ])); - assert!(policy.check_measurement(&measurements3).is_err()); + assert!(policy.check_measurement(AttestationType::DcapTdx, &measurements3).is_err()); } #[tokio::test]