diff --git a/Cargo.lock b/Cargo.lock index 4c76e33..f88657a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,6 +110,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + [[package]] name = "aho-corasick" version = "1.1.4" @@ -152,19 +158,41 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive 0.5.1", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "asn1-rs" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60" dependencies = [ - "asn1-rs-derive", + "asn1-rs-derive 0.6.0", "asn1-rs-impl", "displaydoc", "nom", @@ -174,6 +202,18 @@ dependencies = [ "time", ] +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "asn1-rs-derive" version = "0.6.0" @@ -228,7 +268,7 @@ dependencies = [ "az-tdx-vtpm", "base64 0.22.1", "configfs-tsm", - "dcap-qvl", + "dcap-qvl 0.3.12 (git+https://github.com/flashbots/dcap-qvl.git?branch=peg%2Fazure-outdated-tcp-override)", "hex", "http 1.4.0", "num-bigint", @@ -249,12 +289,27 @@ dependencies = [ "tokio-rustls", "tracing", "tss-esapi", - "x509-parser", + "x509-parser 0.18.1", ] [[package]] name = "attested-tls" version = "0.0.1" +dependencies = [ + "anyhow", + "attestation", + "nested-tls", + "ra-tls", + "rcgen 0.14.7", + "rustls", + "serde_json", + "sha2", + "thiserror 2.0.18", + "tokio", + "tracing", + "x509-parser 0.18.1", + "yasna 0.5.2", +] [[package]] name = "autocfg" @@ -438,6 +493,29 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake3" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -447,6 +525,31 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bon" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d13a61f2963b88eef9c1be03df65d42f6996dfeac1054870d950fcf66686f83" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d314cc62af2b6b0c65780555abb4d02a03dd3b799cd42419044f0c38d99738c0" +dependencies = [ + "darling", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "borsh" version = "1.6.0" @@ -542,6 +645,23 @@ dependencies = [ "shlex", ] +[[package]] +name = "cc-eventlog" +version = "0.5.7" +source = "git+https://github.com/Dstack-TEE/dstack.git#31cfd481b178fd36b2137fc8fc1e5c728f89145d" +dependencies = [ + "anyhow", + "digest", + "ez-hash", + "fs-err", + "hex", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -637,6 +757,18 @@ version = "0.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187437900921c8172f33316ad51a3267df588e99a2aebfa5ca1a2ed44df9e703" +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -663,6 +795,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + [[package]] name = "convert_case" version = "0.10.0" @@ -681,6 +819,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -811,12 +958,83 @@ dependencies = [ "syn", ] +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "data-encoding" version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" +[[package]] +name = "dcap-qvl" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e7842b81018f3b991dc65ec0a95ff347332de58478c4ac43459095af00cc89" +dependencies = [ + "anyhow", + "asn1_der", + "base64 0.22.1", + "borsh", + "byteorder", + "chrono", + "const-oid", + "dcap-qvl-webpki", + "der", + "derive_more 2.1.1", + "futures", + "hex", + "log", + "p256", + "parity-scale-codec", + "pem", + "reqwest", + "ring", + "rustls-pki-types", + "scale-info", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "signature", + "tracing", + "urlencoding", + "wasm-bindgen-futures", + "x509-cert", +] + [[package]] name = "dcap-qvl" version = "0.3.12" @@ -884,13 +1102,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs 0.6.2", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "der-parser" version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", "displaydoc", "nom", "num-bigint", @@ -1005,6 +1237,43 @@ dependencies = [ "syn", ] +[[package]] +name = "dstack-attest" +version = "0.5.7" +source = "git+https://github.com/Dstack-TEE/dstack.git#31cfd481b178fd36b2137fc8fc1e5c728f89145d" +dependencies = [ + "anyhow", + "cc-eventlog", + "dcap-qvl 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", + "dstack-types", + "errify", + "ez-hash", + "fs-err", + "hex", + "hex_fmt", + "insta", + "or-panic", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "sha3", + "tdx-attest", +] + +[[package]] +name = "dstack-types" +version = "0.5.7" +source = "git+https://github.com/Dstack-TEE/dstack.git#31cfd481b178fd36b2137fc8fc1e5c728f89145d" +dependencies = [ + "parity-scale-codec", + "serde", + "serde-human-bytes", + "sha3", + "size-parser", +] + [[package]] name = "dunce" version = "1.0.5" @@ -1064,6 +1333,7 @@ dependencies = [ "ff", "generic-array", "group", + "pem-rfc7468", "pkcs8", "rand_core 0.6.4", "sec1", @@ -1071,6 +1341,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1118,6 +1394,28 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errify" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb818c3c01af9cdeb367f7e92e290b9a080935cdc5fb6cc0c1193ae17032849" +dependencies = [ + "anyhow", + "errify-macros", +] + +[[package]] +name = "errify-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e87afa19e6030c2cf5514b00d5a242a3ea9492a2aa618635076914f5d15e7af" +dependencies = [ + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn", +] + [[package]] name = "errno" version = "0.3.14" @@ -1128,6 +1426,21 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "ez-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42b3b3adc5fbbc9e21416d5b721b1bccb501a87d7b32ac89f2c7cea229d40772" +dependencies = [ + "blake2", + "blake3", + "digest", + "md-5", + "sha1", + "sha2", + "sha3", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -1162,6 +1475,16 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7ac824320a75a52197e8f2d787f6a38b6718bb6897a35142d749af3c0e8f4fe" +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1198,6 +1521,15 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-err" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" +dependencies = [ + "autocfg", +] + [[package]] name = "fs_extra" version = "1.3.0" @@ -1404,6 +1736,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hex_fmt" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b07f60793ff0a4d9cef0f18e63b5357e06209987153a64648c972c1e5aff336f" + [[package]] name = "hickory-proto" version = "0.25.2" @@ -1450,6 +1788,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -1669,6 +2016,12 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.1.0" @@ -1725,6 +2078,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "insta" +version = "1.46.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" +dependencies = [ + "console", + "once_cell", + "similar", + "tempfile", +] + [[package]] name = "iocuddle" version = "0.1.1" @@ -1834,6 +2199,15 @@ dependencies = [ "sha2", ] +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -1925,6 +2299,16 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "memchr" version = "2.8.0" @@ -1952,6 +2336,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + [[package]] name = "mio" version = "1.1.1" @@ -1992,13 +2386,26 @@ dependencies = [ "criterion", "impl-more 0.3.1", "pin-project-lite", - "rcgen", + "rcgen 0.14.7", "rustls", "tokio", "tokio-rustls", "tracing", ] +[[package]] +name = "nix" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", + "memoffset", +] + [[package]] name = "nom" version = "7.1.3" @@ -2091,13 +2498,22 @@ dependencies = [ "serde", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs 0.6.2", +] + [[package]] name = "oid-registry" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", ] [[package]] @@ -2170,6 +2586,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "or-panic" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "596a79faf55e869e7bc0c2162cf2f18a54d4d1112876bceae587ad954fcbd574" + [[package]] name = "p256" version = "0.13.2" @@ -2452,6 +2874,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "version_check", + "yansi", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2528,6 +2963,43 @@ version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" +[[package]] +name = "ra-tls" +version = "0.5.7" +source = "git+https://github.com/Dstack-TEE/dstack.git#31cfd481b178fd36b2137fc8fc1e5c728f89145d" +dependencies = [ + "anyhow", + "bon", + "cc-eventlog", + "dcap-qvl 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", + "dstack-attest", + "dstack-types", + "elliptic-curve", + "errify", + "ez-hash", + "flate2", + "fs-err", + "hex", + "hex_fmt", + "hkdf", + "or-panic", + "p256", + "parity-scale-codec", + "rand 0.8.5", + "rcgen 0.13.2", + "ring", + "rmp-serde", + "rustls-pki-types", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "sha3", + "tracing", + "x509-parser 0.16.0", + "yasna 0.5.2", +] + [[package]] name = "radium" version = "0.7.0" @@ -2540,6 +3012,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ + "libc", "rand_chacha 0.3.1", "rand_core 0.6.4", ] @@ -2612,6 +3085,20 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "x509-parser 0.16.0", + "yasna 0.5.2", +] + [[package]] name = "rcgen" version = "0.14.7" @@ -2622,7 +3109,7 @@ dependencies = [ "ring", "rustls-pki-types", "time", - "x509-parser", + "x509-parser 0.18.1", "yasna 0.5.2", ] @@ -2756,6 +3243,25 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmp" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + [[package]] name = "rsa" version = "0.9.10" @@ -3038,6 +3544,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -3049,6 +3566,16 @@ dependencies = [ "digest", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "shlex" version = "1.3.0" @@ -3075,6 +3602,28 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + +[[package]] +name = "size-parser" +version = "0.5.7" +source = "git+https://github.com/Dstack-TEE/dstack.git#31cfd481b178fd36b2137fc8fc1e5c728f89145d" +dependencies = [ + "anyhow", + "serde", + "thiserror 2.0.18", +] + [[package]] name = "slab" version = "0.4.12" @@ -3135,6 +3684,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -3190,6 +3745,25 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tdx-attest" +version = "0.5.7" +source = "git+https://github.com/Dstack-TEE/dstack.git#31cfd481b178fd36b2137fc8fc1e5c728f89145d" +dependencies = [ + "anyhow", + "cc-eventlog", + "fs-err", + "hex", + "libc", + "parity-scale-codec", + "serde", + "serde-human-bytes", + "serde_json", + "sha2", + "thiserror 2.0.18", + "vsock", +] + [[package]] name = "tdx-quote" version = "0.0.5" @@ -3624,6 +4198,16 @@ version = "0.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" +[[package]] +name = "vsock" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b82aeb12ad864eb8cd26a6c21175d0bdc66d398584ee6c93c76964c3bcfc78ff" +dependencies = [ + "libc", + "nix", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -3828,6 +4412,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -4103,18 +4696,36 @@ dependencies = [ "x509-cert", ] +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs 0.6.2", + "data-encoding", + "der-parser 9.0.0", + "lazy_static", + "nom", + "oid-registry 0.7.1", + "ring", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "x509-parser" version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", "data-encoding", - "der-parser", + "der-parser 10.0.0", "lazy_static", "nom", - "oid-registry", + "oid-registry 0.8.1", "ring", "rusticata-macros", "thiserror 2.0.18", @@ -4142,6 +4753,12 @@ dependencies = [ "x509-ocsp", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 8067bbd..8d84292 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,3 +20,4 @@ unused_async = "warn" rustls = { version = "0.23.37", default-features = false, features = ["brotli"] } tokio = { version = "1.50.0", features = ["default"] } tokio-rustls = { version = "0.26.4", default-features = false } +attestation = { path = "crates/attestation" } diff --git a/crates/attestation/src/dcap.rs b/crates/attestation/src/dcap.rs index bf29dcd..5122c9a 100644 --- a/crates/attestation/src/dcap.rs +++ b/crates/attestation/src/dcap.rs @@ -111,6 +111,7 @@ pub async fn verify_dcap_attestation_with_given_timestamp( } #[cfg(any(test, feature = "mock"))] +#[allow(clippy::unused_async)] pub async fn verify_dcap_attestation( input: Vec, expected_input_data: [u8; 64], diff --git a/crates/attested-tls/Cargo.toml b/crates/attested-tls/Cargo.toml index 989d69c..baf403c 100644 --- a/crates/attested-tls/Cargo.toml +++ b/crates/attested-tls/Cargo.toml @@ -3,5 +3,24 @@ name = "attested-tls" version = "0.0.1" edition = "2024" +[dependencies] +tokio = { workspace = true } +rustls = { workspace = true, default-features = false } +attestation = { workspace = true } +rcgen = "0.14.7" +thiserror = "2.0.17" +ra-tls = { git = "https://github.com/Dstack-TEE/dstack.git", version = "0.5.7", features = ["quote"] } +anyhow = "1.0.102" +x509-parser = "0.18.1" +serde_json = "1.0.149" +yasna = "0.5.2" +sha2 = "0.10.9" +tracing = "0.1.41" + +[dev-dependencies] +attestation = { workspace = true, features = ["mock"] } +nested-tls = { path = "../nested-tls" } +rustls = { workspace = true, default-features = false, features = ["aws_lc_rs"] } + [lints] workspace = true diff --git a/crates/attested-tls/src/lib.rs b/crates/attested-tls/src/lib.rs index 8b13789..ff7a5a0 100644 --- a/crates/attested-tls/src/lib.rs +++ b/crates/attested-tls/src/lib.rs @@ -1 +1,1350 @@ +//! An attested TLS certificate resolver and verifier +use std::{ + collections::HashMap, + fmt, + sync::{Arc, RwLock}, + time::{Duration, SystemTime}, +}; +pub use attestation::{ + AttestationExchangeMessage, + AttestationGenerator, + AttestationType, + AttestationVerifier, +}; +pub use ra_tls::cert::CaCert; +use ra_tls::{ + attestation::{Attestation, AttestationQuote, VersionedAttestation}, + cert::CertRequest, + rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256}, +}; +use rustls::{ + DigitallySignedStruct, + DistinguishedName, + RootCertStore, + SignatureScheme, + client::{ + ResolvesClientCert, + VerifierBuilderError, + WebPkiServerVerifier, + danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + verify_server_name, + }, + crypto::CryptoProvider, + pki_types::{ + CertificateDer, + PrivateKeyDer, + PrivatePkcs8KeyDer, + ServerName, + UnixTime, + pem::PemObject, + }, + server::{ + ParsedCertificate, + ResolvesServerCert, + WebPkiClientVerifier, + danger::{ClientCertVerified, ClientCertVerifier}, + }, + sign::{CertifiedKey, SigningKey}, +}; +use sha2::{Digest as _, Sha512}; +use thiserror::Error; +use x509_parser::{certificate::X509Certificate, oid_registry::Oid}; + +/// The length of time a certificate is valid for +#[cfg(not(test))] +const CERTIFICATE_VALIDITY: Duration = Duration::from_secs(30 * 60); +#[cfg(test)] +const CERTIFICATE_VALIDITY: Duration = Duration::from_secs(4); + +/// How long before expiry to renew certificate +#[cfg(not(test))] +const CERTIFICATE_RENEWAL_LEAD_TIME: Duration = Duration::from_secs(5 * 60); +#[cfg(test)] +const CERTIFICATE_RENEWAL_LEAD_TIME: Duration = Duration::from_secs(2); + +/// How long to wait before re-trying certificate renewal on failure +#[cfg(not(test))] +const CERTIFICATE_RENEWAL_RETRY_DELAY: Duration = Duration::from_secs(30); +#[cfg(test)] +const CERTIFICATE_RENEWAL_RETRY_DELAY: Duration = Duration::from_millis(200); + +/// A TLS certificate resolver which includes an attestation as a +/// certificate extension +#[derive(Clone, Debug)] +pub struct AttestedCertificateResolver { + /// Cloneable inner state + state: Arc, +} + +/// Internal state used by the resolver and its renewal loop +struct ResolverState { + /// The private TLS key in a format ready to be + /// used in handshake + key: Arc, + /// Optional CA used to sign leaf certificates - default is self-signed + ca: Option>, + /// The private TLS key in a format ready to be used by to sign + /// certificates if no CA is used + key_pair_der: Vec, + /// The current certificate with attestation + certificate: RwLock>>, + /// Attestation generator used when renewing ceritifcate + attestation_generator: AttestationGenerator, + /// Primary DNS name used as certificate subject / common name. + primary_name: String, + /// DNS subject alternative names, including the primary name. + subject_alt_names: Vec, +} + +impl fmt::Debug for ResolverState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let certificate_chain_len = self.certificate.read().ok().map(|certs| certs.len()); + + f.debug_struct("ResolverState") + .field("key", &"") + .field("ca_present", &self.ca.is_some()) + .field("key_pair_der_len", &self.key_pair_der.len()) + .field("certificate_chain_len", &certificate_chain_len) + .field("attestation_generator", &self.attestation_generator) + .field("primary_name", &self.primary_name) + .field("subject_alt_names", &self.subject_alt_names) + .finish() + } +} + +impl AttestedCertificateResolver { + /// Create a certificate resolver with a given attestation generator + /// A private cerificate authority can also be given - otherwise + /// certificates will be self signed + pub async fn new( + attestation_generator: AttestationGenerator, + ca: Option, + primary_name: String, + subject_alt_names: Vec, + ) -> Result { + Self::new_with_provider( + attestation_generator, + ca, + primary_name, + subject_alt_names, + default_crypto_provider()?, + ) + .await + } + + /// Also provide a crypto provider + pub async fn new_with_provider( + attestation_generator: AttestationGenerator, + ca: Option, + primary_name: String, + subject_alt_names: Vec, + provider: Arc, + ) -> Result { + debug_assert!(CERTIFICATE_RENEWAL_LEAD_TIME < CERTIFICATE_VALIDITY); + let subject_alt_names = + normalized_subject_alt_names(primary_name.as_str(), subject_alt_names); + + // Generate keypair + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)?; + let key_pair_der = key_pair.serialize_der(); + let key = Self::load_signing_key(&key_pair, provider)?; + + // Generate initial attested certificate + let certificate = Self::issue_ra_cert_chain( + &key_pair, + ca.as_ref(), + primary_name.as_str(), + &subject_alt_names, + &attestation_generator, + ) + .await?; + + let state = Arc::new(ResolverState { + key, + certificate: RwLock::new(certificate), + ca: ca.map(Arc::new), + key_pair_der, + attestation_generator, + primary_name, + subject_alt_names, + }); + + // Start a loop which will periodically renew the certificate + Self::spawn_renewal_task(Arc::downgrade(&state)); + + Ok(Self { state }) + } + + /// Create an attested certificate chain - either self-signed or with + /// the provided CA + async fn issue_ra_cert_chain( + key: &KeyPair, + ca: Option<&CaCert>, + primary_name: &str, + subject_alt_names: &[String], + attestation_generator: &AttestationGenerator, + ) -> Result>, AttestedTlsError> { + tracing::debug!("Generating new remote-attested ceritifcate for {primary_name}"); + let pubkey = key.public_key_der(); + let now = SystemTime::now(); + let not_after = now + CERTIFICATE_VALIDITY; + + let attestation = Self::create_attestation_payload( + pubkey, + now, + not_after, + primary_name, + attestation_generator, + ) + .await?; + + let cert_request = CertRequest::builder() + .key(key) + .subject(primary_name) + .alt_names(subject_alt_names) + .not_before(now) + .not_after(not_after) + .usage_server_auth(true) + .usage_client_auth(true) + .attestation(&attestation) + .build(); + + let leaf = match ca { + Some(ca) => ca.sign(cert_request).map_err(AttestedTlsError::RaTls)?, + None => cert_request.self_signed().map_err(AttestedTlsError::RaTls)?, + }; + + let mut chain = vec![leaf.der().to_vec().into()]; + if let Some(ca) = ca { + chain.push(CertificateDer::from_pem_slice(ca.pem_cert.as_bytes())?); + } + + Ok(chain) + } + + /// Get keypair into a format ready to be used in handshakes + fn load_signing_key( + key_pair: &KeyPair, + provider: Arc, + ) -> Result, AttestedTlsError> { + let private_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der())); + + Ok(provider.key_provider.load_private_key(private_key)?) + } + + /// Create an attestation, and format it to be used in certificate + /// extension + async fn create_attestation_payload( + pubkey: Vec, + not_before: SystemTime, + not_after: SystemTime, + primary_name: &str, + attestation_generator: &AttestationGenerator, + ) -> Result { + let report_data = + create_report_data(pubkey, not_before, not_after, primary_name.as_bytes())?; + let attestation = attestation_generator.generate_attestation(report_data).await?; + Ok(VersionedAttestation::V0 { + attestation: Attestation { + quote: ra_tls::attestation::AttestationQuote::DstackTdx( + ra_tls::attestation::TdxQuote { + quote: serde_json::to_vec(&attestation)?, + event_log: Vec::new(), + }, + ), + runtime_events: Vec::new(), + report_data, + config: String::new(), + report: (), + }, + }) + } + + /// Start a loop which periodically renews the certificate + fn spawn_renewal_task(state: std::sync::Weak) { + tokio::spawn(async move { + let renewal_delay = CERTIFICATE_VALIDITY - CERTIFICATE_RENEWAL_LEAD_TIME; + let mut next_delay = renewal_delay; + + loop { + tokio::time::sleep(next_delay).await; + let Some(current) = state.upgrade() else { + tracing::warn!("Resolver has been dropped - stopping renewal loop"); + break; + }; + + let key_pair = match KeyPair::try_from(current.key_pair_der.clone()) { + Ok(key_pair) => key_pair, + Err(e) => { + tracing::error!("Failed to load keypair: {e}"); + next_delay = CERTIFICATE_RENEWAL_RETRY_DELAY; + continue; + } + }; + + next_delay = match Self::issue_ra_cert_chain( + &key_pair, + current.ca.as_deref(), + current.primary_name.as_str(), + ¤t.subject_alt_names, + ¤t.attestation_generator, + ) + .await + { + Ok(certificate) => { + *current.certificate.write().expect("Certificate lock poisoned") = + certificate; + renewal_delay + } + Err(e) => { + tracing::error!("Failed to renew attested certificate: {e}"); + CERTIFICATE_RENEWAL_RETRY_DELAY + } + }; + } + }); + } +} + +impl ResolvesServerCert for AttestedCertificateResolver { + fn resolve(&self, _: rustls::server::ClientHello<'_>) -> Option> { + self.current_certified_key() + } +} + +impl ResolvesClientCert for AttestedCertificateResolver { + fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { + self.current_certified_key() + } + + fn has_certs(&self) -> bool { + !self.state.certificate.read().expect("Certificate lock poisoned").is_empty() + } +} + +impl AttestedCertificateResolver { + fn current_certified_key(&self) -> Option> { + let certificate = self.state.certificate.read().expect("Certificate lock poisoned").clone(); + Some(Arc::new(CertifiedKey::new(certificate, self.state.key.clone()))) + } +} + +fn default_crypto_provider() -> Result, AttestedTlsError> { + CryptoProvider::get_default().cloned().ok_or(AttestedTlsError::CryptoProviderUnavailable) +} + +/// Ensures that SAN contains the primary hostname +fn normalized_subject_alt_names(primary_name: &str, subject_alt_names: Vec) -> Vec { + let mut normalized = Vec::with_capacity(subject_alt_names.len() + 1); + normalized.push(primary_name.to_string()); + + for name in subject_alt_names { + if !normalized.iter().any(|existing| existing == &name) { + normalized.push(name); + } + } + + normalized +} + +/// Make input data for the attestation by hashing together public key, +/// validity period and hostname +fn create_report_data( + public_key: Vec, + not_before: SystemTime, + not_after: SystemTime, + hostname: &[u8], +) -> Result<[u8; 64], AttestedTlsError> { + let not_before = not_before + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(AttestedTlsError::SystemTime)? + .as_secs() + .to_be_bytes(); + let not_after = not_after + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(AttestedTlsError::SystemTime)? + .as_secs() + .to_be_bytes(); + + let mut hasher = Sha512::new(); + hasher.update(public_key); + hasher.update(not_before); + hasher.update(not_after); + hasher.update(hostname); + + Ok(hasher.finalize().into()) +} + +/// Verifies attested TLS server or client certificates during TLS handshake +#[derive(Debug)] +pub struct AttestedCertificateVerifier { + /// Underlying verifier when used with a private CA rather than + /// self-signed + server_inner: Option>, + /// Underlying client verifier when used with a private CA rather than + /// self-signed + client_inner: Option>, + /// Underlying cryptography provider + provider: Arc, + /// Configured for verifying attestations + attestation_verifier: AttestationVerifier, + /// Report data of pre-trusted certificates with cache expiry time + trusted_certificates: Arc>>, +} + +impl AttestedCertificateVerifier { + /// Create a certificate verifier with given attestation verification + /// and optionally a private CA root of trust + pub fn new( + root_store: Option, + attestation_verifier: AttestationVerifier, + ) -> Result { + Self::new_with_provider(root_store, attestation_verifier, default_crypto_provider()?) + } + + /// Also provide a crypto provider + pub fn new_with_provider( + root_store: Option, + attestation_verifier: AttestationVerifier, + provider: Arc, + ) -> Result { + let (server_inner, client_inner) = match root_store { + Some(root_store) => { + let root_store = Arc::new(root_store); + let server_inner = WebPkiServerVerifier::builder_with_provider( + root_store.clone(), + provider.clone(), + ) + .build() + .map_err(AttestedTlsError::VerifierBuilder)?; + let client_inner = + WebPkiClientVerifier::builder_with_provider(root_store, provider.clone()) + .build() + .map_err(AttestedTlsError::VerifierBuilder)?; + + (Some(server_inner), Some(client_inner)) + } + None => (None, None), + }; + + Ok(Self { + server_inner, + client_inner, + provider, + attestation_verifier, + trusted_certificates: Default::default(), + }) + } + + /// Given a TLS certificate, return the embedded attestation + fn extract_custom_attestation_from_cert( + cert: &CertificateDer<'_>, + ) -> Result { + // First try to parse using ra_tls which assumes DCAP + if let Ok(Some(attestation)) = ra_tls::attestation::from_der(cert.as_ref()) && + let AttestationQuote::DstackTdx(tdx_quote) = attestation.quote + { + return Ok(AttestationExchangeMessage { + attestation_type: AttestationType::DcapTdx, + attestation: tdx_quote.quote, + }); + } + + // If that fails, extract and parse the extension + let cert = Self::parse_x509_certificate(cert)?; + let oid = Oid::from(ra_tls::oids::PHALA_RATLS_TDX_QUOTE) + .map_err(|err| rustls::Error::General(format!("invalid attestation OID: {err}")))?; + let ext = cert + .get_extension_unique(&oid) + .map_err(|err| Self::bad_encoding(format!("invalid attestation extension: {err}")))? + .ok_or_else(|| Self::bad_encoding("missing attestation extension"))?; + let payload = yasna::parse_der(ext.value, |reader| reader.read_bytes()) + .map_err(|err| Self::bad_encoding(format!("invalid attestation DER payload: {err}")))?; + serde_json::from_slice(&payload) + .map_err(|err| Self::bad_encoding(format!("invalid attestation JSON payload: {err}"))) + } + + /// Given a certificate, return the attestation report input data based + /// on public key and expriy, as well as the expiry time + fn cert_binding_data(cert: &CertificateDer<'_>) -> Result<([u8; 64], UnixTime), rustls::Error> { + let cert = Self::parse_x509_certificate(cert)?; + let not_before: u64 = cert + .validity() + .not_before + .timestamp() + .try_into() + .map_err(|_| rustls::Error::General("invalid certificate not_before".into()))?; + let not_after: u64 = cert + .validity() + .not_after + .timestamp() + .try_into() + .map_err(|_| rustls::Error::General("invalid certificate not_after".into()))?; + let hostname = Self::hostname_from_cert(&cert)?; + let expected_input_data = create_report_data( + cert.public_key().raw.to_vec(), + SystemTime::UNIX_EPOCH + Duration::from_secs(not_before), + SystemTime::UNIX_EPOCH + Duration::from_secs(not_after), + &hostname, + ) + .map_err(|err| rustls::Error::General(err.to_string()))?; + let not_after = UnixTime::since_unix_epoch(Duration::from_secs(not_after)); + + Ok((expected_input_data, not_after)) + } + + /// Given a cerificate and the current time, check if it is currently + /// valid + fn verify_cert_time_validity( + cert: &CertificateDer<'_>, + now: UnixTime, + ) -> Result<(), rustls::Error> { + let cert = Self::parse_x509_certificate(cert)?; + Self::verify_cert_time_validity_parsed(&cert, now) + } + + /// Given a parsed cerificate and the current time, check if it is + /// currently valid + fn verify_cert_time_validity_parsed( + cert: &X509Certificate<'_>, + now: UnixTime, + ) -> Result<(), rustls::Error> { + let now = now.as_secs(); + let not_before: u64 = cert + .validity() + .not_before + .timestamp() + .try_into() + .map_err(|_| rustls::Error::General("invalid certificate not_before".into()))?; + let not_after: u64 = cert + .validity() + .not_after + .timestamp() + .try_into() + .map_err(|_| rustls::Error::General("invalid certificate not_after".into()))?; + + if now < not_before { + return Err(rustls::Error::InvalidCertificate( + rustls::CertificateError::NotValidYetContext { + time: UnixTime::since_unix_epoch(Duration::from_secs(now)), + not_before: UnixTime::since_unix_epoch(Duration::from_secs(not_before)), + }, + )); + } + + if now > not_after { + return Err(rustls::Error::InvalidCertificate( + rustls::CertificateError::ExpiredContext { + time: UnixTime::since_unix_epoch(Duration::from_secs(now)), + not_after: UnixTime::since_unix_epoch(Duration::from_secs(not_after)), + }, + )); + } + + Ok(()) + } + + /// Verify server name and time validity for self-signed certs + fn verify_server_cert_constraints( + cert: &CertificateDer<'_>, + server_name: &ServerName<'_>, + now: UnixTime, + ) -> Result<(), rustls::Error> { + let parsed = ParsedCertificate::try_from(cert)?; + let cert = Self::parse_x509_certificate(cert)?; + Self::verify_cert_time_validity_parsed(&cert, now)?; + verify_server_name(&parsed, server_name) + } + + /// Given a certificate with embedded attestation, verify the + /// attestation if it has not already been verified + fn verify_attestation_binding( + &self, + end_entity: &CertificateDer<'_>, + now: UnixTime, + ) -> Result<(), rustls::Error> { + let (expected_input_data, expiry) = Self::cert_binding_data(end_entity)?; + + // First check if we have already successfully verified the attestation + // associated with this certificate + { + let trusted_certificates = self.trusted_certificates.read().map_err(|_| { + rustls::Error::General("Trusted certificate cache lock poisoned".into()) + })?; + if trusted_certificates.get(&expected_input_data).is_some_and(|expiry| *expiry >= now) { + tracing::debug!("Skipping attestation verification for trusted certificate"); + return Ok(()); + } + } + + let attestation = Self::extract_custom_attestation_from_cert(end_entity)?; + + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + self.attestation_verifier + .verify_attestation(attestation, expected_input_data) + .await + .map_err(|err| { + tracing::warn!( + "Rejecting certificate after attestation verification failure: {err}" + ); + rustls::Error::InvalidCertificate( + rustls::CertificateError::ApplicationVerificationFailure, + ) + }) + }) + })?; + + let mut trusted_certificates = self.trusted_certificates.write().map_err(|_| { + rustls::Error::General("Trusted certificate cache lock poisoned".into()) + })?; + + // Remove any expired entries + trusted_certificates.retain(|_, cached_expiry| *cached_expiry >= now); + // Write trusted certificate details to cache + trusted_certificates.insert(expected_input_data, expiry); + + Ok(()) + } + + /// Helper for creating encoding related verification errors + fn bad_encoding(message: impl Into) -> rustls::Error { + let message = message.into(); + tracing::debug!("Rejecting malformed certificate or attestation payload: {message}"); + rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding) + } + + /// Helper to parse a certificate and map the error for rustls + fn parse_x509_certificate<'a>( + cert: &'a CertificateDer<'_>, + ) -> Result, rustls::Error> { + x509_parser::parse_x509_certificate(cert.as_ref()) + .map(|(_, parsed)| parsed) + .map_err(|err| Self::bad_encoding(format!("Invalid X.509 DER: {err}"))) + } + + /// Given a certificate get the hostname for report input data + fn hostname_from_cert(cert: &X509Certificate<'_>) -> Result, rustls::Error> { + cert.subject() + .iter_common_name() + .next() + .ok_or_else(|| Self::bad_encoding("Missing common name"))? + .as_str() + .map(|hostname| hostname.as_bytes().to_vec()) + .map_err(|err| Self::bad_encoding(format!("Invalid common name: {err}"))) + } +} + +impl ServerCertVerifier for AttestedCertificateVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + if let Some(server_inner) = &self.server_inner { + match server_inner.verify_server_cert( + end_entity, + intermediates, + server_name, + ocsp_response, + now, + ) { + Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)) => { + // handle self-signed certs differently + Self::verify_server_cert_constraints(end_entity, server_name, now)?; + } + Err(err) => return Err(err), + Ok(_) => {} + } + } else { + Self::verify_server_cert_constraints(end_entity, server_name, now)?; + } + self.verify_attestation_binding(end_entity, now)?; + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.provider.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.provider.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.provider.signature_verification_algorithms.supported_schemes() + } + + fn root_hint_subjects(&self) -> Option<&[DistinguishedName]> { + self.server_inner.as_ref().and_then(|server_inner| server_inner.root_hint_subjects()) + } +} + +impl ClientCertVerifier for AttestedCertificateVerifier { + fn offer_client_auth(&self) -> bool { + self.client_inner.as_ref().is_none_or(|client_inner| client_inner.offer_client_auth()) + } + + fn client_auth_mandatory(&self) -> bool { + self.client_inner.as_ref().is_none_or(|client_inner| client_inner.client_auth_mandatory()) + } + + fn root_hint_subjects(&self) -> &[DistinguishedName] { + self.client_inner.as_ref().map_or(&[], |client_inner| client_inner.root_hint_subjects()) + } + + fn verify_client_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + now: UnixTime, + ) -> Result { + if let Some(client_inner) = &self.client_inner { + match client_inner.verify_client_cert(end_entity, intermediates, now) { + Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)) => { + Self::verify_cert_time_validity(end_entity, now)?; + } + Err(err) => return Err(err), + Ok(_) => {} + } + } else { + Self::verify_cert_time_validity(end_entity, now)?; + } + self.verify_attestation_binding(end_entity, now)?; + Ok(ClientCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.provider.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.provider.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.provider.signature_verification_algorithms.supported_schemes() + } +} + +#[derive(Debug, Error)] +pub enum AttestedTlsError { + #[error("Failed to generate certificate key pair: {0}")] + CertificateKeyGeneration(#[source] rcgen::Error), + #[error("Failed to build certificate parameters: {0}")] + CertificateParams(#[source] rcgen::Error), + #[error("Failed to self-sign certificate: {0}")] + CertificateSigning(#[source] rcgen::Error), + #[error("Failed to build certificate verifier: {0}")] + VerifierBuilder(#[source] VerifierBuilderError), + #[error("Cetificate generation: {0}")] + RcGen(#[from] ra_tls::rcgen::Error), + #[error("RA-TLS: {0}")] + RaTls(#[source] anyhow::Error), + #[error("Rustls: {0}")] + Rustls(#[from] rustls::Error), + #[error("Failed to parse PEM certificate: {0}")] + Pem(#[from] rustls::pki_types::pem::Error), + #[error("System time: {0}")] + SystemTime(#[source] std::time::SystemTimeError), + #[error("No rustls CryptoProvider is installed")] + CryptoProviderUnavailable, + #[error("JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("Attestation: {0}")] + Attestation(#[from] attestation::AttestationError), +} + +#[cfg(test)] +mod tests { + use std::{io::Cursor, sync::Arc}; + + use ra_tls::rcgen::{BasicConstraints, CertificateParams, IsCa}; + use rustls::{ + CertificateError, + ClientConfig, + ClientConnection, + Error, + RootCertStore, + ServerConfig, + ServerConnection, + crypto::aws_lc_rs, + }; + + use super::*; + + /// Test helper to verify a certificate + fn verify_server_cert_direct( + verifier: &AttestedCertificateVerifier, + end_entity: &CertificateDer<'_>, + server_name: &ServerName<'_>, + now: UnixTime, + ) -> Result { + rustls::client::danger::ServerCertVerifier::verify_server_cert( + verifier, + end_entity, + &[], + server_name, + &[], + now, + ) + } + + #[tokio::test(flavor = "multi_thread")] + async fn certificate_resolver_creates_initial_certificate() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + "foo".to_string(), + vec![], + provider, + ) + .await + .unwrap(); + + let certificate = resolver.state.certificate.read().unwrap(); + + assert_eq!(certificate.len(), 1); + } + + #[tokio::test(flavor = "multi_thread")] + async fn server_and_client_configs_complete_a_handshake() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let server_name = "foo"; + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + server_name.to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + + let verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider.clone(), + ) + .unwrap(); + + let server_config = ServerConfig::builder_with_provider(provider.clone()) + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + let client_config = ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + .with_no_client_auth(); + + let mut client = ClientConnection::new( + Arc::new(client_config), + ServerName::try_from(server_name).unwrap(), + ) + .unwrap(); + + let mut server = ServerConnection::new(Arc::new(server_config)).unwrap(); + + while client.is_handshaking() || server.is_handshaking() { + transfer_tls_client_to_server(&mut client, &mut server); + transfer_tls_server_to_client(&mut server, &mut client); + } + + assert!(!client.is_handshaking()); + assert!(!server.is_handshaking()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn ca_signed_server_and_client_configs_complete_a_handshake() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let server_name = "foo"; + let ca = test_ca(); + let ca_cert = CertificateDer::from_pem_slice(ca.pem_cert.as_bytes()).unwrap(); + + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + Some(ca), + server_name.to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + + let certificate_chain = resolver.state.certificate.read().unwrap().clone(); + + assert_eq!(certificate_chain.len(), 2); + + let mut roots = RootCertStore::empty(); + roots.add(ca_cert).unwrap(); + + let verifier = AttestedCertificateVerifier::new_with_provider( + Some(roots), + AttestationVerifier::mock(), + provider.clone(), + ) + .unwrap(); + + let server_config = ServerConfig::builder_with_provider(provider.clone()) + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + + let client_config = ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + .with_no_client_auth(); + + let mut client = ClientConnection::new( + Arc::new(client_config), + ServerName::try_from(server_name).unwrap(), + ) + .unwrap(); + + let mut server = ServerConnection::new(Arc::new(server_config)).unwrap(); + + while client.is_handshaking() || server.is_handshaking() { + transfer_tls_client_to_server(&mut client, &mut server); + transfer_tls_server_to_client(&mut server, &mut client); + } + + assert!(!client.is_handshaking()); + assert!(!server.is_handshaking()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn certificate_is_renewed_before_expiry() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + "foo".to_string(), + vec![], + provider, + ) + .await + .unwrap(); + let initial_certificate = + resolver.state.certificate.read().unwrap().first().unwrap().clone(); + + tokio::time::sleep( + CERTIFICATE_VALIDITY - CERTIFICATE_RENEWAL_LEAD_TIME + Duration::from_secs(1), + ) + .await; + + let renewed_certificate = + resolver.state.certificate.read().unwrap().first().unwrap().clone(); + + assert_ne!(initial_certificate.as_ref(), renewed_certificate.as_ref()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn server_and_client_configs_complete_a_mutual_auth_handshake() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let server_name = "foo"; + + let server_resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + server_name.to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + + let client_resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + "client".to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + + let server_verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider.clone(), + ) + .unwrap(); + let client_verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider.clone(), + ) + .unwrap(); + + let server_config = ServerConfig::builder_with_provider(provider.clone()) + .with_safe_default_protocol_versions() + .unwrap() + .with_client_cert_verifier(Arc::new(server_verifier)) + .with_cert_resolver(Arc::new(server_resolver)); + let client_config = ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(client_verifier)) + .with_client_cert_resolver(Arc::new(client_resolver)); + + let mut client = ClientConnection::new( + Arc::new(client_config), + ServerName::try_from(server_name).unwrap(), + ) + .unwrap(); + let mut server = ServerConnection::new(Arc::new(server_config)).unwrap(); + + while client.is_handshaking() || server.is_handshaking() { + transfer_tls_client_to_server(&mut client, &mut server); + transfer_tls_server_to_client(&mut server, &mut client); + } + + assert!(!client.is_handshaking()); + assert!(!server.is_handshaking()); + assert!(client.peer_certificates().is_some()); + assert!(server.peer_certificates().is_some()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn alternate_san_completes_a_handshake() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let primary_name = "foo"; + let alternate_name = "bar"; + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + primary_name.to_string(), + vec![alternate_name.to_string(), primary_name.to_string()], + provider.clone(), + ) + .await + .unwrap(); + let verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider.clone(), + ) + .unwrap(); + + let server_config = ServerConfig::builder_with_provider(provider.clone()) + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + let client_config = ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + .with_no_client_auth(); + + let mut client = ClientConnection::new( + Arc::new(client_config), + ServerName::try_from(alternate_name).unwrap(), + ) + .unwrap(); + let mut server = ServerConnection::new(Arc::new(server_config)).unwrap(); + + while client.is_handshaking() || server.is_handshaking() { + transfer_tls_client_to_server(&mut client, &mut server); + transfer_tls_server_to_client(&mut server, &mut client); + } + + assert!(!client.is_handshaking()); + assert!(!server.is_handshaking()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn malformed_certificate_returns_bad_encoding() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider, + ) + .unwrap(); + let cert = CertificateDer::from(vec![1_u8, 2, 3, 4]); + + let result = verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ); + + assert_eq!(result.unwrap_err(), Error::InvalidCertificate(CertificateError::BadEncoding)); + } + + #[tokio::test(flavor = "multi_thread")] + async fn certificate_without_attestation_extension_returns_bad_encoding() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let cert = plain_self_signed_certificate("foo"); + let mut roots = RootCertStore::empty(); + roots.add(cert.clone()).unwrap(); + let verifier = AttestedCertificateVerifier::new_with_provider( + Some(roots), + AttestationVerifier::mock(), + provider, + ) + .unwrap(); + + let result = verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ); + + assert_eq!(result.unwrap_err(), Error::InvalidCertificate(CertificateError::BadEncoding)); + } + + #[tokio::test(flavor = "multi_thread")] + async fn self_signed_attested_certificate_with_wrong_name_is_rejected() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + "foo".to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + let verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider, + ) + .unwrap(); + let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); + + let result = verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("bar").unwrap(), + UnixTime::now(), + ); + + assert!(matches!( + result.unwrap_err(), + Error::InvalidCertificate(CertificateError::NotValidForNameContext { .. }) + )); + } + + #[tokio::test(flavor = "multi_thread")] + async fn certificate_binding_changes_when_identity_changes() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + "foo".to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + let original_cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); + let (original_report_data, original_not_after) = + AttestedCertificateVerifier::cert_binding_data(&original_cert).unwrap(); + let parsed_cert = + AttestedCertificateVerifier::parse_x509_certificate(&original_cert).unwrap(); + let not_before = parsed_cert.validity().not_before.timestamp() as u64; + let not_after = parsed_cert.validity().not_after.timestamp() as u64; + let key_pair = KeyPair::try_from(resolver.state.key_pair_der.clone()).unwrap(); + let replay_name = "bar".to_string(); + let replay_alt_names = vec![replay_name.clone()]; + let replayed_cert_request = CertRequest::builder() + .key(&key_pair) + .subject(&replay_name) + .alt_names(&replay_alt_names) + .not_before(SystemTime::UNIX_EPOCH + Duration::from_secs(not_before)) + .not_after(SystemTime::UNIX_EPOCH + Duration::from_secs(not_after)) + .usage_server_auth(true) + .usage_client_auth(true) + .build(); + let replayed_cert: CertificateDer<'static> = + replayed_cert_request.self_signed().unwrap().der().to_vec().into(); + let (replayed_report_data, replayed_not_after) = + AttestedCertificateVerifier::cert_binding_data(&replayed_cert).unwrap(); + + assert_eq!(original_not_after, replayed_not_after); + assert_ne!(original_report_data, replayed_report_data); + } + + #[tokio::test(flavor = "multi_thread")] + async fn attestation_rejection_returns_application_verification_failure() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + "foo".to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + let verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::expect_none(), + provider, + ) + .unwrap(); + let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); + + let result = verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ); + + assert_eq!( + result.unwrap_err(), + Error::InvalidCertificate(CertificateError::ApplicationVerificationFailure) + ); + } + + #[tokio::test(flavor = "multi_thread")] + async fn verifier_reuses_trusted_certificate_cache() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + "foo".to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + let mut verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider, + ) + .unwrap(); + let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); + let (expected_input_data, not_after) = + AttestedCertificateVerifier::cert_binding_data(&cert).unwrap(); + + verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ) + .unwrap(); + assert_eq!( + verifier.trusted_certificates.read().unwrap().get(&expected_input_data), + Some(¬_after) + ); + + verifier.attestation_verifier = AttestationVerifier::expect_none(); + + verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ) + .unwrap(); + } + + /// Helper to create a private cerificate authority + fn test_ca() -> CaCert { + let key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); + let mut params = CertificateParams::new(vec!["test-ca".to_string()]).unwrap(); + params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + let cert = params.self_signed(&key).unwrap(); + + CaCert::from_parts(key, cert) + } + + /// Helper to create a self signed cert with no attestation + fn plain_self_signed_certificate(subject_name: &str) -> CertificateDer<'static> { + let key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); + let params = CertificateParams::new(vec![subject_name.to_string()]).unwrap(); + params.self_signed(&key).unwrap().der().to_vec().into() + } + + fn transfer_tls_client_to_server(client: &mut ClientConnection, server: &mut ServerConnection) { + let mut tls = Vec::new(); + + while client.wants_write() { + client.write_tls(&mut tls).unwrap(); + } + + if tls.is_empty() { + return; + } + + server.read_tls(&mut Cursor::new(tls)).unwrap(); + server.process_new_packets().unwrap(); + } + + fn transfer_tls_server_to_client(server: &mut ServerConnection, client: &mut ClientConnection) { + let mut tls = Vec::new(); + + while server.wants_write() { + server.write_tls(&mut tls).unwrap(); + } + + if tls.is_empty() { + return; + } + + client.read_tls(&mut Cursor::new(tls)).unwrap(); + client.process_new_packets().unwrap(); + } +} diff --git a/crates/attested-tls/tests/nested_tls.rs b/crates/attested-tls/tests/nested_tls.rs new file mode 100644 index 0000000..6bc818a --- /dev/null +++ b/crates/attested-tls/tests/nested_tls.rs @@ -0,0 +1,119 @@ +//! Provides a test demonstrating using nested-tls and attested-tls together +use std::sync::Arc; + +use attestation::{AttestationGenerator, AttestationType, AttestationVerifier}; +use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier}; +use nested_tls::{client::NestingTlsConnector, server::NestingTlsAcceptor}; +use ra_tls::rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256}; +use rustls::{ + ClientConfig, + RootCertStore, + ServerConfig, + crypto::{CryptoProvider, aws_lc_rs}, + pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName}, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + +#[tokio::test(flavor = "multi_thread")] +async fn nested_tls_uses_attested_tls_for_inner_session() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let (outer_server, outer_client) = plain_tls_config_pair(provider.clone()); + let inner_server = attested_server_config("localhost", provider.clone()).await; + let inner_client = attested_client_config(provider.clone()); + + let acceptor = NestingTlsAcceptor::new(Arc::new(outer_server), Arc::new(inner_server)); + let connector = NestingTlsConnector::new(Arc::new(outer_client), Arc::new(inner_client)); + + let (client_io, server_io) = duplex(16 * 1024); + + let server = tokio::spawn(async move { + let mut stream = acceptor.accept(server_io).await.unwrap(); + + let mut req = [0_u8; 5]; + stream.read_exact(&mut req).await.unwrap(); + assert_eq!(&req, b"hello"); + + stream.write_all(b"world").await.unwrap(); + stream.flush().await.unwrap(); + }); + + let domain = ServerName::try_from("localhost").unwrap(); + let mut stream = connector.connect(domain, client_io).await.unwrap(); + + stream.write_all(b"hello").await.unwrap(); + stream.flush().await.unwrap(); + + let mut resp = [0_u8; 5]; + stream.read_exact(&mut resp).await.unwrap(); + assert_eq!(&resp, b"world"); + + server.await.unwrap(); +} + +/// Create vanilla TLS server and client config for outer session +fn plain_tls_config_pair(provider: Arc) -> (ServerConfig, ClientConfig) { + let subject_name = "localhost"; + let key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); + let mut params = ra_tls::rcgen::CertificateParams::new(vec![subject_name.to_string()]).unwrap(); + params + .subject_alt_names + .push(ra_tls::rcgen::SanType::DnsName(subject_name.try_into().unwrap())); + let cert = params.self_signed(&key).unwrap(); + let cert_der: CertificateDer<'static> = cert.der().clone(); + let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key.serialize_der())); + + let server = ServerConfig::builder_with_provider(provider.clone()) + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_single_cert(vec![cert_der.clone()], key_der) + .unwrap(); + + let mut roots = RootCertStore::empty(); + roots.add(cert_der).unwrap(); + + let client = ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .with_root_certificates(roots) + .with_no_client_auth(); + + (server, client) +} + +/// Create attested server TLS config with mock DCAP attestation and +/// self-signed certs +async fn attested_server_config(server_name: &str, provider: Arc) -> ServerConfig { + let resolver = AttestedCertificateResolver::new_with_provider( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + None, + server_name.to_string(), + vec![], + provider.clone(), + ) + .await + .unwrap(); + + ServerConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)) +} + +/// Create client TLS config with attestation verification +fn attested_client_config(provider: Arc) -> ClientConfig { + let verifier = AttestedCertificateVerifier::new_with_provider( + None, + AttestationVerifier::mock(), + provider.clone(), + ) + .unwrap(); + + ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + .with_no_client_auth() +}