diff --git a/crates/attestation/src/gcp.rs b/crates/attestation/src/gcp.rs new file mode 100644 index 0000000..069b477 --- /dev/null +++ b/crates/attestation/src/gcp.rs @@ -0,0 +1,47 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use attest_measure::dcap::DcapFirmware; +use thiserror::Error; + +/// Maps MRTD values to GCP firmware to avoid re-fetching on subsequent +/// verification +#[derive(Clone, Debug, Default)] +pub(crate) struct GcpFirmwareCache { + cache: Arc>>, +} + +impl GcpFirmwareCache { + pub(crate) fn new() -> Self { + Self { cache: Default::default() } + } + + /// Retrieve firmware from cache or fetch if not present + pub(crate) fn get_or_fetch( + &self, + mrtd: [u8; 48], + ) -> Result { + if let Some(firmware) = + self.cache.read().map_err(|_| GcpFirmwareCacheError::CacheLock)?.get(&mrtd).cloned() + { + return Ok(firmware); + } + + let firmware = DcapFirmware::from_google(mrtd)?; + self.cache + .write() + .map_err(|_| GcpFirmwareCacheError::CacheLock)? + .insert(mrtd, firmware.clone()); + Ok(firmware) + } +} + +#[derive(Debug, Error)] +pub(crate) enum GcpFirmwareCacheError { + #[error("Cache lock poisoned")] + CacheLock, + #[error("Firmware fetch: {0}")] + Firmware(#[from] attest_measure::dcap::GoogleError), +} diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index d47c7c6..a649cca 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -3,8 +3,8 @@ #[cfg(feature = "azure")] pub mod azure; pub mod dcap; +mod gcp; pub mod measurements; - use std::{ fmt::{self, Display, Formatter}, io::Read, @@ -345,14 +345,17 @@ pub struct AttestationVerifier { pub override_azure_outdated_tcb: bool, /// Internal cache for collateral pub internal_pccs: Option, + /// Cached GCP firmware blobs indexed by MRTD + known_gcp_firmware: gcp::GcpFirmwareCache, } impl AttestationVerifier { - pub fn new( + fn build( measurement_policy: MeasurementPolicy, pccs_url: Option, dump_dcap_quotes: bool, override_azure_outdated_tcb: bool, + known_gcp_firmware: gcp::GcpFirmwareCache, ) -> Self { Self { measurement_policy, @@ -360,9 +363,25 @@ impl AttestationVerifier { dump_dcap_quotes, override_azure_outdated_tcb, internal_pccs: Some(Pccs::new(pccs_url)), + known_gcp_firmware, } } + pub fn new( + measurement_policy: MeasurementPolicy, + pccs_url: Option, + dump_dcap_quotes: bool, + override_azure_outdated_tcb: bool, + ) -> Self { + Self::build( + measurement_policy, + pccs_url, + dump_dcap_quotes, + override_azure_outdated_tcb, + gcp::GcpFirmwareCache::new(), + ) + } + /// Create an [AttestationVerifier] which will only allow no attestation /// and will reject if one is given pub fn expect_none() -> Self { @@ -372,6 +391,7 @@ impl AttestationVerifier { dump_dcap_quotes: false, override_azure_outdated_tcb: false, internal_pccs: None, + known_gcp_firmware: gcp::GcpFirmwareCache::new(), } } @@ -384,6 +404,7 @@ impl AttestationVerifier { dump_dcap_quotes: false, override_azure_outdated_tcb: false, internal_pccs: None, + known_gcp_firmware: gcp::GcpFirmwareCache::new(), } } @@ -396,6 +417,7 @@ impl AttestationVerifier { dump_dcap_quotes: false, override_azure_outdated_tcb: false, internal_pccs: Some(Pccs::new(Some(pccs_url))), + known_gcp_firmware: gcp::GcpFirmwareCache::new(), } } @@ -482,7 +504,11 @@ impl AttestationVerifier { .attestation_evidence .as_ref() .map(|evidence| evidence.platform.clone()); - self.measurement_policy.check_measurement(&measurements, platform_metadata)?; + self.measurement_policy.check_measurement_with_gcp_cache( + &measurements, + platform_metadata, + Some(&self.known_gcp_firmware), + )?; tracing::debug!("Verification successful"); Ok(Some(measurements)) @@ -555,7 +581,11 @@ impl AttestationVerifier { .attestation_evidence .as_ref() .map(|evidence| evidence.platform.clone()); - self.measurement_policy.check_measurement(&measurements, platform_metadata)?; + self.measurement_policy.check_measurement_with_gcp_cache( + &measurements, + platform_metadata, + Some(&self.known_gcp_firmware), + )?; tracing::debug!("Verification successful"); Ok(Some(measurements)) diff --git a/crates/attestation/src/measurements.rs b/crates/attestation/src/measurements.rs index c743bad..c644cf8 100644 --- a/crates/attestation/src/measurements.rs +++ b/crates/attestation/src/measurements.rs @@ -10,7 +10,12 @@ use serde::Deserialize; use thiserror::Error; use tracing::warn; -use crate::{AttestationError, AttestationType, dcap::DcapVerificationError}; +use crate::{ + AttestationError, + AttestationType, + dcap::DcapVerificationError, + gcp::{GcpFirmwareCache, GcpFirmwareCacheError}, +}; /// Represents the measurement register types in a TDX quote #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -380,6 +385,15 @@ impl MeasurementPolicy { &self, measurements: &MultiMeasurements, platform_metadata: Option, + ) -> Result<(), AttestationError> { + self.check_measurement_with_gcp_cache(measurements, platform_metadata, None) + } + + pub(crate) fn check_measurement_with_gcp_cache( + &self, + measurements: &MultiMeasurements, + platform_metadata: Option, + known_gcp_firmware: Option<&GcpFirmwareCache>, ) -> Result<(), AttestationError> { if self.accepted_measurements.iter().any(|measurement_record| match measurements { MultiMeasurements::Dcap(dcap_measurements) => { @@ -408,24 +422,37 @@ impl MeasurementPolicy { ); return false; }; - match DcapFirmware::from_google(*mrtd) { + + let result = if let Some(cache) = known_gcp_firmware { + cache.get_or_fetch(*mrtd) + } else { + DcapFirmware::from_google(*mrtd) + .map_err(GcpFirmwareCacheError::from) + }; + match result { Ok(firmware) => Some(firmware), Err(err) => { - warn!("Could not match image hash measurement - failed to fetch or verify Google firmware: {err:?}"); - return false - }, + warn!( + "Could not match image hash measurement - failed to fetch or verify Google firmware: {err:?}" + ); + return false; + } } } ImageAttestationType::SelfHostedTdx => None, ImageAttestationType::AzureTdx => return false, }; - let Ok(expected_measurements) = expected_dcap_registers( + let expected_measurements = match expected_dcap_registers( image_hashes, platform_metadata, firmware.as_ref(), - ) else { - return false; // TODO should we bail here + ) { + Ok(expected) => expected, + Err(err) => { + warn!("Failed to compute expected DCAP registers: {err:?}"); + return false; // TODO should we bail here + } }; if let Some(expected_mrtd) = expected_measurements.mrtd {