From 744bf2fd812cca1e97718d84a21c6290f5fb4f5c Mon Sep 17 00:00:00 2001 From: irving ou Date: Thu, 2 Apr 2026 15:39:26 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20QUIC=20agent=20tunnel=20=E2=80=94=20pro?= =?UTF-8?q?tocol,=20listener,=20agent=20client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add QUIC-based agent tunnel core infrastructure. Agents in private networks connect outbound to Gateway via QUIC/mTLS, advertise reachable subnets and domains, and proxy TCP connections on behalf of Gateway. Protocol (agent-tunnel-proto crate): - RouteAdvertise with subnets + domain advertisements - ConnectMessage/ConnectResponse for session stream setup - Heartbeat/HeartbeatAck for liveness detection - Protocol version negotiation (v2) Gateway (agent_tunnel module): - QUIC listener with mTLS authentication - Agent registry with subnet/domain tracking - Certificate authority for agent enrollment - Enrollment token store (one-time tokens) - Bidirectional proxy stream multiplexing Agent (devolutions-agent): - QUIC client with auto-reconnect and exponential backoff - Agent enrollment with config merge (preserves existing settings) - Domain auto-detection (Windows: USERDNSDOMAIN, Linux: resolv.conf) - Subnet validation on incoming connections - Certificate file permissions (0o600 on Unix) API endpoints: - POST /jet/agent-tunnel/enroll — agent enrollment - GET /jet/agent-tunnel/agents — list agents - GET /jet/agent-tunnel/agents/{id} — get agent - DELETE /jet/agent-tunnel/agents/{id} — delete agent - POST /jet/agent-tunnel/agents/resolve-target — routing diagnostics Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.lock | 393 +++++++-- crates/agent-tunnel-proto/Cargo.toml | 21 + crates/agent-tunnel-proto/src/control.rs | 308 +++++++ crates/agent-tunnel-proto/src/error.rs | 15 + crates/agent-tunnel-proto/src/lib.rs | 24 + crates/agent-tunnel-proto/src/session.rs | 215 +++++ crates/agent-tunnel-proto/src/version.rs | 37 + devolutions-agent/Cargo.toml | 16 +- devolutions-agent/src/config.rs | 75 +- devolutions-agent/src/domain_detect.rs | 112 +++ devolutions-agent/src/enrollment.rs | 192 +++++ devolutions-agent/src/lib.rs | 3 + devolutions-agent/src/main.rs | 235 +++++- devolutions-agent/src/service.rs | 7 +- devolutions-agent/src/tunnel.rs | 512 ++++++++++++ devolutions-gateway/Cargo.toml | 15 + devolutions-gateway/src/agent_tunnel/cert.rs | 433 ++++++++++ .../src/agent_tunnel/enrollment_store.rs | 126 +++ .../src/agent_tunnel/listener.rs | 336 ++++++++ devolutions-gateway/src/agent_tunnel/mod.rs | 15 + .../src/agent_tunnel/registry.rs | 773 ++++++++++++++++++ .../src/agent_tunnel/stream.rs | 37 + .../src/api/agent_enrollment.rs | 302 +++++++ devolutions-gateway/src/api/mod.rs | 2 + devolutions-gateway/src/api/webapp.rs | 1 + devolutions-gateway/src/config.rs | 39 + devolutions-gateway/src/extract.rs | 58 ++ devolutions-gateway/src/generic_client.rs | 74 ++ devolutions-gateway/src/lib.rs | 3 + devolutions-gateway/src/listener.rs | 1 + devolutions-gateway/src/middleware/auth.rs | 8 + devolutions-gateway/src/ngrok.rs | 1 + devolutions-gateway/src/rd_clean_path.rs | 4 +- devolutions-gateway/src/service.rs | 32 +- devolutions-gateway/src/token.rs | 16 +- devolutions-gateway/tests/config.rs | 6 + 36 files changed, 4348 insertions(+), 99 deletions(-) create mode 100644 crates/agent-tunnel-proto/Cargo.toml create mode 100644 crates/agent-tunnel-proto/src/control.rs create mode 100644 crates/agent-tunnel-proto/src/error.rs create mode 100644 crates/agent-tunnel-proto/src/lib.rs create mode 100644 crates/agent-tunnel-proto/src/session.rs create mode 100644 crates/agent-tunnel-proto/src/version.rs create mode 100644 devolutions-agent/src/domain_detect.rs create mode 100644 devolutions-agent/src/enrollment.rs create mode 100644 devolutions-agent/src/tunnel.rs create mode 100644 devolutions-gateway/src/agent_tunnel/cert.rs create mode 100644 devolutions-gateway/src/agent_tunnel/enrollment_store.rs create mode 100644 devolutions-gateway/src/agent_tunnel/listener.rs create mode 100644 devolutions-gateway/src/agent_tunnel/mod.rs create mode 100644 devolutions-gateway/src/agent_tunnel/registry.rs create mode 100644 devolutions-gateway/src/agent_tunnel/stream.rs create mode 100644 devolutions-gateway/src/api/agent_enrollment.rs diff --git a/Cargo.lock b/Cargo.lock index 46b826db1..6b1ba94b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,7 +36,7 @@ checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", "cipher 0.4.4", - "cpufeatures 0.2.17", + "cpufeatures", ] [[package]] @@ -47,7 +47,7 @@ checksum = "7e713c57c2a2b19159e7be83b9194600d7e8eb3b7c2cd67e671adf47ce189a05" dependencies = [ "cfg-if", "cipher 0.5.0-rc.1", - "cpufeatures 0.2.17", + "cpufeatures", ] [[package]] @@ -74,6 +74,19 @@ dependencies = [ "const-oid 0.10.2", ] +[[package]] +name = "agent-tunnel-proto" +version = "0.0.0" +dependencies = [ + "bincode", + "ipnetwork", + "proptest", + "serde", + "thiserror 2.0.18", + "tokio 1.49.0", + "uuid", +] + [[package]] name = "ahash" version = "0.8.12" @@ -167,7 +180,7 @@ checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" dependencies = [ "base64ct", "blake2", - "cpufeatures 0.2.17", + "cpufeatures", "password-hash", ] @@ -177,13 +190,29 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive 0.5.1", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "asn1-rs" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60" dependencies = [ - "asn1-rs-derive", + "asn1-rs-derive 0.6.0", "asn1-rs-impl", "displaydoc", "nom", @@ -192,6 +221,18 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2 1.0.106", + "quote 1.0.44", + "syn 2.0.114", + "synstructure", +] + [[package]] name = "asn1-rs-derive" version = "0.6.0" @@ -837,6 +878,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "ceviche" version = "0.7.0" @@ -885,18 +932,7 @@ checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" dependencies = [ "cfg-if", "cipher 0.4.4", - "cpufeatures 0.2.17", -] - -[[package]] -name = "chacha20" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" -dependencies = [ - "cfg-if", - "cpufeatures 0.3.0", - "rand_core 0.10.0", + "cpufeatures", ] [[package]] @@ -906,7 +942,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead 0.5.2", - "chacha20 0.9.1", + "chacha20", "cipher 0.4.4", "poly1305", "zeroize", @@ -1020,6 +1056,16 @@ dependencies = [ "cc", ] +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes 1.11.1", + "memchr", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1082,15 +1128,6 @@ dependencies = [ "libc", ] -[[package]] -name = "cpufeatures" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" -dependencies = [ - "libc", -] - [[package]] name = "crc32fast" version = "1.5.0" @@ -1310,7 +1347,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f9200d1d13637f15a6acb71e758f64624048d85b31a5fdbfd8eca1e2687d0b7" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "curve25519-dalek-derive", "digest 0.11.0-rc.3", "fiat-crypto", @@ -1330,6 +1367,20 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -1380,13 +1431,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs 0.6.2", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "der-parser" version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" dependencies = [ - "asn1-rs", + "asn1-rs 0.7.1", "displaydoc", "nom", "num-traits", @@ -1438,9 +1503,12 @@ dependencies = [ name = "devolutions-agent" version = "2026.1.1" dependencies = [ + "agent-tunnel-proto", "anyhow", "async-trait", "aws-lc-rs", + "base64 0.22.1", + "bincode", "bytes 1.11.1", "camino", "ceviche", @@ -1454,16 +1522,22 @@ dependencies = [ "futures", "hex", "http-client-proxy", + "ipnetwork", "ironrdp", "notify-debouncer-mini", "parking_lot", + "quinn", "rand 0.8.5", + "rcgen", "reqwest", + "rustls 0.23.37", "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_json", "sha2 0.10.9", "tap", + "tempfile", "thiserror 2.0.18", "tokio 1.49.0", "tokio-rustls", @@ -1492,12 +1566,15 @@ dependencies = [ name = "devolutions-gateway" version = "2026.1.1" dependencies = [ + "agent-tunnel-proto", "anyhow", "argon2", "async-trait", "axum 0.8.8", "axum-extra", "backoff", + "base64 0.22.1", + "bincode", "bitflags 2.10.0", "bytes 1.11.1", "cadeau", @@ -1505,6 +1582,7 @@ dependencies = [ "ceviche", "cfg-if", "chacha20poly1305", + "dashmap", "devolutions-agent-shared", "devolutions-gateway-generators", "devolutions-gateway-task", @@ -1521,6 +1599,7 @@ dependencies = [ "http-client-proxy", "hyper 1.8.1", "hyper-util", + "ipnetwork", "ironrdp-acceptor", "ironrdp-connector", "ironrdp-core", @@ -1539,15 +1618,22 @@ dependencies = [ "nonempty", "parking_lot", "pcap-file", + "pem", "picky", "picky-krb", "pin-project-lite 0.2.17", "portpicker", "proptest", + "quinn", + "rand 0.8.5", + "rcgen", "reqwest", "rstest", + "rustls 0.23.37", "rustls-cng", "rustls-native-certs", + "rustls-pemfile 2.2.0", + "rustls-pki-types", "secrecy 0.10.3", "secure-memory", "serde", @@ -1586,6 +1672,7 @@ dependencies = [ "video-streamer", "windows-sys 0.61.2", "x509-cert", + "x509-parser", "zeroize", ] @@ -2002,16 +2089,16 @@ dependencies = [ [[package]] name = "embed-resource" -version = "3.0.8" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63a1d0de4f2249aa0ff5884d7080814f446bb241a559af6c170a41e878ed2d45" +checksum = "55a075fc573c64510038d7ee9abc7990635863992f83ebc52c8b433b8411a02e" dependencies = [ "cc", "memchr", "rustc_version", "toml", "vswhom", - "winreg", + "winreg 0.55.0", ] [[package]] @@ -2090,6 +2177,18 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fastbloom" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4" +dependencies = [ + "getrandom 0.3.4", + "libm", + "rand 0.9.2", + "siphasher", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -2377,7 +2476,6 @@ dependencies = [ "cfg-if", "libc", "r-efi", - "rand_core 0.10.0", "wasip2", "wasip3", ] @@ -3119,15 +3217,14 @@ dependencies = [ [[package]] name = "ipconfig" -version = "0.3.4" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d40460c0ce33d6ce4b0630ad68ff63d6661961c48b6dba35e5a4d81cfb48222" +checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2 0.6.2", + "socket2 0.5.10", "widestring 1.2.1", - "windows-registry 0.6.1", - "windows-result 0.4.1", - "windows-sys 0.61.2", + "windows-sys 0.48.0", + "winreg 0.50.0", ] [[package]] @@ -3136,6 +3233,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "iri-string" version = "0.7.10" @@ -3295,7 +3401,7 @@ dependencies = [ "bit_field", "bitflags 2.10.0", "byteorder", - "der-parser", + "der-parser 10.0.0", "ironrdp-core", "ironrdp-error", "md-5 0.10.6", @@ -3523,6 +3629,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "job-queue" version = "0.0.0" @@ -3593,7 +3721,7 @@ version = "0.2.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d546793a04a1d3049bd192856f804cfe96356e2cf36b54b4e575155babe9f41" dependencies = [ - "cpufeatures 0.2.17", + "cpufeatures", ] [[package]] @@ -3734,9 +3862,9 @@ dependencies = [ [[package]] name = "libsql" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30fe980ac5693ed1f3db490559fb578885e913a018df64af8a1a46e1959a78df" +checksum = "2329faffc510cc3c6b4f00169a39177cc7099d3ed7647fc92f7cf26e53a8d976" dependencies = [ "anyhow", "async-stream", @@ -3773,9 +3901,9 @@ dependencies = [ [[package]] name = "libsql-ffi" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0be1da6f123ceb2cd23f469883415cab9ee963286a85d61e22afb8b12e15e681" +checksum = "6cd1c1662822495393327856774f6803be25d85bfdcd5b9d4af35458f5daaf75" dependencies = [ "bindgen", "cc", @@ -3785,9 +3913,9 @@ dependencies = [ [[package]] name = "libsql-hrana" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3358538b52cfcf9af4fe7aeb57d6843aafed2e8af80807bd636fd1448e94ea7" +checksum = "646d0aa75e412769018422f0da798f72e93bd51964f0b2ddad4317aa779ae444" dependencies = [ "base64 0.21.7", "bytes 1.11.1", @@ -3797,9 +3925,9 @@ dependencies = [ [[package]] name = "libsql-rusqlite" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b646f94fc1d266e481c38a2d44d6d9d1be3ad04b56b90457acfb310dc450030e" +checksum = "5a4ce3a78c6e3c2b23b02ab6272df8340e1c53380497979d456882254f348d5f" dependencies = [ "bitflags 2.10.0", "fallible-iterator 0.2.0", @@ -3829,9 +3957,9 @@ dependencies = [ [[package]] name = "libsql-sys" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90725458cc4461bc82f8f7983e80b002ea4f64b5184e1462f252d0dd74b122f5" +checksum = "2a3c326fcfc36fe7578238d5ee6b58c529f8c76372acd61ec50267529cdaff95" dependencies = [ "bytes 1.11.1", "libsql-ffi", @@ -3843,9 +3971,9 @@ dependencies = [ [[package]] name = "libsql_replication" -version = "0.9.30" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bba5c9b3a26aca06d70f6a3646ba341cf574a548355353fe135af524b1b77cc" +checksum = "1d9a2e469ac8400659bd31f81a745908bcc5cb6b40be2f2ff8de90b15bec5501" dependencies = [ "aes 0.8.4", "async-stream", @@ -4541,6 +4669,15 @@ dependencies = [ "serde", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs 0.6.2", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -4565,9 +4702,9 @@ checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "openssl" -version = "0.10.76" +version = "0.10.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" dependencies = [ "bitflags 2.10.0", "cfg-if", @@ -4603,9 +4740,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.112" +version = "0.9.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" dependencies = [ "cc", "libc", @@ -4731,6 +4868,16 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -5111,7 +5258,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" dependencies = [ - "cpufeatures 0.2.17", + "cpufeatures", "opaque-debug", "universal-hash 0.5.1", ] @@ -5123,7 +5270,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ffd40cc99d0fbb02b4b3771346b811df94194bc103983efa0203c8893755085" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "universal-hash 0.6.0-rc.2", ] @@ -5162,9 +5309,9 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.13" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dc729a129e682e8d24170cd30ae1aa01b336b096cbb56df6d534ffec133d186" +checksum = "54b858f82211e84682fecd373f68e1ceae642d8d751a1ebd13f33de6257b3e20" dependencies = [ "bytes 1.11.1", "chrono", @@ -5412,7 +5559,7 @@ dependencies = [ "system-configuration-sys 0.6.0", "url", "windows-sys 0.61.2", - "winreg", + "winreg 0.55.0", ] [[package]] @@ -5457,6 +5604,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes 1.11.1", + "fastbloom", "getrandom 0.3.4", "lru-slab", "rand 0.9.2", @@ -5464,6 +5612,7 @@ dependencies = [ "rustc-hash 2.1.1", "rustls 0.23.37", "rustls-pki-types", + "rustls-platform-verifier", "slab", "thiserror 2.0.18", "tinyvec", @@ -5536,17 +5685,6 @@ dependencies = [ "rand_core 0.9.5", ] -[[package]] -name = "rand" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" -dependencies = [ - "chacha20 0.10.0", - "getrandom 0.4.1", - "rand_core 0.10.0", -] - [[package]] name = "rand_chacha" version = "0.3.1" @@ -5585,12 +5723,6 @@ dependencies = [ "getrandom 0.3.4", ] -[[package]] -name = "rand_core" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" - [[package]] name = "rand_xorshift" version = "0.4.0" @@ -5629,6 +5761,20 @@ dependencies = [ "cipher 0.5.0-rc.1", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring 0.17.14", + "rustls-pki-types", + "time", + "x509-parser", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -6003,6 +6149,33 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-platform-verifier" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +dependencies = [ + "core-foundation 0.10.1", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls 0.23.37", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + [[package]] name = "rustls-webpki" version = "0.103.9" @@ -6359,7 +6532,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.10.7", ] @@ -6370,7 +6543,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5e046edf639aa2e7afb285589e5405de2ef7e61d4b0ac1e30256e3eab911af9" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.11.0-rc.3", ] @@ -6381,7 +6554,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.10.7", ] @@ -6392,7 +6565,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d1e3878ab0f98e35b2df35fe53201d088299b41a6bb63e3e34dada2ac4abd924" dependencies = [ "cfg-if", - "cpufeatures 0.2.17", + "cpufeatures", "digest 0.11.0-rc.3", ] @@ -7067,9 +7240,9 @@ dependencies = [ [[package]] name = "tokio-postgres" -version = "0.7.17" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dd8df5ef180f6364759a6f00f7aadda4fbbac86cdee37480826a6ff9f3574ce" +checksum = "dcea47c8f71744367793f16c2db1f11cb859d28f436bdb4ca9193eb1f787ee42" dependencies = [ "async-trait", "byteorder", @@ -7084,7 +7257,7 @@ dependencies = [ "pin-project-lite 0.2.17", "postgres-protocol", "postgres-types", - "rand 0.10.0", + "rand 0.9.2", "socket2 0.6.2", "tokio 1.49.0", "tokio-util", @@ -7571,9 +7744,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.23" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -8164,6 +8337,15 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.11" @@ -8791,6 +8973,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + [[package]] name = "winreg" version = "0.55.0" @@ -8951,6 +9143,24 @@ dependencies = [ "tls_codec", ] +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs 0.6.2", + "data-encoding", + "der-parser 9.0.0", + "lazy_static", + "nom", + "oid-registry", + "ring 0.17.14", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xmf-sys" version = "0.4.0" @@ -8960,6 +9170,15 @@ dependencies = [ "dlib", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.8.1" diff --git a/crates/agent-tunnel-proto/Cargo.toml b/crates/agent-tunnel-proto/Cargo.toml new file mode 100644 index 000000000..5822f5908 --- /dev/null +++ b/crates/agent-tunnel-proto/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "agent-tunnel-proto" +version = "0.0.0" +authors = ["Devolutions Inc. "] +edition = "2024" +publish = false + +[lints] +workspace = true + +[dependencies] +bincode = "1.3" +ipnetwork = "0.20" +serde = { version = "1", features = ["derive"] } +thiserror = "2.0" +tokio = { version = "1.45", features = ["io-util"] } +uuid = { version = "1.17", features = ["v4", "serde"] } + +[dev-dependencies] +proptest = "1.7" +tokio = { version = "1.45", features = ["rt", "macros"] } diff --git a/crates/agent-tunnel-proto/src/control.rs b/crates/agent-tunnel-proto/src/control.rs new file mode 100644 index 000000000..3fe35358b --- /dev/null +++ b/crates/agent-tunnel-proto/src/control.rs @@ -0,0 +1,308 @@ +use ipnetwork::Ipv4Network; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}; + +use crate::error::ProtoError; +use crate::version::CURRENT_PROTOCOL_VERSION; + +/// Maximum encoded message size (1 MiB) to prevent denial-of-service via oversized frames. +pub const MAX_CONTROL_MESSAGE_SIZE: u32 = 1024 * 1024; + +/// A DNS domain advertisement with its source. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct DomainAdvertisement { + /// The DNS domain (e.g., "contoso.local"). + pub domain: String, + /// Whether this domain was auto-detected (`true`) or explicitly configured (`false`). + pub auto_detected: bool, +} + +/// Control-plane messages exchanged over the dedicated control stream (stream ID 0). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ControlMessage { + /// Agent advertises subnets and domains it can reach. + RouteAdvertise { + protocol_version: u16, + /// Monotonically increasing epoch within this agent process lifetime. + epoch: u64, + /// Reachable IPv4 subnets. + subnets: Vec, + /// DNS domains this agent can resolve, with source tracking. + domains: Vec, + }, + + /// Periodic liveness probe. + Heartbeat { + protocol_version: u16, + /// Milliseconds since UNIX epoch (sender's wall clock). + timestamp_ms: u64, + /// Number of currently active proxy streams on this connection. + active_stream_count: u32, + }, + + /// Acknowledgement to a Heartbeat. + HeartbeatAck { + protocol_version: u16, + /// Echoed timestamp from the corresponding Heartbeat. + timestamp_ms: u64, + }, +} + +impl ControlMessage { + /// Create a new RouteAdvertise with the current protocol version. + pub fn route_advertise(epoch: u64, subnets: Vec, domains: Vec) -> Self { + Self::RouteAdvertise { + protocol_version: CURRENT_PROTOCOL_VERSION, + epoch, + subnets, + domains, + } + } + + /// Create a new Heartbeat with the current protocol version. + pub fn heartbeat(timestamp_ms: u64, active_stream_count: u32) -> Self { + Self::Heartbeat { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + active_stream_count, + } + } + + /// Create a new HeartbeatAck with the current protocol version. + pub fn heartbeat_ack(timestamp_ms: u64) -> Self { + Self::HeartbeatAck { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + } + } + + /// Length-prefixed bincode encode and write to an async writer. + pub async fn encode(&self, writer: &mut W) -> Result<(), ProtoError> { + let payload = bincode::serialize(self)?; + let len = u32::try_from(payload.len()).map_err(|_| ProtoError::MessageTooLarge { + size: u32::MAX, + max: MAX_CONTROL_MESSAGE_SIZE, + })?; + if MAX_CONTROL_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_CONTROL_MESSAGE_SIZE, + }); + } + writer.write_all(&len.to_be_bytes()).await?; + writer.write_all(&payload).await?; + writer.flush().await?; + Ok(()) + } + + /// Read and decode a length-prefixed bincode message from an async reader. + pub async fn decode(reader: &mut R) -> Result { + let mut len_buf = [0u8; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf); + + if MAX_CONTROL_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_CONTROL_MESSAGE_SIZE, + }); + } + + let mut payload = vec![0u8; len as usize]; + reader.read_exact(&mut payload).await?; + let msg: Self = bincode::deserialize(&payload)?; + Ok(msg) + } + + /// Extract the protocol version from any variant. + pub fn protocol_version(&self) -> u16 { + match self { + Self::RouteAdvertise { protocol_version, .. } + | Self::Heartbeat { protocol_version, .. } + | Self::HeartbeatAck { protocol_version, .. } => *protocol_version, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn roundtrip_route_advertise() { + let msg = ControlMessage::route_advertise( + 42, + vec![ + "10.0.0.0/8".parse().expect("valid CIDR"), + "192.168.1.0/24".parse().expect("valid CIDR"), + ], + vec![], + ); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_route_advertise_with_domains() { + let msg = ControlMessage::route_advertise( + 42, + vec!["10.0.0.0/8".parse().expect("valid CIDR")], + vec![ + DomainAdvertisement { + domain: "contoso.local".to_owned(), + auto_detected: false, + }, + DomainAdvertisement { + domain: "finance.contoso.local".to_owned(), + auto_detected: true, + }, + ], + ); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + + match &decoded { + ControlMessage::RouteAdvertise { domains, .. } => { + assert_eq!(domains.len(), 2); + assert_eq!(domains[0].domain, "contoso.local"); + assert!(!domains[0].auto_detected); + assert_eq!(domains[1].domain, "finance.contoso.local"); + assert!(domains[1].auto_detected); + } + _ => panic!("expected RouteAdvertise"), + } + } + + #[tokio::test] + async fn roundtrip_route_advertise_empty_domains() { + let msg = ControlMessage::route_advertise(1, vec!["192.168.1.0/24".parse().expect("valid CIDR")], vec![]); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_heartbeat() { + let msg = ControlMessage::heartbeat(1_700_000_000_000, 5); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_heartbeat_ack() { + let msg = ControlMessage::heartbeat_ack(1_700_000_000_000); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ControlMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn reject_oversized_message() { + // Craft a length prefix that exceeds the maximum + let bad_len = (MAX_CONTROL_MESSAGE_SIZE + 1).to_be_bytes(); + let mut buf = bad_len.to_vec(); + buf.extend_from_slice(&[0u8; 32]); // dummy payload + + let result = ControlMessage::decode(&mut buf.as_slice()).await; + assert!(result.is_err()); + } +} + +#[cfg(test)] +mod proptests { + use proptest::prelude::*; + + use super::*; + use crate::version::CURRENT_PROTOCOL_VERSION; + + fn arb_ipv4_network() -> impl Strategy { + (any::<[u8; 4]>(), 0u8..=32).prop_map(|(octets, prefix)| { + let ip = std::net::Ipv4Addr::from(octets); + // Use network() to normalize the address for the given prefix + Ipv4Network::new(ip, prefix) + .map(|n| Ipv4Network::new(n.network(), prefix).expect("normalized network should be valid")) + .unwrap_or_else(|_| Ipv4Network::new(std::net::Ipv4Addr::UNSPECIFIED, 0).expect("0.0.0.0/0 is valid")) + }) + } + + fn arb_domain_advertisement() -> impl Strategy { + ("[a-z]{3,10}\\.[a-z]{2,5}", any::()) + .prop_map(|(domain, auto_detected)| DomainAdvertisement { domain, auto_detected }) + } + + fn arb_control_message() -> impl Strategy { + prop_oneof![ + ( + any::(), + proptest::collection::vec(arb_ipv4_network(), 0..50), + proptest::collection::vec(arb_domain_advertisement(), 0..5), + ) + .prop_map(|(epoch, subnets, domains)| { + ControlMessage::RouteAdvertise { + protocol_version: CURRENT_PROTOCOL_VERSION, + epoch, + subnets, + domains, + } + }), + (any::(), any::()).prop_map(|(timestamp_ms, active_stream_count)| { + ControlMessage::Heartbeat { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + active_stream_count, + } + }), + any::().prop_map(|timestamp_ms| ControlMessage::HeartbeatAck { + protocol_version: CURRENT_PROTOCOL_VERSION, + timestamp_ms, + }), + ] + } + + proptest! { + #[test] + fn control_message_roundtrip(msg in arb_control_message()) { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().expect("tokio runtime"); + rt.block_on(async { + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + let decoded = ControlMessage::decode(&mut buf.as_slice()).await.expect("decode should succeed"); + prop_assert_eq!(msg, decoded); + Ok(()) + })?; + } + } +} diff --git a/crates/agent-tunnel-proto/src/error.rs b/crates/agent-tunnel-proto/src/error.rs new file mode 100644 index 000000000..e7c10e92e --- /dev/null +++ b/crates/agent-tunnel-proto/src/error.rs @@ -0,0 +1,15 @@ +/// Protocol-level errors for the agent tunnel. +#[derive(Debug, thiserror::Error)] +pub enum ProtoError { + #[error("unsupported protocol version {received} (supported: {min}..={max})")] + UnsupportedVersion { received: u16, min: u16, max: u16 }, + + #[error("message too large: {size} bytes (max: {max})")] + MessageTooLarge { size: u32, max: u32 }, + + #[error("bincode encode/decode error: {0}")] + Bincode(#[from] bincode::Error), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} diff --git a/crates/agent-tunnel-proto/src/lib.rs b/crates/agent-tunnel-proto/src/lib.rs new file mode 100644 index 000000000..2ae1853bc --- /dev/null +++ b/crates/agent-tunnel-proto/src/lib.rs @@ -0,0 +1,24 @@ +//! Protocol definitions for the QUIC-based agent tunnel. +//! +//! This crate defines the binary protocol exchanged between Gateway and Agent +//! over QUIC streams. All messages use length-prefixed bincode encoding and +//! carry a `protocol_version` field for forward compatibility. +//! +//! ## Stream model +//! +//! - **Control stream** (QUIC stream 0): carries [`ControlMessage`] variants +//! (route advertisements, heartbeats). +//! - **Session streams** (QUIC streams 1..N): each stream proxies one TCP +//! connection. The first message is a [`ConnectMessage`] from Gateway, +//! followed by a [`ConnectResponse`] from Agent. After a successful +//! response, raw TCP bytes flow bidirectionally. + +pub mod control; +pub mod error; +pub mod session; +pub mod version; + +pub use control::{ControlMessage, DomainAdvertisement, MAX_CONTROL_MESSAGE_SIZE}; +pub use error::ProtoError; +pub use session::{ConnectMessage, ConnectResponse, MAX_SESSION_MESSAGE_SIZE}; +pub use version::{CURRENT_PROTOCOL_VERSION, MIN_SUPPORTED_VERSION, validate_protocol_version}; diff --git a/crates/agent-tunnel-proto/src/session.rs b/crates/agent-tunnel-proto/src/session.rs new file mode 100644 index 000000000..d202f0f1a --- /dev/null +++ b/crates/agent-tunnel-proto/src/session.rs @@ -0,0 +1,215 @@ +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}; +use uuid::Uuid; + +use crate::error::ProtoError; +use crate::version::CURRENT_PROTOCOL_VERSION; + +/// Maximum encoded session message size (64 KiB). +pub const MAX_SESSION_MESSAGE_SIZE: u32 = 64 * 1024; + +/// Length-prefixed bincode encode and write to an async writer. +async fn encode_framed(msg: &T, writer: &mut W) -> Result<(), ProtoError> { + let payload = bincode::serialize(msg)?; + let len = u32::try_from(payload.len()).map_err(|_| ProtoError::MessageTooLarge { + size: u32::MAX, + max: MAX_SESSION_MESSAGE_SIZE, + })?; + if MAX_SESSION_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_SESSION_MESSAGE_SIZE, + }); + } + writer.write_all(&len.to_be_bytes()).await?; + writer.write_all(&payload).await?; + writer.flush().await?; + Ok(()) +} + +/// Read and decode a length-prefixed bincode message from an async reader. +async fn decode_framed(reader: &mut R) -> Result { + let mut len_buf = [0u8; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_be_bytes(len_buf); + + if MAX_SESSION_MESSAGE_SIZE < len { + return Err(ProtoError::MessageTooLarge { + size: len, + max: MAX_SESSION_MESSAGE_SIZE, + }); + } + + let mut payload = vec![0u8; len as usize]; + reader.read_exact(&mut payload).await?; + let msg: T = bincode::deserialize(&payload)?; + Ok(msg) +} + +/// Request from Gateway to Agent to open a TCP connection to a target. +/// +/// Sent as the first message on a newly opened QUIC bidirectional stream. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ConnectMessage { + pub protocol_version: u16, + /// Association/session ID from the Gateway. + pub session_id: Uuid, + /// Target address in `host:port` form (e.g., `"192.168.1.100:3389"`). + pub target: String, +} + +/// Agent's response to a ConnectMessage. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub enum ConnectResponse { + Success { protocol_version: u16 }, + Error { protocol_version: u16, reason: String }, +} + +impl ConnectMessage { + pub fn new(session_id: Uuid, target: String) -> Self { + Self { + protocol_version: CURRENT_PROTOCOL_VERSION, + session_id, + target, + } + } + + /// Length-prefixed bincode encode and write to an async writer. + pub async fn encode(&self, writer: &mut W) -> Result<(), ProtoError> { + encode_framed(self, writer).await + } + + /// Read and decode a length-prefixed bincode message from an async reader. + pub async fn decode(reader: &mut R) -> Result { + decode_framed(reader).await + } +} + +impl ConnectResponse { + pub fn success() -> Self { + Self::Success { + protocol_version: CURRENT_PROTOCOL_VERSION, + } + } + + pub fn error(reason: impl Into) -> Self { + Self::Error { + protocol_version: CURRENT_PROTOCOL_VERSION, + reason: reason.into(), + } + } + + pub fn is_success(&self) -> bool { + matches!(self, Self::Success { .. }) + } + + /// Length-prefixed bincode encode and write to an async writer. + pub async fn encode(&self, writer: &mut W) -> Result<(), ProtoError> { + encode_framed(self, writer).await + } + + /// Read and decode a length-prefixed bincode message from an async reader. + pub async fn decode(reader: &mut R) -> Result { + decode_framed(reader).await + } + + /// Extract the protocol version from any variant. + pub fn protocol_version(&self) -> u16 { + match self { + Self::Success { protocol_version } | Self::Error { protocol_version, .. } => *protocol_version, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn roundtrip_connect_message() { + let msg = ConnectMessage::new(Uuid::new_v4(), "192.168.1.100:3389".to_owned()); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ConnectMessage::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_connect_response_success() { + let msg = ConnectResponse::success(); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ConnectResponse::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } + + #[tokio::test] + async fn roundtrip_connect_response_error() { + let msg = ConnectResponse::error("connection refused"); + + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + + let decoded = ConnectResponse::decode(&mut buf.as_slice()) + .await + .expect("decode should succeed"); + + assert_eq!(msg, decoded); + } +} + +#[cfg(test)] +mod proptests { + use proptest::prelude::*; + + use super::*; + + fn arb_connect_message() -> impl Strategy { + ("[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}:[0-9]{1,5}") + .prop_map(|target| ConnectMessage::new(Uuid::new_v4(), target)) + } + + fn arb_connect_response() -> impl Strategy { + prop_oneof![Just(ConnectResponse::success()), ".*".prop_map(ConnectResponse::error),] + } + + proptest! { + #[test] + fn connect_message_roundtrip(msg in arb_connect_message()) { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().expect("tokio runtime"); + rt.block_on(async { + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + let decoded = ConnectMessage::decode(&mut buf.as_slice()).await.expect("decode should succeed"); + // Compare fields individually because UUID is generated fresh + prop_assert_eq!(&msg.target, &decoded.target); + prop_assert_eq!(msg.protocol_version, decoded.protocol_version); + prop_assert_eq!(msg.session_id, decoded.session_id); + Ok(()) + })?; + } + + #[test] + fn connect_response_roundtrip(msg in arb_connect_response()) { + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().expect("tokio runtime"); + rt.block_on(async { + let mut buf = Vec::new(); + msg.encode(&mut buf).await.expect("encode should succeed"); + let decoded = ConnectResponse::decode(&mut buf.as_slice()).await.expect("decode should succeed"); + prop_assert_eq!(msg, decoded); + Ok(()) + })?; + } + } +} diff --git a/crates/agent-tunnel-proto/src/version.rs b/crates/agent-tunnel-proto/src/version.rs new file mode 100644 index 000000000..dcae0852b --- /dev/null +++ b/crates/agent-tunnel-proto/src/version.rs @@ -0,0 +1,37 @@ +/// Current protocol version. +pub const CURRENT_PROTOCOL_VERSION: u16 = 2; + +/// Minimum protocol version that is still accepted. +pub const MIN_SUPPORTED_VERSION: u16 = 2; + +/// Validate that a received protocol version is within the supported range. +pub fn validate_protocol_version(version: u16) -> Result<(), crate::error::ProtoError> { + if version < MIN_SUPPORTED_VERSION || CURRENT_PROTOCOL_VERSION < version { + return Err(crate::error::ProtoError::UnsupportedVersion { + received: version, + min: MIN_SUPPORTED_VERSION, + max: CURRENT_PROTOCOL_VERSION, + }); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn accept_current_version() { + assert!(validate_protocol_version(CURRENT_PROTOCOL_VERSION).is_ok()); + } + + #[test] + fn reject_zero_version() { + assert!(validate_protocol_version(0).is_err()); + } + + #[test] + fn reject_future_version() { + assert!(validate_protocol_version(CURRENT_PROTOCOL_VERSION + 1).is_err()); + } +} diff --git a/devolutions-agent/Cargo.toml b/devolutions-agent/Cargo.toml index a93df50c8..1c2ef7f4f 100644 --- a/devolutions-agent/Cargo.toml +++ b/devolutions-agent/Cargo.toml @@ -12,8 +12,11 @@ publish = false workspace = true [dependencies] +agent-tunnel-proto = { path = "../crates/agent-tunnel-proto" } anyhow = "1" async-trait = "0.1" +bincode = "1.3" +base64 = "0.22" bytes = "1" camino = { version = "1.1", features = ["serde1"] } ceviche = "0.7" @@ -23,15 +26,22 @@ devolutions-gateway-task = { path = "../crates/devolutions-gateway-task" } devolutions-log = { path = "../crates/devolutions-log" } futures = "0.3" http-client-proxy = { path = "../crates/http-client-proxy" } +ipnetwork = "0.20" parking_lot = "0.12" +quinn = "0.11" rand = "0.8" # FIXME(@CBenoit): maybe we don't need this crate -rustls-pemfile = "2.2" # FIXME(@CBenoit): maybe we don't need this crate +rcgen = { version = "0.13", features = ["pem"] } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2", "socks", "json"] } +rustls = { version = "0.23", default-features = false, features = ["std", "ring"] } +rustls-pemfile = "2.2" +rustls-pki-types = "1" serde_json = "1" serde = { version = "1", features = ["derive"] } tap = "1.0" tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "tls12", "ring"] } tracing = "0.1" url = { version = "2.5", features = ["serde"] } +uuid = { version = "1.17", features = ["v4", "serde"] } [dependencies.ironrdp] version = "0.14" @@ -72,6 +82,7 @@ features = [ "Win32_Foundation", "Win32_Storage_FileSystem", "Win32_Security", + "Win32_System_SystemInformation", "Win32_System_Threading", "Win32_Security_Cryptography", "Win32_Security_Authorization", @@ -82,5 +93,8 @@ features = [ [target.'cfg(windows)'.build-dependencies] embed-resource = "3.0" +[dev-dependencies] +tempfile = "3" + [target.'cfg(windows)'.dev-dependencies] expect-test = "1.5" diff --git a/devolutions-agent/src/config.rs b/devolutions-agent/src/config.rs index 1632fcc60..826117208 100644 --- a/devolutions-agent/src/config.rs +++ b/devolutions-agent/src/config.rs @@ -20,6 +20,7 @@ pub struct Conf { pub remote_desktop: RemoteDesktopConf, pub pedm: dto::PedmConf, pub session: dto::SessionConf, + pub tunnel: dto::TunnelConf, pub proxy: dto::ProxyConf, pub debug: dto::DebugConf, } @@ -48,6 +49,7 @@ impl Conf { remote_desktop, pedm: conf_file.pedm.clone().unwrap_or_default(), session: conf_file.session.clone().unwrap_or_default(), + tunnel: conf_file.tunnel.clone().unwrap_or_default(), proxy: conf_file.proxy.clone().unwrap_or_default(), debug: conf_file.debug.clone().unwrap_or_default(), }) @@ -143,14 +145,14 @@ impl ConfHandle { } } -fn save_config(conf: &dto::ConfFile) -> anyhow::Result<()> { +pub fn save_config(conf: &dto::ConfFile) -> anyhow::Result<()> { let conf_file_path = get_conf_file_path(); let json = serde_json::to_string_pretty(conf).context("failed JSON serialization of configuration")?; std::fs::write(&conf_file_path, json).with_context(|| format!("failed to write file at {conf_file_path}"))?; Ok(()) } -fn get_conf_file_path() -> Utf8PathBuf { +pub fn get_conf_file_path() -> Utf8PathBuf { get_data_dir().join("agent.json") } @@ -273,6 +275,70 @@ pub mod dto { } } + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct TunnelConf { + /// Enable tunnel module + pub enabled: bool, + + /// Gateway QUIC endpoint (e.g., "gateway.example.com:4433") + #[serde(default, skip_serializing_if = "String::is_empty")] + pub gateway_endpoint: String, + + /// Client certificate path (issued during enrollment) + #[serde(skip_serializing_if = "Option::is_none")] + pub client_cert_path: Option, + + /// Client private key path + #[serde(skip_serializing_if = "Option::is_none")] + pub client_key_path: Option, + + /// Gateway CA certificate path + #[serde(skip_serializing_if = "Option::is_none")] + pub gateway_ca_cert_path: Option, + + /// Subnets to advertise (e.g., ["10.0.0.0/8", "192.168.1.0/24"]) + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub advertise_subnets: Vec, + + /// DNS domains to advertise (e.g., ["contoso.local"]). Auto-detected if omitted. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub advertise_domains: Vec, + + /// Whether to auto-detect the machine's DNS domain and add it to advertise_domains (default: true) + #[serde(default = "default_true")] + pub auto_detect_domain: bool, + + /// Heartbeat interval in seconds (default: 60) + #[serde(skip_serializing_if = "Option::is_none")] + pub heartbeat_interval_secs: Option, + + /// Route advertise interval in seconds (default: 30) + #[serde(skip_serializing_if = "Option::is_none")] + pub route_advertise_interval_secs: Option, + } + + fn default_true() -> bool { + true + } + + impl Default for TunnelConf { + fn default() -> Self { + Self { + enabled: false, + gateway_endpoint: String::new(), + client_cert_path: None, + client_key_path: None, + gateway_ca_cert_path: None, + advertise_subnets: Vec::new(), + advertise_domains: Vec::new(), + auto_detect_domain: true, + heartbeat_interval_secs: Some(60), + route_advertise_interval_secs: Some(30), + } + } + } + /// Source of truth for Agent configuration /// /// This struct represents the JSON file used for configuration as close as possible @@ -304,6 +370,10 @@ pub mod dto { #[serde(default, skip_serializing_if = "Option::is_none")] pub session: Option, + /// Agent Tunnel configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub tunnel: Option, + /// HTTP/SOCKS proxy configuration for outbound requests #[serde(skip_serializing_if = "Option::is_none")] pub proxy: Option, @@ -330,6 +400,7 @@ pub mod dto { proxy: None, debug: None, session: Some(SessionConf { enabled: false }), + tunnel: None, rest: serde_json::Map::new(), } } diff --git a/devolutions-agent/src/domain_detect.rs b/devolutions-agent/src/domain_detect.rs new file mode 100644 index 000000000..33f4f8423 --- /dev/null +++ b/devolutions-agent/src/domain_detect.rs @@ -0,0 +1,112 @@ +//! Auto-detection of the machine's DNS domain for agent tunnel domain advertisement. + +/// Attempts to detect the DNS domain this machine belongs to. +/// +/// Returns `None` if detection fails or the result is clearly not a valid domain +/// (e.g., ISP domain, empty string, single-label name). +pub fn detect_domain() -> Option { + let raw = detect_domain_raw()?; + let trimmed = raw.trim().trim_end_matches('.').to_ascii_lowercase(); + if is_plausible_domain(&trimmed) { + Some(trimmed) + } else { + None + } +} + +/// Returns `true` if the detected domain looks like a legitimate internal domain +/// (not a TLD, has at least two labels, all labels non-empty). +fn is_plausible_domain(domain: &str) -> bool { + let trimmed = domain.trim_end_matches('.'); + if trimmed.is_empty() { + return false; + } + let mut parts = trimmed.split('.'); + parts.next().is_some_and(|l| !l.is_empty()) && parts.next().is_some_and(|l| !l.is_empty()) +} + +#[cfg(target_os = "windows")] +fn detect_domain_raw() -> Option { + // Try USERDNSDOMAIN first (available in user logon sessions) + if let Ok(domain) = std::env::var("USERDNSDOMAIN") + && !domain.is_empty() + { + return Some(domain); + } + + // Fallback: GetComputerNameExW(ComputerNameDnsDomain) + // This works in SYSTEM service context where USERDNSDOMAIN is empty. + detect_domain_via_computer_name() +} + +#[cfg(target_os = "windows")] +fn detect_domain_via_computer_name() -> Option { + use windows::Win32::System::SystemInformation::{ComputerNameDnsDomain, GetComputerNameExW}; + use windows::core::PWSTR; + + // First call: get required buffer size. Expected to fail with ERROR_MORE_DATA. + let mut size = 0u32; + + // SAFETY: Passing null buffer with zero size to query required length. + // GetComputerNameExW writes the required size to `size` and returns ERROR_MORE_DATA. + let _ = unsafe { GetComputerNameExW(ComputerNameDnsDomain, None, &mut size) }; + + if size == 0 { + return None; + } + + let mut buf = vec![0u16; size as usize]; + + // SAFETY: `buf` is allocated with `size` elements. GetComputerNameExW writes at most + // `size` wide chars and updates `size` to the actual length (excluding null terminator). + let result = unsafe { GetComputerNameExW(ComputerNameDnsDomain, Some(PWSTR(buf.as_mut_ptr())), &mut size) }; + + if result.is_err() { + return None; + } + + let domain = String::from_utf16_lossy(&buf[..size as usize]); + + if domain.is_empty() { None } else { Some(domain) } +} + +#[cfg(not(target_os = "windows"))] +fn detect_domain_raw() -> Option { + let content = std::fs::read_to_string("/etc/resolv.conf").ok()?; + for line in content.lines() { + let line = line.trim(); + if let Some(rest) = line.strip_prefix("search ").or_else(|| line.strip_prefix("domain ")) + && let Some(domain) = rest.split_whitespace().next() + && !domain.is_empty() + { + return Some(domain.to_owned()); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn plausible_domain_accepts_typical_ad_domain() { + assert!(is_plausible_domain("contoso.local")); + assert!(is_plausible_domain("corp.contoso.com")); + assert!(is_plausible_domain("ad.it-help.ninja")); + } + + #[test] + fn plausible_domain_rejects_garbage() { + assert!(!is_plausible_domain("")); + assert!(!is_plausible_domain("local")); + assert!(!is_plausible_domain("com")); + assert!(!is_plausible_domain(".")); + assert!(!is_plausible_domain("..")); + } + + #[test] + fn plausible_domain_handles_trailing_dot() { + assert!(is_plausible_domain("contoso.local.")); + } +} diff --git a/devolutions-agent/src/enrollment.rs b/devolutions-agent/src/enrollment.rs new file mode 100644 index 000000000..98d1a9a3a --- /dev/null +++ b/devolutions-agent/src/enrollment.rs @@ -0,0 +1,192 @@ +//! Agent enrollment logic for QUIC tunnel. +//! +//! This module handles the enrollment process where an agent registers with +//! the Gateway and receives its client certificate and configuration. + +use anyhow::Context as _; +use camino::{Utf8Path, Utf8PathBuf}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::config; + +/// Request body for enrollment API +#[derive(Serialize)] +struct EnrollRequest { + /// Friendly name for the agent + agent_name: String, + /// PEM-encoded Certificate Signing Request + csr_pem: String, +} + +/// Response from enrollment API +#[derive(Deserialize)] +struct EnrollResponse { + agent_id: Uuid, + agent_name: String, + client_cert_pem: String, + gateway_ca_cert_pem: String, + quic_endpoint: String, +} + +#[derive(Debug, Clone)] +pub struct PersistedEnrollment { + pub agent_id: Uuid, + pub agent_name: String, + pub client_cert_path: Utf8PathBuf, + pub client_key_path: Utf8PathBuf, + pub gateway_ca_path: Utf8PathBuf, + pub quic_endpoint: String, +} + +/// Enroll an agent with the Gateway and save the configuration. +/// +/// # Arguments +/// * `gateway_url` - Base Gateway URL (e.g., "https://gateway.example.com:7171") +/// * `enrollment_token` - JWT token for enrollment +/// * `agent_name` - Friendly name for this agent +/// * `advertise_subnets` - List of subnets to advertise (e.g., ["10.0.0.0/8"]) +pub async fn enroll_agent( + gateway_url: &str, + enrollment_token: &str, + agent_name: &str, + advertise_subnets: Vec, +) -> anyhow::Result<()> { + bootstrap_and_persist(gateway_url, enrollment_token, agent_name, advertise_subnets).await?; + Ok(()) +} + +pub async fn bootstrap_and_persist( + gateway_url: &str, + enrollment_token: &str, + agent_name: &str, + advertise_subnets: Vec, +) -> anyhow::Result { + // Generate key pair and CSR locally — the private key never leaves this machine. + let (key_pem, csr_pem) = generate_key_and_csr(agent_name)?; + + let enroll_response = request_enrollment(gateway_url, enrollment_token, agent_name, &csr_pem).await?; + persist_enrollment_response(advertise_subnets, enroll_response, &key_pem) +} + +/// Generate an ECDSA P-256 key pair and a CSR containing the agent name as CN. +/// +/// Returns `(key_pem, csr_pem)`. The private key stays on the agent; only the +/// CSR is sent to the gateway. +fn generate_key_and_csr(agent_name: &str) -> anyhow::Result<(String, String)> { + let key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).context("generate agent key pair")?; + let key_pem = key_pair.serialize_pem(); + + let mut params = rcgen::CertificateParams::default(); + params.distinguished_name.push(rcgen::DnType::CommonName, agent_name); + + let csr = params.serialize_request(&key_pair).context("generate CSR")?; + let csr_pem = csr.pem().context("encode CSR to PEM")?; + + Ok((key_pem, csr_pem)) +} + +async fn request_enrollment( + gateway_url: &str, + enrollment_token: &str, + agent_name: &str, + csr_pem: &str, +) -> anyhow::Result { + let client = reqwest::Client::new(); + let enroll_url = format!("{}/jet/agent-tunnel/enroll", gateway_url.trim_end_matches('/')); + + let response = client + .post(&enroll_url) + .bearer_auth(enrollment_token) + .json(&EnrollRequest { + agent_name: agent_name.to_owned(), + csr_pem: csr_pem.to_owned(), + }) + .send() + .await + .context("failed to send enrollment request")?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + anyhow::bail!("enrollment failed with status {}: {}", status, error_text); + } + + response.json().await.context("failed to parse enrollment response") +} + +fn persist_enrollment_response( + advertise_subnets: Vec, + enroll_response: EnrollResponse, + key_pem: &str, +) -> anyhow::Result { + let config_path = config::get_conf_file_path(); + let config_dir = config_path + .parent() + .filter(|path| !path.as_str().is_empty()) + .map(Utf8Path::to_owned) + .unwrap_or_else(|| Utf8PathBuf::from(".")); + let cert_dir = config_dir.join("certs"); + + std::fs::create_dir_all(&cert_dir) + .with_context(|| format!("failed to create certificate directory: {}", cert_dir))?; + + let client_cert_path = cert_dir.join(format!("{}-cert.pem", enroll_response.agent_id)); + let client_key_path = cert_dir.join(format!("{}-key.pem", enroll_response.agent_id)); + let gateway_ca_path = cert_dir.join("gateway-ca.pem"); + + // Write the locally-generated private key first (before cert/CA from the network). + std::fs::write(&client_key_path, key_pem) + .with_context(|| format!("failed to write client private key: {}", client_key_path))?; + + std::fs::write(&client_cert_path, &enroll_response.client_cert_pem) + .with_context(|| format!("failed to write client certificate: {}", client_cert_path))?; + + std::fs::write(&gateway_ca_path, &enroll_response.gateway_ca_cert_pem) + .with_context(|| format!("failed to write gateway CA certificate: {}", gateway_ca_path))?; + + // Restrict permissions on cert/key files (owner-only on Unix). + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + let restricted = std::fs::Permissions::from_mode(0o600); + for path in [&client_cert_path, &client_key_path, &gateway_ca_path] { + std::fs::set_permissions(path, restricted.clone()) + .with_context(|| format!("failed to set permissions on {path}"))?; + } + } + + // Load existing config and update only the Tunnel section. + // This preserves other settings (Updater, Session, PEDM, etc.) that may have been + // configured by the MSI installer or admin. + let mut conf_file = config::load_conf_file_or_generate_new().context("failed to load existing configuration")?; + + // Preserve existing domain config from previous enrollment/manual configuration. + let existing_tunnel = conf_file.tunnel.as_ref(); + + let tunnel_conf = config::dto::TunnelConf { + enabled: true, + gateway_endpoint: enroll_response.quic_endpoint.clone(), + client_cert_path: Some(client_cert_path.clone()), + client_key_path: Some(client_key_path.clone()), + gateway_ca_cert_path: Some(gateway_ca_path.clone()), + advertise_subnets, + advertise_domains: existing_tunnel.map(|t| t.advertise_domains.clone()).unwrap_or_default(), + auto_detect_domain: existing_tunnel.map(|t| t.auto_detect_domain).unwrap_or(true), + heartbeat_interval_secs: Some(60), + route_advertise_interval_secs: Some(30), + }; + + conf_file.tunnel = Some(tunnel_conf); + + config::save_config(&conf_file)?; + + Ok(PersistedEnrollment { + agent_id: enroll_response.agent_id, + agent_name: enroll_response.agent_name, + client_cert_path, + client_key_path, + gateway_ca_path, + quic_endpoint: enroll_response.quic_endpoint, + }) +} diff --git a/devolutions-agent/src/lib.rs b/devolutions-agent/src/lib.rs index 71c328f48..304192ef1 100644 --- a/devolutions-agent/src/lib.rs +++ b/devolutions-agent/src/lib.rs @@ -6,8 +6,11 @@ use ctrlc as _; extern crate tracing; pub mod config; +pub mod domain_detect; +pub mod enrollment; pub mod log; pub mod remote_desktop; +pub mod tunnel; #[cfg(windows)] pub mod session_manager; diff --git a/devolutions-agent/src/main.rs b/devolutions-agent/src/main.rs index 9005dd7b3..b2796e0d3 100644 --- a/devolutions-agent/src/main.rs +++ b/devolutions-agent/src/main.rs @@ -2,25 +2,35 @@ #![allow(clippy::print_stdout)] // Used by devolutions-agent library. +use agent_tunnel_proto as _; use anyhow as _; use async_trait as _; +use bincode as _; use camino as _; use devolutions_agent_shared as _; use devolutions_gateway_task as _; use devolutions_log as _; use futures as _; +use http_client_proxy as _; +use ipnetwork as _; use ironrdp as _; use parking_lot as _; +use quinn as _; use rand as _; +use reqwest as _; +use rustls as _; use rustls_pemfile as _; +use rustls_pki_types as _; use serde as _; use serde_json as _; use tap as _; use tokio as _; use tokio_rustls as _; +use url as _; +use uuid as _; #[cfg(windows)] use { - devolutions_pedm as _, hex as _, notify_debouncer_mini as _, reqwest as _, sha2 as _, thiserror as _, uuid as _, + aws_lc_rs as _, devolutions_pedm as _, hex as _, notify_debouncer_mini as _, sha2 as _, thiserror as _, win_api_wrappers as _, windows as _, }; @@ -32,6 +42,8 @@ mod service; use std::env; use std::sync::mpsc; +use anyhow::{Context as _, Result, bail}; +use base64::Engine as _; use ceviche::Service; use ceviche::controller::*; use devolutions_agent::AgentServiceEvent; @@ -42,6 +54,23 @@ use self::service::{AgentService, DESCRIPTION, DISPLAY_NAME, SERVICE_NAME}; const BAD_CONFIG_ERR_CODE: u32 = 1; const START_FAILED_ERR_CODE: u32 = 2; +#[derive(Debug, PartialEq, Eq)] +struct UpCommand { + gateway_url: String, + enrollment_token: String, + agent_name: String, + advertise_subnets: Vec, +} + +#[derive(Debug, serde::Deserialize)] +struct EnrollmentStringPayload { + version: u64, + api_base_url: String, + enrollment_token: String, + #[serde(default)] + name: Option, +} + fn agent_service_main( rx: mpsc::Receiver, _tx: mpsc::Sender, @@ -110,6 +139,85 @@ fn agent_service_main( Service!("agent", agent_service_main); +fn parse_required_value(args: &[String], index: &mut usize, flag: &str) -> Result { + *index += 1; + args.get(*index) + .cloned() + .with_context(|| format!("missing value for {flag}")) +} + +fn parse_advertise_subnets(value: &str) -> Vec { + value + .split(',') + .map(str::trim) + .filter(|subnet| !subnet.is_empty()) + .map(ToOwned::to_owned) + .collect() +} + +fn parse_enrollment_string(value: &str) -> Result { + const PREFIX: &str = "dgw-enroll:v1:"; + + let encoded = value.strip_prefix(PREFIX).context("invalid enrollment string prefix")?; + + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(encoded) + .context("invalid base64 enrollment string")?; + + let payload: EnrollmentStringPayload = + serde_json::from_slice(&decoded).context("invalid enrollment string payload")?; + + if payload.version != 1 { + bail!("unsupported enrollment string version: {}", payload.version); + } + + Ok(payload) +} + +fn parse_up_command_args(args: &[String]) -> Result { + let mut gateway_url = None; + let mut enrollment_token = None; + let mut agent_name = None; + let mut enrollment_string = None; + let mut advertise_subnets = Vec::new(); + + let mut index = 0; + while index < args.len() { + let arg = args[index].as_str(); + + match arg { + "--gateway" => gateway_url = Some(parse_required_value(args, &mut index, "--gateway")?), + "--token" | "--enrollment-token" => enrollment_token = Some(parse_required_value(args, &mut index, arg)?), + "--name" | "--agent-name" => agent_name = Some(parse_required_value(args, &mut index, arg)?), + "--enrollment-string" => enrollment_string = Some(parse_required_value(args, &mut index, arg)?), + "--advertise-routes" | "--advertise-subnets" => { + advertise_subnets.extend(parse_advertise_subnets(&parse_required_value(args, &mut index, arg)?)) + } + unexpected => bail!("unknown argument for up: {unexpected}"), + } + + index += 1; + } + + if let Some(enrollment_string) = enrollment_string { + let payload = parse_enrollment_string(&enrollment_string)?; + + gateway_url.get_or_insert(payload.api_base_url); + enrollment_token.get_or_insert(payload.enrollment_token); + + if agent_name.is_none() { + agent_name = payload.name; + } + } + + Ok(UpCommand { + gateway_url: gateway_url.context("missing required --gateway")?, + enrollment_token: enrollment_token.context("missing required --token")?, + agent_name: agent_name.context("missing required --name")?, + advertise_subnets, + }) +} + fn main() { let mut controller = Controller::new(SERVICE_NAME, DISPLAY_NAME, DESCRIPTION); @@ -152,6 +260,61 @@ fn main() { eprintln!("[ERROR] Agent configuration failed: {e}"); } } + "enroll" => { + let gateway_url = env::args() + .nth(2) + .expect("missing gateway URL (e.g., https://gateway.example.com:7171)"); + let enrollment_token = env::args().nth(3).expect("missing enrollment token"); + let agent_name = env::args().nth(4).expect("missing agent name"); + let subnets_arg = env::args().nth(5).unwrap_or_default(); + + let advertise_subnets: Vec = if subnets_arg.is_empty() { + Vec::new() + } else { + subnets_arg.split(',').map(|s| s.trim().to_owned()).collect() + }; + + let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime"); + rt.block_on(async { + if let Err(e) = devolutions_agent::enrollment::enroll_agent( + &gateway_url, + &enrollment_token, + &agent_name, + advertise_subnets, + ) + .await + { + eprintln!("[ERROR] Enrollment failed: {e:#}"); + std::process::exit(1); + } + }); + } + "up" => { + let args: Vec = env::args().skip(2).collect(); + let command = match parse_up_command_args(&args) { + Ok(command) => command, + Err(error) => { + eprintln!("[ERROR] Invalid up arguments: {error:#}"); + std::process::exit(1); + } + }; + + let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime"); + let result = rt.block_on(async { + devolutions_agent::enrollment::bootstrap_and_persist( + &command.gateway_url, + &command.enrollment_token, + &command.agent_name, + command.advertise_subnets, + ) + .await + }); + + if let Err(error) = result { + eprintln!("[ERROR] Bootstrap failed: {error:#}"); + std::process::exit(1); + } + } _ => { eprintln!("[ERROR] Invalid command: {cmd}"); } @@ -160,3 +323,73 @@ fn main() { let _result = controller.register(service_main_wrapper); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_up_command_args_uses_default_config_path() { + let args = vec![ + "--gateway".to_owned(), + "https://gateway.example.com:7171".to_owned(), + "--token".to_owned(), + "bootstrap-token".to_owned(), + "--name".to_owned(), + "site-a-agent".to_owned(), + "--advertise-routes".to_owned(), + "10.0.0.0/8,192.168.1.0/24".to_owned(), + ]; + + let parsed = parse_up_command_args(&args).expect("parse up args"); + + assert_eq!( + parsed, + UpCommand { + gateway_url: "https://gateway.example.com:7171".to_owned(), + enrollment_token: "bootstrap-token".to_owned(), + agent_name: "site-a-agent".to_owned(), + advertise_subnets: vec!["10.0.0.0/8".to_owned(), "192.168.1.0/24".to_owned()], + } + ); + } + + #[test] + fn parse_up_command_args_accepts_aliases() { + let args = vec![ + "--gateway".to_owned(), + "https://gateway.example.com:7171".to_owned(), + "--enrollment-token".to_owned(), + "bootstrap-token".to_owned(), + "--agent-name".to_owned(), + "site-a-agent".to_owned(), + "--advertise-subnets".to_owned(), + "10.0.0.0/8".to_owned(), + ]; + + let parsed = parse_up_command_args(&args).expect("parse up args"); + + assert_eq!(parsed.advertise_subnets, vec!["10.0.0.0/8".to_owned()]); + } + + #[test] + fn parse_up_command_args_accepts_enrollment_string() { + let payload = serde_json::json!({ + "version": 1, + "api_base_url": "https://gateway.example.com:7171", + "enrollment_token": "bootstrap-token", + "name": "site-a-agent", + }); + let enrollment_string = format!( + "dgw-enroll:v1:{}", + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string()) + ); + let args = vec!["--enrollment-string".to_owned(), enrollment_string]; + + let parsed = parse_up_command_args(&args).expect("parse up args"); + + assert_eq!(parsed.gateway_url, "https://gateway.example.com:7171"); + assert_eq!(parsed.enrollment_token, "bootstrap-token"); + assert_eq!(parsed.agent_name, "site-a-agent"); + } +} diff --git a/devolutions-agent/src/service.rs b/devolutions-agent/src/service.rs index 90dd20f58..276a2e4f6 100644 --- a/devolutions-agent/src/service.rs +++ b/devolutions-agent/src/service.rs @@ -7,6 +7,7 @@ use devolutions_agent::log::AgentLog; use devolutions_agent::remote_desktop::RemoteDesktopTask; #[cfg(windows)] use devolutions_agent::session_manager::SessionManager; +use devolutions_agent::tunnel::TunnelTask; #[cfg(windows)] use devolutions_agent::updater::UpdaterTask; use devolutions_gateway_task::{ChildTask, ShutdownHandle, ShutdownSignal}; @@ -227,7 +228,11 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { let service_event_tx = None; if conf.debug.enable_unstable && conf.remote_desktop.enabled { - tasks.register(RemoteDesktopTask::new(conf_handle)); + tasks.register(RemoteDesktopTask::new(conf_handle.clone())); + } + + if conf.tunnel.enabled { + tasks.register(TunnelTask::new(conf_handle)); } Ok(TasksCtx { diff --git a/devolutions-agent/src/tunnel.rs b/devolutions-agent/src/tunnel.rs new file mode 100644 index 000000000..53e4077e0 --- /dev/null +++ b/devolutions-agent/src/tunnel.rs @@ -0,0 +1,512 @@ +//! QUIC-based Agent Tunnel client implementation (Quinn). +//! +//! This module implements a QUIC client that connects to the Gateway's agent tunnel +//! endpoint, advertises reachable subnets, and handles incoming TCP proxy requests. + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use agent_tunnel_proto::{ConnectMessage, ConnectResponse, ControlMessage}; +use anyhow::{Context as _, bail}; +use async_trait::async_trait; +use devolutions_gateway_task::{ShutdownSignal, Task}; +use ipnetwork::Ipv4Network; +use tokio::net::TcpStream; + +use crate::config::ConfHandle; + +// --------------------------------------------------------------------------- +// Custom TLS verifier: verify cert chain against CA, skip hostname check +// --------------------------------------------------------------------------- + +/// Wraps a `WebPkiServerVerifier` but skips the hostname verification step. +/// +/// For our private PKI, the agent may connect by IP address (e.g., `127.0.0.1`) +/// while the server cert has the gateway's hostname (e.g., `devolutions432`). +/// The cert chain is still validated against our private CA — only the +/// hostname-to-SAN matching is bypassed. +#[derive(Debug)] +struct SkipHostnameVerification(Arc); + +impl rustls::client::danger::ServerCertVerifier for SkipHostnameVerification { + fn verify_server_cert( + &self, + end_entity: &rustls_pki_types::CertificateDer<'_>, + intermediates: &[rustls_pki_types::CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: rustls_pki_types::UnixTime, + ) -> Result { + // Verify the cert chain against our CA, skipping hostname verification. + // We call the inner verifier with a dummy name; if it fails specifically + // because of hostname mismatch (CertNotValidForName), we accept it. + // All other errors (expired cert, unknown CA, bad signature) propagate. + self.0 + .verify_server_cert( + end_entity, + intermediates, + &rustls_pki_types::ServerName::try_from("dummy.local").expect("valid dummy server name"), + ocsp_response, + now, + ) + .or_else(|e| match e { + rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName) => { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + other => Err(other), + }) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &rustls_pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.0.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &rustls_pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.0.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.supported_verify_schemes() + } +} + +// --------------------------------------------------------------------------- +// TunnelTask — service task with auto-reconnect +// --------------------------------------------------------------------------- + +pub struct TunnelTask { + conf_handle: ConfHandle, +} + +impl TunnelTask { + pub fn new(conf_handle: ConfHandle) -> Self { + Self { conf_handle } + } +} + +#[async_trait] +impl Task for TunnelTask { + type Output = anyhow::Result<()>; + const NAME: &'static str = "tunnel"; + + async fn run(self, mut shutdown_signal: ShutdownSignal) -> anyhow::Result<()> { + const INITIAL_BACKOFF: Duration = Duration::from_secs(1); + const MAX_BACKOFF: Duration = Duration::from_secs(60); + const CONNECTED_THRESHOLD: Duration = Duration::from_secs(30); + + info!("Starting QUIC agent tunnel (with auto-reconnect)"); + + let mut backoff = INITIAL_BACKOFF; + + loop { + let start = std::time::Instant::now(); + + match run_single_connection(&self.conf_handle, &mut shutdown_signal).await { + Ok(()) => { + info!("Tunnel task stopped"); + return Ok(()); + } + Err(error) => { + warn!(error = %format!("{error:#}"), "Tunnel connection lost"); + } + } + + if CONNECTED_THRESHOLD < start.elapsed() { + backoff = INITIAL_BACKOFF; + } + + info!(?backoff, "Reconnecting after backoff"); + + tokio::select! { + _ = shutdown_signal.wait() => { + info!("Shutdown during reconnect backoff"); + return Ok(()); + } + _ = tokio::time::sleep(backoff) => {} + } + + let jitter_factor = rand::Rng::gen_range(&mut rand::thread_rng(), 0.75..1.25); + backoff = + Duration::from_secs_f64((backoff.as_secs_f64() * 2.0 * jitter_factor).min(MAX_BACKOFF.as_secs_f64())); + } + } +} + +// --------------------------------------------------------------------------- +// Single connection lifetime +// --------------------------------------------------------------------------- + +/// Run a single QUIC tunnel connection lifetime: config → connect → event loop. +/// +/// Returns `Ok(())` on graceful shutdown (shutdown signal received). +/// Returns `Err(...)` on any failure — the caller should retry with backoff. +async fn run_single_connection(conf_handle: &ConfHandle, shutdown_signal: &mut ShutdownSignal) -> anyhow::Result<()> { + // Ensure rustls crypto provider is installed (ring). + let _ = rustls::crypto::ring::default_provider().install_default(); + + let agent_conf = conf_handle.get_conf(); + let tunnel_conf = &agent_conf.tunnel; + + let cert_path = tunnel_conf + .client_cert_path + .as_ref() + .context("client_cert_path not configured")?; + let key_path = tunnel_conf + .client_key_path + .as_ref() + .context("client_key_path not configured")?; + let ca_path = tunnel_conf + .gateway_ca_cert_path + .as_ref() + .context("gateway_ca_cert_path not configured")?; + + let advertise_subnets: Vec = tunnel_conf + .advertise_subnets + .iter() + .map(|subnet| subnet.parse()) + .collect::, _>>() + .context("failed to parse advertise_subnets")?; + + if advertise_subnets.is_empty() { + warn!("No subnets configured to advertise"); + } + + // Build domain advertisement list: explicit config + auto-detection. + let mut advertise_domains: Vec = tunnel_conf + .advertise_domains + .iter() + .map(|d| agent_tunnel_proto::DomainAdvertisement { + domain: d.clone(), + auto_detected: false, + }) + .collect(); + + if tunnel_conf.auto_detect_domain { + if let Some(detected) = crate::domain_detect::detect_domain() { + if !advertise_domains + .iter() + .any(|d| d.domain.eq_ignore_ascii_case(&detected)) + { + info!(domain = %detected, "Auto-detected DNS domain"); + advertise_domains.push(agent_tunnel_proto::DomainAdvertisement { + domain: detected, + auto_detected: true, + }); + } + } else if tunnel_conf.advertise_domains.is_empty() { + warn!( + "Domain auto-detection found nothing and no advertise_domains configured. \ + Set advertise_domains in agent config." + ); + } + } + + info!( + subnet_count = advertise_subnets.len(), + domain_count = advertise_domains.len(), + domains = ?advertise_domains.iter().map(|d| { + let source = if d.auto_detected { "auto" } else { "explicit" }; + format!("{} ({})", d.domain, source) + }).collect::>(), + "Advertising subnets and domains" + ); + + // -- Build rustls ClientConfig -- + + let certs: Vec> = rustls_pemfile::certs(&mut std::io::BufReader::new( + std::fs::File::open(cert_path.as_str()).context("open client cert file")?, + )) + .collect::, _>>() + .context("parse client certificates")?; + + let key = rustls_pemfile::private_key(&mut std::io::BufReader::new( + std::fs::File::open(key_path.as_str()).context("open client key file")?, + )) + .context("parse private key file")? + .context("no private key found in file")?; + + let mut roots = rustls::RootCertStore::empty(); + let ca_certs: Vec> = rustls_pemfile::certs(&mut std::io::BufReader::new( + std::fs::File::open(ca_path.as_str()).context("open CA cert file")?, + )) + .collect::, _>>() + .context("parse CA certificates")?; + for cert in ca_certs { + roots.add(cert)?; + } + + // Use a custom verifier that validates the cert chain against our private CA + // but skips hostname verification. This is correct for a private PKI where the + // agent connects by IP address but the server cert has the gateway's hostname. + let verifier = rustls::client::WebPkiServerVerifier::builder(Arc::new(roots)) + .build() + .context("build server cert verifier")?; + + let mut client_crypto = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(SkipHostnameVerification(verifier))) + .with_client_auth_cert(certs, key) + .context("build rustls client config with client auth")?; + client_crypto.alpn_protocols = vec![b"devolutions-agent-tunnel".to_vec()]; + + let client_config = quinn::ClientConfig::new(Arc::new( + quinn::crypto::rustls::QuicClientConfig::try_from(client_crypto) + .context("build QuicClientConfig from rustls config")?, + )); + + // -- Transport config -- + + let mut transport = quinn::TransportConfig::default(); + transport.max_idle_timeout(Some( + Duration::from_secs(120).try_into().context("idle timeout conversion")?, + )); + transport.keep_alive_interval(Some(Duration::from_secs(15))); + transport.max_concurrent_bidi_streams(100u32.into()); + + let mut client_config = client_config; + client_config.transport_config(Arc::new(transport)); + + // -- DNS resolve -- + + let gateway_addr = tokio::net::lookup_host(&tunnel_conf.gateway_endpoint) + .await + .context("failed to resolve gateway endpoint")? + .next() + .context("no addresses resolved for gateway endpoint")?; + + info!(gateway_addr = %gateway_addr, "Connecting to gateway"); + + // -- Connect -- + + let mut endpoint = + quinn::Endpoint::client("0.0.0.0:0".parse().context("parse bind address")?).context("create QUIC endpoint")?; + endpoint.set_default_client_config(client_config); + + let connection = endpoint + .connect(gateway_addr, "gateway") + .context("initiate QUIC connection")? + .await + .context("QUIC handshake")?; + info!("QUIC connection established"); + + // -- Open control stream -- + + let (mut ctrl_send, mut ctrl_recv) = connection.open_bi().await.context("open control stream")?; + + // Send initial RouteAdvertise. + let epoch = 1u64; + let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); + msg.encode(&mut ctrl_send) + .await + .context("encode initial RouteAdvertise")?; + info!(epoch, "Sent initial RouteAdvertise"); + + // Spawn control stream reader. + tokio::spawn(async move { + let _ = handle_control_recv(&mut ctrl_recv) + .await + .inspect_err(|e| error!(%e, "Control recv stream failed")); + }); + + // -- Main loop: accept incoming session streams + periodic tasks -- + + let route_interval = tunnel_conf.route_advertise_interval_secs.unwrap_or(30); + let heartbeat_interval_secs = tunnel_conf.heartbeat_interval_secs.unwrap_or(60); + let mut route_tick = tokio::time::interval(Duration::from_secs(route_interval)); + let mut heartbeat_tick = tokio::time::interval(Duration::from_secs(heartbeat_interval_secs)); + // Skip the first immediate tick (we already sent the initial RouteAdvertise). + route_tick.tick().await; + heartbeat_tick.tick().await; + + loop { + tokio::select! { + biased; + + _ = shutdown_signal.wait() => { + info!("Tunnel task shutting down"); + connection.close(0u32.into(), b"shutting down"); + return Ok(()); + } + + _ = route_tick.tick() => { + let msg = ControlMessage::route_advertise(epoch, advertise_subnets.clone(), advertise_domains.clone()); + let _ = msg.encode(&mut ctrl_send).await + .inspect(|_| trace!(epoch, "Sent RouteAdvertise (refresh)")) + .inspect_err(|e| error!(%e, "Failed to send RouteAdvertise")); + } + + _ = heartbeat_tick.tick() => { + let msg = ControlMessage::heartbeat(current_time_millis(), 0); + let _ = msg.encode(&mut ctrl_send).await + .inspect(|_| trace!("Sent Heartbeat")) + .inspect_err(|e| error!(%e, "Failed to send Heartbeat")); + } + + result = connection.accept_bi() => { + let (send, recv) = result.context("accept incoming bidi stream")?; + let subnets = advertise_subnets.clone(); + tokio::spawn(async move { + let _ = handle_session_stream(&subnets, send, recv).await + .inspect_err(|e| error!(%e, "Session stream failed")); + }); + } + } + } +} + +// --------------------------------------------------------------------------- +// Control stream reader +// --------------------------------------------------------------------------- + +async fn handle_control_recv(recv: &mut quinn::RecvStream) -> anyhow::Result<()> { + loop { + let message = ControlMessage::decode(recv).await.context("decode control message")?; + + match message { + ControlMessage::HeartbeatAck { + protocol_version, + timestamp_ms, + } => { + if let Err(e) = agent_tunnel_proto::validate_protocol_version(protocol_version) { + warn!(%protocol_version, %e, "Ignoring HeartbeatAck: unsupported protocol version"); + continue; + } + let rtt = current_time_millis().saturating_sub(timestamp_ms); + debug!(rtt_ms = rtt, "Received HeartbeatAck"); + } + unexpected => { + warn!(message = ?unexpected, "Unexpected control message from gateway"); + } + } + } +} + +// --------------------------------------------------------------------------- +// Session stream handler +// --------------------------------------------------------------------------- + +async fn handle_session_stream( + advertise_subnets: &[Ipv4Network], + mut send: quinn::SendStream, + mut recv: quinn::RecvStream, +) -> anyhow::Result<()> { + // Read ConnectMessage (length-prefixed) directly from the Quinn stream. + let connect_msg = ConnectMessage::decode(&mut recv) + .await + .context("decode ConnectMessage")?; + + info!( + session_id = %connect_msg.session_id, + target = %connect_msg.target, + "Received ConnectMessage" + ); + + if let Err(e) = agent_tunnel_proto::validate_protocol_version(connect_msg.protocol_version) { + warn!( + protocol_version = %connect_msg.protocol_version, + %e, + "Rejecting ConnectMessage: unsupported protocol version" + ); + let response = ConnectResponse::error(format!("unsupported protocol version: {e}")); + response + .encode(&mut send) + .await + .context("send ConnectResponse error for unsupported version")?; + bail!("unsupported protocol version in ConnectMessage"); + } + + // Validate and connect to target. + let candidates = resolve_target_candidates(&connect_msg.target, advertise_subnets).await?; + let (tcp_stream, selected_target) = connect_to_target(&candidates).await?; + info!(target = %selected_target, "TCP connection established"); + + // Send ConnectResponse::Success. + ConnectResponse::success() + .encode(&mut send) + .await + .context("send ConnectResponse")?; + info!("Sent ConnectResponse::Success"); + + // Bidirectional proxy using tokio::io::copy. + let (mut tcp_read, mut tcp_write) = tcp_stream.into_split(); + + let quic_to_tcp = tokio::io::copy(&mut recv, &mut tcp_write); + let tcp_to_quic = tokio::io::copy(&mut tcp_read, &mut send); + + tokio::select! { + r = quic_to_tcp => { + r.inspect_err(|e| debug!(%e, "QUIC->TCP copy ended"))?; + } + r = tcp_to_quic => { + r.inspect_err(|e| debug!(%e, "TCP->QUIC copy ended"))?; + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Utilities (no QUIC involvement) +// --------------------------------------------------------------------------- + +async fn resolve_target_candidates(target: &str, advertise_subnets: &[Ipv4Network]) -> anyhow::Result> { + let resolved: Vec = tokio::net::lookup_host(target) + .await + .with_context(|| format!("resolve target {target}"))? + .collect(); + + if resolved.is_empty() { + bail!("no addresses resolved for target {target}"); + } + + let reachable: Vec = resolved + .into_iter() + .filter(|addr| match addr.ip() { + std::net::IpAddr::V4(ipv4) => advertise_subnets.iter().any(|subnet| subnet.contains(ipv4)), + // TODO: Support IPv6. + std::net::IpAddr::V6(_) => false, + }) + .collect(); + + if reachable.is_empty() { + bail!("target {target} is not in advertised subnets"); + } + + Ok(reachable) +} + +async fn connect_to_target(candidates: &[SocketAddr]) -> anyhow::Result<(TcpStream, SocketAddr)> { + let mut last_error = None; + + for candidate in candidates { + match TcpStream::connect(candidate).await { + Ok(stream) => return Ok((stream, *candidate)), + Err(error) => last_error = Some((candidate, error)), + } + } + + let Some((candidate, error)) = last_error else { + bail!("no target candidates available"); + }; + + Err(error).with_context(|| format!("TCP connect failed for {candidate}")) +} + +fn current_time_millis() -> u64 { + let elapsed = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system time should be after unix epoch"); + + u64::try_from(elapsed.as_millis()).expect("millisecond timestamp should fit in u64") +} diff --git a/devolutions-gateway/Cargo.toml b/devolutions-gateway/Cargo.toml index 40226fb8b..c3a9acff3 100644 --- a/devolutions-gateway/Cargo.toml +++ b/devolutions-gateway/Cargo.toml @@ -35,6 +35,7 @@ terminal-streamer.path = "../crates/terminal-streamer" network-monitor.path = "../crates/network-monitor" sysevent.path = "../crates/sysevent" sysevent-codes.path = "../crates/sysevent-codes" +agent-tunnel-proto.path = "../crates/agent-tunnel-proto" ironrdp-pdu = { version = "0.7", features = ["std"] } ironrdp-core = { version = "0.1", features = ["std"] } ironrdp-rdcleanpath = "0.2" @@ -69,6 +70,11 @@ thiserror = "2" typed-builder = "0.21" backoff = "0.4" bitflags = "2.9" +base64 = "0.22" +bincode = "1.3" +ipnetwork = "0.20" +dashmap = "6.1" +rand = "0.8" # Security, crypto… picky = { version = "7.0.0-rc.15", default-features = false, features = ["jose", "x509", "pkcs12", "time_conversion"] } @@ -81,6 +87,15 @@ x509-cert = { version = "0.2", default-features = false, features = ["std"] } sha2 = "0.10" hex = "0.4" rustls-native-certs = "0.8" +pem = "3.0" +rcgen = { version = "0.13", features = ["pem", "x509-parser"] } +x509-parser = "0.16" + +# QUIC (agent tunnel) +quinn = "0.11" +rustls = { version = "0.23", default-features = false, features = ["ring", "logging", "std", "tls12"] } +rustls-pki-types = "1" +rustls-pemfile = "2" # Logging tracing = "0.1" diff --git a/devolutions-gateway/src/agent_tunnel/cert.rs b/devolutions-gateway/src/agent_tunnel/cert.rs new file mode 100644 index 000000000..8fdda3645 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/cert.rs @@ -0,0 +1,433 @@ +//! CA certificate management for the QUIC agent tunnel. +//! +//! Manages a self-signed CA that issues client certificates to agents during enrollment, +//! and a server certificate for the QUIC listener. + +use std::time::Duration; + +use anyhow::Context as _; +use camino::{Utf8Path, Utf8PathBuf}; +use rcgen::{CertificateParams, DnType, ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose, SanType}; +use sha2::{Digest, Sha256}; +use uuid::Uuid; + +const CA_CERT_FILENAME: &str = "agent-tunnel-ca-cert.pem"; +const CA_KEY_FILENAME: &str = "agent-tunnel-ca-key.pem"; +const SERVER_CERT_FILENAME: &str = "agent-tunnel-server-cert.pem"; +const SERVER_KEY_FILENAME: &str = "agent-tunnel-server-key.pem"; +const CA_VALIDITY_DAYS: u32 = 3650; // ~10 years +const SERVER_CERT_VALIDITY_DAYS: u32 = 365; // 1 year +const AGENT_CERT_VALIDITY_DAYS: u32 = 365; // 1 year + +const CA_COMMON_NAME: &str = "Devolutions Gateway Agent Tunnel CA"; +const CA_ORG_NAME: &str = "Devolutions Inc."; + +/// Build the standard CA `CertificateParams` (same DN every time so that +/// reconstructed certificates match the on-disk CA for chain validation). +fn make_ca_params() -> CertificateParams { + let mut params = CertificateParams::default(); + params.distinguished_name.push(DnType::CommonName, CA_COMMON_NAME); + params.distinguished_name.push(DnType::OrganizationName, CA_ORG_NAME); + params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + params.key_usages.push(KeyUsagePurpose::KeyCertSign); + params.key_usages.push(KeyUsagePurpose::CrlSign); + params.not_before = time::OffsetDateTime::now_utc(); + params.not_after = time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(CA_VALIDITY_DAYS) * 86400); + params +} + +/// Manages the CA used to sign agent client certificates and the QUIC server certificate. +pub struct CaManager { + ca_cert_pem: String, + ca_key_pair: KeyPair, + data_dir: Utf8PathBuf, +} + +/// Bundle returned to a newly enrolled agent (private key never leaves the agent). +pub struct SignedAgentCert { + pub client_cert_pem: String, + pub ca_cert_pem: String, +} + +impl CaManager { + /// Load an existing CA from disk, or generate a new one. + pub fn load_or_generate(data_dir: &Utf8Path) -> anyhow::Result { + let cert_path = data_dir.join(CA_CERT_FILENAME); + let key_path = data_dir.join(CA_KEY_FILENAME); + + if cert_path.exists() && key_path.exists() { + info!(%cert_path, "Loading existing agent tunnel CA"); + let ca_cert_pem = + std::fs::read_to_string(&cert_path).with_context(|| format!("read CA cert from {cert_path}"))?; + let ca_key_pem = + std::fs::read_to_string(&key_path).with_context(|| format!("read CA key from {key_path}"))?; + let ca_key_pair = KeyPair::from_pem(&ca_key_pem).context("parse CA key pair from PEM")?; + Ok(Self { + ca_cert_pem, + ca_key_pair, + data_dir: data_dir.to_owned(), + }) + } else { + info!("Generating new agent tunnel CA certificate"); + let ca_key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).context("generate CA key pair")?; + + let ca_params = make_ca_params(); + let ca_cert = ca_params + .self_signed(&ca_key_pair) + .context("self-sign CA certificate")?; + let ca_cert_pem = ca_cert.pem(); + + std::fs::create_dir_all(data_dir).with_context(|| format!("create data directory {data_dir}"))?; + std::fs::write(&cert_path, &ca_cert_pem).with_context(|| format!("write CA cert to {cert_path}"))?; + std::fs::write(&key_path, ca_key_pair.serialize_pem()) + .with_context(|| format!("write CA key to {key_path}"))?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600)) + .with_context(|| format!("set permissions on {key_path}"))?; + } + + info!(%cert_path, "Agent tunnel CA certificate generated and saved"); + + Ok(Self { + ca_cert_pem, + ca_key_pair, + data_dir: data_dir.to_owned(), + }) + } + } + + /// Reconstruct a `Certificate` object from the stored key pair. + /// + /// The reconstructed cert uses the same DN as the original CA, so the + /// issuer field in signed certificates will match the on-disk CA cert. + fn reconstruct_ca_cert(&self) -> anyhow::Result { + make_ca_params() + .self_signed(&self.ca_key_pair) + .context("reconstruct CA certificate for signing") + } + + /// Sign an agent's CSR, producing a client certificate. + /// + /// The agent generates its own key pair and sends only the CSR. + /// The private key never leaves the agent. + pub fn sign_agent_csr(&self, agent_id: Uuid, agent_name: &str, csr_pem: &str) -> anyhow::Result { + // Parse and verify the CSR (signature check included). + let csr_params = rcgen::CertificateSigningRequestParams::from_pem(csr_pem) + .map_err(|e| anyhow::anyhow!("invalid CSR: {e}"))?; + + // Build our own cert params (we control CN, SAN, EKU, validity — not the CSR). + let mut agent_params = CertificateParams::default(); + agent_params.distinguished_name.push(DnType::CommonName, agent_name); + agent_params + .distinguished_name + .push(DnType::OrganizationName, CA_ORG_NAME); + agent_params.subject_alt_names.push(SanType::Rfc822Name( + format!("urn:uuid:{agent_id}").try_into().context("SAN URI")?, + )); + agent_params + .extended_key_usages + .push(ExtendedKeyUsagePurpose::ClientAuth); + agent_params.not_before = time::OffsetDateTime::now_utc(); + agent_params.not_after = + time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(AGENT_CERT_VALIDITY_DAYS) * 86400); + + // Sign with the CA, embedding the public key from the CSR. + let ca_cert = self.reconstruct_ca_cert()?; + let agent_cert = agent_params + .signed_by(&csr_params.public_key, &ca_cert, &self.ca_key_pair) + .context("sign agent certificate with CA")?; + + info!(%agent_id, %agent_name, "Signed agent CSR and issued client certificate"); + + Ok(SignedAgentCert { + client_cert_pem: agent_cert.pem(), + ca_cert_pem: self.ca_cert_pem.clone(), + }) + } + + /// Ensure a server certificate exists for the QUIC listener (signed by our CA). + /// + /// Returns `(cert_path, key_path)` on disk. + pub fn ensure_server_cert(&self, hostname: &str) -> anyhow::Result<(Utf8PathBuf, Utf8PathBuf)> { + let cert_path = self.data_dir.join(SERVER_CERT_FILENAME); + let key_path = self.data_dir.join(SERVER_KEY_FILENAME); + + if cert_path.exists() && key_path.exists() { + // TODO: check cert expiry and regenerate if near/past expiration (365-day validity). + info!(%cert_path, "Using existing agent tunnel server certificate"); + return Ok((cert_path, key_path)); + } + + info!(%hostname, "Generating agent tunnel server certificate"); + + let server_key_pair = + KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).context("generate server key pair")?; + + let mut server_params = CertificateParams::default(); + server_params.distinguished_name.push(DnType::CommonName, hostname); + server_params + .distinguished_name + .push(DnType::OrganizationName, CA_ORG_NAME); + server_params + .subject_alt_names + .push(SanType::DnsName(hostname.try_into().context("DNS SAN")?)); + server_params + .extended_key_usages + .push(ExtendedKeyUsagePurpose::ServerAuth); + server_params.not_before = time::OffsetDateTime::now_utc(); + server_params.not_after = + time::OffsetDateTime::now_utc() + Duration::from_secs(u64::from(SERVER_CERT_VALIDITY_DAYS) * 86400); + + let ca_cert = self.reconstruct_ca_cert()?; + + let server_cert = server_params + .signed_by(&server_key_pair, &ca_cert, &self.ca_key_pair) + .context("sign server certificate with CA")?; + + std::fs::write(&cert_path, server_cert.pem()).with_context(|| format!("write server cert to {cert_path}"))?; + std::fs::write(&key_path, server_key_pair.serialize_pem()) + .with_context(|| format!("write server key to {key_path}"))?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt as _; + std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600)) + .with_context(|| format!("set permissions on {key_path}"))?; + } + + info!(%cert_path, %hostname, "Server certificate generated and saved"); + + Ok((cert_path, key_path)) + } + + /// Get the CA certificate in PEM format. + pub fn ca_cert_pem(&self) -> &str { + &self.ca_cert_pem + } + + /// Get the CA certificate file path on disk. + pub fn ca_cert_path(&self) -> Utf8PathBuf { + self.data_dir.join(CA_CERT_FILENAME) + } + + /// Build a `rustls::ServerConfig` for the QUIC listener with mTLS client verification. + /// + /// The server certificate is signed by our CA; clients must present a certificate + /// also signed by our CA (mutual TLS). + pub fn build_server_tls_config(&self, hostname: &str) -> anyhow::Result { + use std::io::BufReader; + + use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + + // Ensure rustls crypto provider is installed (ring). + let _ = rustls::crypto::ring::default_provider().install_default(); + + let (cert_path, key_path) = self.ensure_server_cert(hostname)?; + + // Load server certificate chain (server cert + CA cert). + let cert_file = std::fs::File::open(cert_path.as_std_path()) + .with_context(|| format!("open server cert file {cert_path}"))?; + let server_certs: Vec> = rustls_pemfile::certs(&mut BufReader::new(cert_file)) + .collect::, _>>() + .context("parse server certificate PEM")?; + + // Also include CA cert in chain. + let ca_cert_path = self.ca_cert_path(); + let ca_file = std::fs::File::open(ca_cert_path.as_std_path()) + .with_context(|| format!("open CA cert file {ca_cert_path}"))?; + let ca_certs: Vec> = rustls_pemfile::certs(&mut BufReader::new(ca_file)) + .collect::, _>>() + .context("parse CA certificate PEM")?; + + let mut cert_chain = server_certs; + cert_chain.extend(ca_certs.clone()); + + // Load server private key. + let key_file = + std::fs::File::open(key_path.as_std_path()).with_context(|| format!("open server key file {key_path}"))?; + let private_key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut BufReader::new(key_file)) + .context("parse server private key PEM")? + .context("no private key found in PEM file")?; + + // Build root cert store with our CA for client verification. + let mut roots = rustls::RootCertStore::empty(); + for ca_cert in &ca_certs { + roots.add(ca_cert.clone()).context("add CA cert to root store")?; + } + + // Require client certificates signed by our CA. + let client_verifier = rustls::server::WebPkiClientVerifier::builder(roots.into()) + .build() + .context("build client certificate verifier")?; + + let mut tls_config = rustls::ServerConfig::builder() + .with_client_cert_verifier(client_verifier) + .with_single_cert(cert_chain, private_key) + .context("build rustls ServerConfig")?; + + tls_config.alpn_protocols = vec![b"devolutions-agent-tunnel".to_vec()]; + + Ok(tls_config) + } +} + +/// Compute SHA-256 fingerprint of a PEM-encoded certificate (hex string). +pub fn cert_fingerprint_from_pem(pem_str: &str) -> anyhow::Result { + let pem = pem::parse(pem_str).context("parse PEM for fingerprint")?; + let digest = Sha256::digest(pem.contents()); + Ok(hex::encode(digest)) +} + +/// Compute SHA-256 fingerprint of a DER-encoded certificate (hex string). +pub fn cert_fingerprint_from_der(der_bytes: &[u8]) -> String { + let digest = Sha256::digest(der_bytes); + hex::encode(digest) +} + +/// Extract agent_id from a PEM-encoded certificate's SAN (urn:uuid:{id}). +pub fn extract_agent_id_from_pem(pem_str: &str) -> anyhow::Result { + let pem = pem::parse(pem_str).context("parse PEM for agent ID extraction")?; + extract_agent_id_from_der(pem.contents()) +} + +/// Extract the Common Name (CN) from a DER-encoded certificate. +pub fn extract_agent_name_from_der(cert_der: &[u8]) -> anyhow::Result { + let (_, cert) = + x509_parser::parse_x509_certificate(cert_der).map_err(|e| anyhow::anyhow!("parse certificate: {e}"))?; + + for attr in cert.subject().iter_common_name() { + if let Ok(cn) = attr.as_str() { + return Ok(cn.to_owned()); + } + } + + anyhow::bail!("no Common Name found in certificate") +} + +/// Extract agent_id from a DER-encoded certificate's SAN (urn:uuid:{id}). +pub fn extract_agent_id_from_der(der_bytes: &[u8]) -> anyhow::Result { + let (_, cert) = x509_parser::parse_x509_certificate(der_bytes).context("parse X.509 certificate")?; + + for ext in cert.extensions() { + if let x509_parser::extensions::ParsedExtension::SubjectAlternativeName(san) = ext.parsed_extension() { + for name in &san.general_names { + if let x509_parser::extensions::GeneralName::RFC822Name(val) = name + && let Some(uuid_str) = val.strip_prefix("urn:uuid:") + { + return uuid_str.parse().context("parse UUID from SAN"); + } + } + } + } + + anyhow::bail!("no urn:uuid: SAN found in certificate") +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: generate a CSR PEM for testing. + fn generate_test_csr(cn: &str) -> (KeyPair, String) { + let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).expect("generate key pair"); + let mut params = CertificateParams::default(); + params.distinguished_name.push(DnType::CommonName, cn); + let csr = params.serialize_request(&key_pair).expect("serialize CSR"); + (key_pair, csr.pem().expect("CSR to PEM")) + } + + #[test] + fn generate_ca_and_sign_agent_csr() { + let temp_dir = std::env::temp_dir().join(format!("dgw-cert-test-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("temp path should be UTF-8"); + + let ca = CaManager::load_or_generate(&data_dir).expect("CA generation should succeed"); + assert!(ca.ca_cert_pem().contains("BEGIN CERTIFICATE")); + + let agent_id = Uuid::new_v4(); + let (_key_pair, csr_pem) = generate_test_csr("test-agent"); + let signed = ca + .sign_agent_csr(agent_id, "test-agent", &csr_pem) + .expect("sign CSR should succeed"); + + assert!(signed.client_cert_pem.contains("BEGIN CERTIFICATE")); + assert_eq!(signed.ca_cert_pem, ca.ca_cert_pem()); + + // Reload CA from disk. + let ca2 = CaManager::load_or_generate(&data_dir).expect("CA reload should succeed"); + assert_eq!(ca2.ca_cert_pem(), ca.ca_cert_pem()); + + // Sign a CSR from the reloaded CA and verify it works. + let (_key_pair2, csr_pem2) = generate_test_csr("test-agent-2"); + let signed2 = ca2 + .sign_agent_csr(Uuid::new_v4(), "test-agent-2", &csr_pem2) + .expect("sign CSR from reloaded CA should succeed"); + assert!(signed2.client_cert_pem.contains("BEGIN CERTIFICATE")); + + // Fingerprint. + let fp = cert_fingerprint_from_pem(&signed.client_cert_pem).expect("fingerprint should work"); + assert_eq!(fp.len(), 64); // SHA-256 hex = 64 chars + + // Extract agent_id from PEM. + let extracted_id = extract_agent_id_from_pem(&signed.client_cert_pem).expect("agent ID extraction should work"); + assert_eq!(extracted_id, agent_id); + + // Server certificate. + let (server_cert_path, server_key_path) = ca + .ensure_server_cert("test-gateway.local") + .expect("server cert should succeed"); + assert!(server_cert_path.exists()); + assert!(server_key_path.exists()); + + // Cleanup. + let _ = std::fs::remove_dir_all(&temp_dir); + } + + #[test] + fn sign_csr_produces_valid_cert() { + let temp_dir = std::env::temp_dir().join(format!("dgw-cert-test-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("temp path should be UTF-8"); + let ca = CaManager::load_or_generate(&data_dir).expect("CA generation should succeed"); + + let agent_id = Uuid::new_v4(); + let (_key_pair, csr_pem) = generate_test_csr("csr-test-agent"); + + let signed = ca + .sign_agent_csr(agent_id, "csr-test-agent", &csr_pem) + .expect("sign CSR should succeed"); + + assert!(signed.client_cert_pem.contains("BEGIN CERTIFICATE")); + + // Verify the cert contains the agent UUID in SAN. + let extracted_id = + extract_agent_id_from_pem(&signed.client_cert_pem).expect("should extract agent ID from signed cert"); + assert_eq!(extracted_id, agent_id); + + let _ = std::fs::remove_dir_all(&temp_dir); + } + + #[test] + fn sign_csr_rejects_tampered_csr() { + let temp_dir = std::env::temp_dir().join(format!("dgw-cert-test-{}", Uuid::new_v4())); + let data_dir = Utf8PathBuf::from_path_buf(temp_dir.clone()).expect("temp path should be UTF-8"); + let ca = CaManager::load_or_generate(&data_dir).expect("CA generation should succeed"); + + let (_key_pair, csr_pem) = generate_test_csr("tampered-agent"); + + // Decode PEM, flip a byte in the DER, re-encode. + let parsed = pem::parse(&csr_pem).expect("parse CSR PEM"); + let mut der_bytes = parsed.contents().to_vec(); + // Flip a byte near the end (in the signature area). + let len = der_bytes.len(); + der_bytes[len - 2] ^= 0xFF; + let tampered_pem = pem::encode(&pem::Pem::new("CERTIFICATE REQUEST", der_bytes)); + + let result = ca.sign_agent_csr(Uuid::new_v4(), "tampered-agent", &tampered_pem); + assert!(result.is_err(), "tampered CSR should be rejected"); + + let _ = std::fs::remove_dir_all(&temp_dir); + } +} diff --git a/devolutions-gateway/src/agent_tunnel/enrollment_store.rs b/devolutions-gateway/src/agent_tunnel/enrollment_store.rs new file mode 100644 index 000000000..6a27b50e2 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/enrollment_store.rs @@ -0,0 +1,126 @@ +//! In-memory store for one-time enrollment tokens. +//! +//! Tokens are generated by the webapp enrollment-string endpoint and consumed +//! by the agent enrollment endpoint. Each token is single-use and has an expiry. + +use std::time::{SystemTime, UNIX_EPOCH}; + +use dashmap::DashMap; + +/// Default token lifetime if not specified: 1 hour. +const DEFAULT_TOKEN_LIFETIME_SECS: u64 = 3600; + +/// A single enrollment token entry. +#[derive(Debug, Clone)] +pub struct EnrollmentTokenEntry { + /// When this token expires (UNIX timestamp in seconds). + pub expires_at: u64, + /// Optional agent name hint associated with this token. + pub agent_name: Option, +} + +/// Thread-safe in-memory store for one-time enrollment tokens. +/// +/// Tokens are stored in a `DashMap` keyed by the token string. +/// They are removed on consumption (one-time use) or on explicit cleanup. +#[derive(Debug)] +pub struct EnrollmentTokenStore { + tokens: DashMap, +} + +impl EnrollmentTokenStore { + /// Creates a new, empty token store. + pub fn new() -> Self { + Self { tokens: DashMap::new() } + } + + /// Inserts a new enrollment token. + /// + /// Also cleans up any expired tokens to prevent unbounded growth. + pub fn insert(&self, token: String, agent_name: Option, lifetime_secs: Option) { + self.cleanup_expired(); + + let lifetime = lifetime_secs.unwrap_or(DEFAULT_TOKEN_LIFETIME_SECS); + let now = current_time_secs(); + let expires_at = now + lifetime; + + self.tokens + .insert(token, EnrollmentTokenEntry { expires_at, agent_name }); + } + + /// Consumes a token if it exists and is not expired. + /// + /// Returns `true` if the token was valid and has been consumed (removed). + /// Returns `false` if the token doesn't exist or is expired. + pub fn consume(&self, token: &str) -> bool { + let now = current_time_secs(); + + if let Some((_, entry)) = self.tokens.remove(token) + && entry.expires_at > now + { + return true; + } + + false + } + + /// Removes all expired tokens from the store. + pub fn cleanup_expired(&self) { + let now = current_time_secs(); + self.tokens.retain(|_, entry| entry.expires_at > now); + } +} + +impl Default for EnrollmentTokenStore { + fn default() -> Self { + Self::new() + } +} + +fn current_time_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert_and_consume() { + let store = EnrollmentTokenStore::new(); + store.insert("tok-123".to_owned(), Some("my-agent".to_owned()), Some(3600)); + + assert!(store.consume("tok-123")); + // Second consume should fail (one-time use). + assert!(!store.consume("tok-123")); + } + + #[test] + fn consume_nonexistent_returns_false() { + let store = EnrollmentTokenStore::new(); + assert!(!store.consume("does-not-exist")); + } + + #[test] + fn expired_token_not_consumable() { + let store = EnrollmentTokenStore::new(); + // Insert with 0 lifetime → already expired. + store.insert("expired-tok".to_owned(), None, Some(0)); + assert!(!store.consume("expired-tok")); + } + + #[test] + fn cleanup_removes_expired() { + let store = EnrollmentTokenStore::new(); + store.insert("expired".to_owned(), None, Some(0)); + store.insert("valid".to_owned(), None, Some(3600)); + + store.cleanup_expired(); + + assert!(!store.consume("expired")); + assert!(store.consume("valid")); + } +} diff --git a/devolutions-gateway/src/agent_tunnel/listener.rs b/devolutions-gateway/src/agent_tunnel/listener.rs new file mode 100644 index 000000000..eea89ee77 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/listener.rs @@ -0,0 +1,336 @@ +//! QUIC listener for agent tunnel connections (Quinn-based). +//! +//! Manages a QUIC endpoint using Quinn, accepts connections from agents with mTLS, +//! processes control messages (route advertisements, heartbeats), and +//! creates proxy streams on demand. + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use agent_tunnel_proto::{ConnectMessage, ConnectResponse, ControlMessage}; +use anyhow::Context as _; +use async_trait::async_trait; +use dashmap::DashMap; +use uuid::Uuid; + +use super::cert::CaManager; +use super::enrollment_store::EnrollmentTokenStore; +use super::registry::{AgentPeer, AgentRegistry}; +use super::stream::TunnelStream; + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Handle for external code to interact with the running agent tunnel. +/// +/// Cloneable and safe to share across tasks. +#[derive(Clone)] +pub struct AgentTunnelHandle { + registry: Arc, + /// Map of agent_id → live Quinn connection, used for opening new streams. + agent_connections: Arc>, + ca_manager: Arc, + enrollment_token_store: Arc, +} + +impl AgentTunnelHandle { + pub fn registry(&self) -> &AgentRegistry { + &self.registry + } + + pub fn ca_manager(&self) -> &CaManager { + &self.ca_manager + } + + pub fn enrollment_token_store(&self) -> &EnrollmentTokenStore { + &self.enrollment_token_store + } + + /// Open a proxy stream through a connected agent. + pub async fn connect_via_agent( + &self, + agent_id: Uuid, + session_id: Uuid, + target: &str, + ) -> anyhow::Result { + let conn = self + .agent_connections + .get(&agent_id) + .map(|entry| entry.value().clone()) + .ok_or_else(|| anyhow::anyhow!("agent {} not connected", agent_id))?; + + let (mut send, mut recv) = conn.open_bi().await.context("open bidirectional stream to agent")?; + + // Send ConnectMessage. + let connect_msg = ConnectMessage::new(session_id, target.to_owned()); + connect_msg + .encode(&mut send) + .await + .map_err(|e| anyhow::anyhow!("encode ConnectMessage: {e}"))?; + + // Read ConnectResponse. + let response = ConnectResponse::decode(&mut recv) + .await + .map_err(|e| anyhow::anyhow!("decode ConnectResponse: {e}"))?; + + if !response.is_success() { + let reason = match &response { + ConnectResponse::Error { reason, .. } => reason.clone(), + _ => "unknown".to_owned(), + }; + anyhow::bail!("agent refused connection: {reason}"); + } + + info!( + %agent_id, + %session_id, + %target, + "Proxy stream established via agent tunnel" + ); + + Ok(TunnelStream { send, recv }) + } +} + +// --------------------------------------------------------------------------- +// Listener task +// --------------------------------------------------------------------------- + +pub struct AgentTunnelListener { + endpoint: quinn::Endpoint, + registry: Arc, + agent_connections: Arc>, +} + +impl AgentTunnelListener { + pub async fn bind( + listen_addr: SocketAddr, + ca_manager: Arc, + hostname: &str, + ) -> anyhow::Result<(Self, AgentTunnelHandle)> { + let tls_config = ca_manager + .build_server_tls_config(hostname) + .context("build server TLS config")?; + + let quic_server_config = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config)) + .context("create QUIC server config from TLS config")?; + + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_server_config)); + + // Configure transport parameters. + let mut transport = quinn::TransportConfig::default(); + transport.max_idle_timeout(Some( + Duration::from_secs(120) + .try_into() + .expect("120s should be a valid idle timeout"), + )); + transport.keep_alive_interval(Some(Duration::from_secs(15))); + transport.max_concurrent_bidi_streams(100u32.into()); + server_config.transport_config(Arc::new(transport)); + + let endpoint = quinn::Endpoint::server(server_config, listen_addr) + .with_context(|| format!("bind QUIC endpoint on {listen_addr}"))?; + + info!(%listen_addr, "Agent tunnel QUIC endpoint bound"); + + let registry = Arc::new(AgentRegistry::new()); + let agent_connections: Arc> = Arc::new(DashMap::new()); + let enrollment_token_store = Arc::new(EnrollmentTokenStore::new()); + + let handle = AgentTunnelHandle { + registry: Arc::clone(®istry), + agent_connections: Arc::clone(&agent_connections), + ca_manager, + enrollment_token_store, + }; + + let listener = Self { + endpoint, + registry, + agent_connections, + }; + + Ok((listener, handle)) + } +} + +#[async_trait] +impl devolutions_gateway_task::Task for AgentTunnelListener { + type Output = anyhow::Result<()>; + const NAME: &'static str = "agent-tunnel-listener"; + + async fn run(self, mut shutdown_signal: devolutions_gateway_task::ShutdownSignal) -> anyhow::Result<()> { + let local_addr = self.endpoint.local_addr()?; + info!(%local_addr, "Agent tunnel listener started"); + + loop { + tokio::select! { + biased; + + _ = shutdown_signal.wait() => { + info!("Agent tunnel listener shutting down"); + self.endpoint.close(0u32.into(), b"shutdown"); + break; + } + + incoming = self.endpoint.accept() => { + let Some(incoming) = incoming else { + info!("QUIC endpoint closed"); + break; + }; + + let registry = Arc::clone(&self.registry); + let agent_connections = Arc::clone(&self.agent_connections); + + tokio::spawn(async move { + if let Err(e) = handle_agent_connection(registry, agent_connections, incoming).await { + warn!(error = format!("{e:#}"), "Agent connection handler failed"); + } + }); + } + } + } + + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Per-connection handler +// --------------------------------------------------------------------------- + +async fn handle_agent_connection( + registry: Arc, + agent_connections: Arc>, + incoming: quinn::Incoming, +) -> anyhow::Result<()> { + let peer_addr = incoming.remote_address(); + info!(%peer_addr, "Accepting new QUIC connection"); + + let conn = incoming.await.context("QUIC handshake failed")?; + + // Extract peer certificate to identify the agent. + let peer_identity = conn.peer_identity().context("no peer identity after handshake")?; + + let peer_certs = peer_identity + .downcast::>>() + .map_err(|_| anyhow::anyhow!("unexpected peer identity type"))?; + + let peer_cert_der = peer_certs.first().context("no peer certificate in chain")?; + + let agent_id = + super::cert::extract_agent_id_from_der(peer_cert_der).context("extract agent_id from peer certificate")?; + + let agent_name = + super::cert::extract_agent_name_from_der(peer_cert_der).unwrap_or_else(|_| format!("agent-{agent_id}")); + + let fingerprint = super::cert::cert_fingerprint_from_der(peer_cert_der); + + info!(%agent_id, %agent_name, %peer_addr, "Agent authenticated via mTLS"); + + let peer = Arc::new(AgentPeer::new(agent_id, agent_name, fingerprint)); + registry.register(Arc::clone(&peer)); + agent_connections.insert(agent_id, conn.clone()); + + // Accept the first bidirectional stream as the control stream. + let control_result = handle_control_stream(&conn, agent_id, ®istry).await; + + // Agent disconnected — clean up. + info!(%agent_id, "Agent QUIC connection closed"); + registry.unregister(&agent_id); + agent_connections.remove(&agent_id); + + control_result +} + +async fn handle_control_stream( + conn: &quinn::Connection, + agent_id: Uuid, + registry: &AgentRegistry, +) -> anyhow::Result<()> { + let (mut control_send, mut control_recv) = conn.accept_bi().await.context("accept control stream")?; + + info!(%agent_id, "Control stream accepted"); + + loop { + tokio::select! { + // Read control messages from the agent. + msg_result = ControlMessage::decode(&mut control_recv) => { + let msg = match msg_result { + Ok(msg) => msg, + Err(agent_tunnel_proto::ProtoError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + debug!(%agent_id, "Control stream EOF"); + break; + } + Err(e) => { + warn!(%agent_id, error = %e, "Control stream decode error"); + break; + } + }; + + handle_control_message(registry, agent_id, &mut control_send, msg).await; + } + + // Detect connection close. + reason = conn.closed() => { + debug!(%agent_id, ?reason, "QUIC connection closed"); + break; + } + } + } + + Ok(()) +} + +async fn handle_control_message( + registry: &AgentRegistry, + agent_id: Uuid, + control_send: &mut quinn::SendStream, + msg: ControlMessage, +) { + match msg { + ControlMessage::RouteAdvertise { + protocol_version, + epoch, + subnets, + domains, + .. + } => { + if let Err(e) = agent_tunnel_proto::validate_protocol_version(protocol_version) { + warn!(%agent_id, %protocol_version, %e, "Rejecting route advertisement: unsupported protocol version"); + return; + } + info!( + %agent_id, + epoch, + subnet_count = subnets.len(), + domain_count = domains.len(), + "Received route advertisement" + ); + if let Some(peer) = registry.get(&agent_id) { + peer.update_routes(epoch, subnets, domains); + peer.touch(); + } + } + ControlMessage::Heartbeat { + timestamp_ms, + active_stream_count, + .. + } => { + debug!(%agent_id, timestamp_ms, active_stream_count, "Received heartbeat"); + if let Some(peer) = registry.get(&agent_id) { + peer.touch(); + } + + let ack = ControlMessage::heartbeat_ack(timestamp_ms); + if let Err(e) = ack.encode(control_send).await { + warn!(%agent_id, error = %e, "Failed to send heartbeat ack"); + } + } + ControlMessage::HeartbeatAck { .. } => { + debug!(%agent_id, "Unexpected HeartbeatAck from agent"); + } + } +} diff --git a/devolutions-gateway/src/agent_tunnel/mod.rs b/devolutions-gateway/src/agent_tunnel/mod.rs new file mode 100644 index 000000000..aa4b094eb --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/mod.rs @@ -0,0 +1,15 @@ +//! QUIC-based agent tunnel (Quinn). +//! +//! Provides a reliable, multiplexed tunnel between the gateway and remote agents +//! using QUIC with mutual TLS authentication. + +pub mod cert; +pub mod enrollment_store; +pub mod listener; +pub mod registry; +pub mod stream; + +pub use enrollment_store::EnrollmentTokenStore; +pub use listener::{AgentTunnelHandle, AgentTunnelListener}; +pub use registry::AgentRegistry; +pub use stream::TunnelStream; diff --git a/devolutions-gateway/src/agent_tunnel/registry.rs b/devolutions-gateway/src/agent_tunnel/registry.rs new file mode 100644 index 000000000..439e183fc --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/registry.rs @@ -0,0 +1,773 @@ +use std::net::IpAddr; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use agent_tunnel_proto::DomainAdvertisement; +use dashmap::DashMap; +use ipnetwork::Ipv4Network; +use parking_lot::RwLock; +use serde::Serialize; +use uuid::Uuid; + +/// Duration after which an agent is considered offline if no heartbeat has been received. +pub const AGENT_OFFLINE_TIMEOUT: Duration = Duration::from_secs(90); + +/// Tracks route advertisements received from an agent. +/// +/// The epoch-based update protocol works as follows: +/// - A higher epoch replaces the entire route set (new process or config reload). +/// - The same epoch only refreshes `updated_at` (periodic re-advertisement). +#[derive(Debug, Clone)] +pub struct RouteAdvertisementState { + /// Monotonically increasing epoch within an agent process lifetime. + pub epoch: u64, + /// IPv4 subnets this agent can reach. + pub subnets: Vec, + /// DNS domains this agent can resolve, with source tracking. + pub domains: Vec, + /// When this route set was first received (used for tie-breaking). + pub received_at: SystemTime, + /// Last time this route set was refreshed. + pub updated_at: SystemTime, +} + +/// Represents a QUIC-connected agent peer tracked by the gateway. +#[derive(Debug)] +pub struct AgentPeer { + /// Unique identifier for this agent. + pub agent_id: Uuid, + /// Human-readable name of the agent. + pub name: String, + /// SHA-256 fingerprint of the agent's client certificate. + pub cert_fingerprint: String, + /// Last heartbeat timestamp in milliseconds since UNIX epoch (updated atomically). + pub(crate) last_seen: AtomicU64, + /// Current route advertisement state, if any. + route_state: RwLock>, +} + +impl AgentPeer { + /// Creates a new agent peer with the current time as last_seen. + pub fn new(agent_id: Uuid, name: String, cert_fingerprint: String) -> Self { + let now_ms = current_time_millis(); + Self { + agent_id, + name, + cert_fingerprint, + last_seen: AtomicU64::new(now_ms), + route_state: RwLock::new(None), + } + } + + /// Updates the last-seen timestamp to the current time. + pub fn touch(&self) { + let now_ms = current_time_millis(); + self.last_seen.store(now_ms, Ordering::Release); + } + + /// Returns the last-seen timestamp as milliseconds since UNIX epoch. + pub fn last_seen_ms(&self) -> u64 { + self.last_seen.load(Ordering::Acquire) + } + + /// Checks whether this agent is considered online. + /// + /// An agent is online if the elapsed time since `last_seen` is less than `timeout`. + pub fn is_online(&self, timeout: Duration) -> bool { + let last_ms = self.last_seen.load(Ordering::Acquire); + let now_ms = current_time_millis(); + // Saturating subtraction handles clock skew gracefully. + let elapsed_ms = now_ms.saturating_sub(last_ms); + elapsed_ms < u64::try_from(timeout.as_millis()).expect("timeout in milliseconds should fit in u64") + } + + /// Returns a clone of the current route advertisement state, if any. + pub fn route_state(&self) -> Option { + self.route_state.read().clone() + } + + /// Updates the route advertisement state using epoch-based logic. + /// + /// - If `epoch` is greater than the current epoch, the route set is replaced entirely + /// and both `received_at` and `updated_at` are set to now. + /// - If `epoch` equals the current epoch, only `updated_at` is refreshed (re-advertisement). + /// - If `epoch` is less than the current epoch, the update is ignored (stale). + pub fn update_routes(&self, epoch: u64, subnets: Vec, domains: Vec) { + let mut state = self.route_state.write(); + let now = SystemTime::now(); + + match state.as_ref() { + Some(current) if epoch < current.epoch => { + // Stale epoch; ignore. + debug!( + agent_id = %self.agent_id, + received_epoch = epoch, + current_epoch = current.epoch, + "Ignoring stale route advertisement" + ); + } + Some(current) if epoch == current.epoch => { + // Same epoch: refresh timestamp only, do not replace subnets or domains. + debug!( + agent_id = %self.agent_id, + epoch, + subnet_count = subnets.len(), + domain_count = current.domains.len(), + "Refreshing route advertisement (same epoch)" + ); + *state = Some(RouteAdvertisementState { + epoch, + subnets: current.subnets.clone(), + domains: current.domains.clone(), + received_at: current.received_at, + updated_at: now, + }); + } + _ => { + // New epoch (or first advertisement): replace everything. + info!( + agent_id = %self.agent_id, + epoch, + subnet_count = subnets.len(), + domain_count = domains.len(), + "Accepted new route advertisement" + ); + *state = Some(RouteAdvertisementState { + epoch, + subnets, + domains, + received_at: now, + updated_at: now, + }); + } + } + } + + /// Returns `true` if this agent can route traffic to the given IP address. + pub fn can_reach(&self, target_ip: IpAddr) -> bool { + self.route_state + .read() + .as_ref() + .map(|route_state| match target_ip { + IpAddr::V4(ipv4) => route_state.subnets.iter().any(|subnet| subnet.contains(ipv4)), + IpAddr::V6(_) => false, + }) + .unwrap_or(false) + } +} + +/// Thread-safe registry of online QUIC-connected agents. +/// +/// Agents are indexed by their `Uuid`. The registry supports concurrent reads and writes +/// through `DashMap`, and provides route-based agent lookup for proxy target resolution. +#[derive(Debug, Clone)] +pub struct AgentRegistry { + agents: Arc>>, +} + +impl AgentRegistry { + /// Creates a new, empty agent registry. + pub fn new() -> Self { + Self { + agents: Arc::new(DashMap::new()), + } + } + + /// Registers a new agent peer. If an agent with the same ID already exists, it is replaced. + pub fn register(&self, peer: Arc) { + info!( + agent_id = %peer.agent_id, + name = %peer.name, + "Agent registered" + ); + self.agents.insert(peer.agent_id, peer); + } + + /// Removes an agent from the registry by ID. + pub fn unregister(&self, agent_id: &Uuid) -> Option> { + let removed = self.agents.remove(agent_id).map(|(_, peer)| peer); + if let Some(ref peer) = removed { + info!( + agent_id = %peer.agent_id, + name = %peer.name, + "Agent unregistered" + ); + } + removed + } + + /// Looks up an agent by ID. + pub fn get(&self, agent_id: &Uuid) -> Option> { + self.agents.get(agent_id).map(|entry| Arc::clone(entry.value())) + } + + /// Returns the number of agents currently in the registry (including offline ones). + pub fn len(&self) -> usize { + self.agents.len() + } + + /// Returns `true` when no agent is registered. + pub fn is_empty(&self) -> bool { + self.agents.is_empty() + } + + /// Returns the number of agents considered online. + pub fn online_count(&self) -> usize { + self.agents + .iter() + .filter(|entry| entry.value().is_online(AGENT_OFFLINE_TIMEOUT)) + .count() + } + + /// Finds all online agents whose advertised subnets include the given target IP. + /// + /// Results are sorted by `received_at` in descending order (most recently received first). + pub fn find_agents_for_target(&self, target_ip: IpAddr) -> Vec> { + let mut candidates: Vec<(SystemTime, Arc)> = self + .agents + .iter() + .filter(|entry| entry.value().is_online(AGENT_OFFLINE_TIMEOUT)) + .filter_map(|entry| { + let agent = Arc::clone(entry.value()); + let route_state = agent.route_state()?; + let matches = match target_ip { + IpAddr::V4(ipv4) => route_state.subnets.iter().any(|subnet| subnet.contains(ipv4)), + IpAddr::V6(_) => false, + }; + + if matches { + Some((route_state.received_at, agent)) + } else { + None + } + }) + .collect(); + + // Sort by received_at descending (most recent first). + candidates.sort_by(|a, b| b.0.cmp(&a.0)); + + candidates.into_iter().map(|(_, agent)| agent).collect() + } + + /// Selects a single online agent that can route to the given target IP. + /// + /// When multiple agents match, the one with the most recent `received_at` wins. + pub fn select_agent_for_target(&self, target_ip: IpAddr) -> Option> { + self.find_agents_for_target(target_ip).into_iter().next() + } + + /// Finds all online agents whose advertised domains match the given hostname via suffix match. + /// + /// Uses longest suffix match: if agent-A advertises "contoso.local" and agent-B advertises + /// "finance.contoso.local", hostname "db01.finance.contoso.local" matches agent-B only. + /// + /// Results are sorted by `received_at` descending (most recently received first). + pub fn select_agents_for_domain(&self, hostname: &str) -> Vec> { + let hostname_lower = hostname.to_ascii_lowercase(); + + let mut best_suffix_len: usize = 0; + let mut candidates: Vec<(SystemTime, Arc)> = Vec::new(); + + for entry in self.agents.iter() { + let agent = entry.value(); + if !agent.is_online(AGENT_OFFLINE_TIMEOUT) { + continue; + } + + let route_state = match agent.route_state() { + Some(rs) => rs, + None => continue, + }; + + for domain_adv in &route_state.domains { + let domain_lower = domain_adv.domain.to_ascii_lowercase(); + let matches = hostname_lower == domain_lower + || (hostname_lower.len() > domain_lower.len() + && hostname_lower.as_bytes()[hostname_lower.len() - domain_lower.len() - 1] == b'.' + && hostname_lower.ends_with(domain_lower.as_str())); + + if matches { + if best_suffix_len < domain_lower.len() { + best_suffix_len = domain_lower.len(); + candidates.clear(); + candidates.push((route_state.received_at, Arc::clone(agent))); + } else if domain_lower.len() == best_suffix_len { + candidates.push((route_state.received_at, Arc::clone(agent))); + } + } + } + } + + candidates.sort_by(|a, b| b.0.cmp(&a.0)); + candidates.into_iter().map(|(_, agent)| agent).collect() + } + + /// Returns information about a single agent by ID. + pub fn agent_info(&self, agent_id: &Uuid) -> Option { + self.agents.get(agent_id).map(|entry| AgentInfo::from(entry.value())) + } + + /// Collects information about all registered agents for API responses. + pub fn agent_infos(&self) -> Vec { + self.agents.iter().map(|entry| AgentInfo::from(entry.value())).collect() + } +} + +impl Default for AgentRegistry { + fn default() -> Self { + Self::new() + } +} + +/// Domain info with source tracking for API responses. +#[derive(Debug, Clone, Serialize)] +pub struct DomainInfo { + pub domain: String, + pub auto_detected: bool, +} + +/// Serializable snapshot of an agent's state, suitable for API responses. +#[derive(Debug, Clone, Serialize)] +pub struct AgentInfo { + pub agent_id: Uuid, + pub name: String, + pub cert_fingerprint: String, + pub is_online: bool, + pub last_seen_ms: u64, + pub subnets: Vec, + pub domains: Vec, + pub route_epoch: Option, +} + +impl From<&Arc> for AgentInfo { + fn from(agent: &Arc) -> Self { + let route_state = agent.route_state(); + Self { + agent_id: agent.agent_id, + name: agent.name.clone(), + cert_fingerprint: agent.cert_fingerprint.clone(), + is_online: agent.is_online(AGENT_OFFLINE_TIMEOUT), + last_seen_ms: agent.last_seen_ms(), + subnets: route_state + .as_ref() + .map(|rs| rs.subnets.iter().map(ToString::to_string).collect()) + .unwrap_or_default(), + domains: route_state + .as_ref() + .map(|rs| { + rs.domains + .iter() + .map(|d| DomainInfo { + domain: d.domain.clone(), + auto_detected: d.auto_detected, + }) + .collect() + }) + .unwrap_or_default(), + route_epoch: route_state.as_ref().map(|rs| rs.epoch), + } + } +} + +/// Returns the current time as milliseconds since UNIX epoch. +fn current_time_millis() -> u64 { + u64::try_from( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_millis(), + ) + .expect("millisecond timestamp should fit in u64") +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_peer(name: &str) -> Arc { + Arc::new(AgentPeer::new( + Uuid::new_v4(), + String::from(name), + String::from("sha256:deadbeef"), + )) + } + + #[test] + fn register_and_lookup() { + let registry = AgentRegistry::new(); + let peer = make_peer("test-agent"); + let agent_id = peer.agent_id; + + registry.register(Arc::clone(&peer)); + assert_eq!(registry.len(), 1); + + let found = registry.get(&agent_id).expect("agent should be found"); + assert_eq!(found.agent_id, agent_id); + } + + #[test] + fn unregister_removes_agent() { + let registry = AgentRegistry::new(); + let peer = make_peer("test-agent"); + let agent_id = peer.agent_id; + + registry.register(Arc::clone(&peer)); + let removed = registry.unregister(&agent_id); + assert!(removed.is_some()); + assert_eq!(registry.len(), 0); + assert!(registry.get(&agent_id).is_none()); + } + + #[test] + fn is_online_within_timeout() { + let peer = make_peer("online-agent"); + peer.touch(); + assert!(peer.is_online(AGENT_OFFLINE_TIMEOUT)); + } + + #[test] + fn is_offline_after_timeout() { + let peer = AgentPeer::new( + Uuid::new_v4(), + String::from("offline-agent"), + String::from("sha256:deadbeef"), + ); + // Simulate a very old last_seen timestamp. + peer.last_seen.store(0, Ordering::Release); + assert!(!peer.is_online(AGENT_OFFLINE_TIMEOUT)); + } + + #[test] + fn update_routes_new_epoch_replaces() { + let peer = make_peer("route-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![]); + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 1); + assert_eq!(state.subnets.len(), 1); + + let new_subnet: Ipv4Network = "192.168.0.0/16".parse().expect("valid CIDR"); + peer.update_routes(2, vec![new_subnet], vec![]); + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 2); + assert_eq!(state.subnets.len(), 1); + assert_eq!(state.subnets[0], new_subnet); + } + + #[test] + fn update_routes_same_epoch_refreshes_only() { + let peer = make_peer("refresh-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![]); + let state_before = peer.route_state().expect("route state should exist"); + let received_at_before = state_before.received_at; + + // Same epoch with different subnets should NOT replace subnets. + let different_subnet: Ipv4Network = "172.16.0.0/12".parse().expect("valid CIDR"); + peer.update_routes(1, vec![different_subnet], vec![]); + + let state_after = peer.route_state().expect("route state should exist"); + assert_eq!(state_after.epoch, 1); + // Subnets should remain unchanged (original advertisement). + assert_eq!(state_after.subnets[0], subnet); + // received_at should remain unchanged. + assert_eq!(state_after.received_at, received_at_before); + // updated_at should have been refreshed. + assert!(state_after.updated_at >= state_before.updated_at); + } + + #[test] + fn update_routes_stale_epoch_ignored() { + let peer = make_peer("stale-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(5, vec![subnet], vec![]); + let old_subnet: Ipv4Network = "172.16.0.0/12".parse().expect("valid CIDR"); + peer.update_routes(3, vec![old_subnet], vec![]); + + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 5); + assert_eq!(state.subnets[0], subnet); + } + + #[test] + fn can_reach_matching_subnet() { + let peer = make_peer("reachable-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![]); + + assert!(peer.can_reach("10.1.2.3".parse().expect("valid IP"))); + assert!(!peer.can_reach("192.168.1.1".parse().expect("valid IP"))); + } + + #[test] + fn can_reach_returns_false_for_ipv6() { + let peer = make_peer("v4-only-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![]); + + assert!(!peer.can_reach("::1".parse().expect("valid IP"))); + } + + #[test] + fn select_agent_for_target_picks_most_recent() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_a)); + + // Small delay to ensure different received_at timestamps. + std::thread::sleep(Duration::from_millis(10)); + + let agent_b = make_peer("agent-b"); + agent_b.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_b)); + + let target: IpAddr = "10.5.5.5".parse().expect("valid IP"); + let winner = registry.select_agent_for_target(target).expect("should find an agent"); + // agent_b was registered later, so its received_at is more recent. + assert_eq!(winner.agent_id, agent_b.agent_id); + } + + #[test] + fn find_agents_for_target_returns_sorted() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_a)); + + std::thread::sleep(Duration::from_millis(10)); + + let agent_b = make_peer("agent-b"); + agent_b.update_routes(1, vec![subnet], vec![]); + registry.register(Arc::clone(&agent_b)); + + let target: IpAddr = "10.5.5.5".parse().expect("valid IP"); + let agents = registry.find_agents_for_target(target); + assert_eq!(agents.len(), 2); + // Most recent first. + assert_eq!(agents[0].agent_id, agent_b.agent_id); + assert_eq!(agents[1].agent_id, agent_a.agent_id); + } + + #[test] + fn find_agents_excludes_offline() { + let registry = AgentRegistry::new(); + + let agent = make_peer("offline-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent.update_routes(1, vec![subnet], vec![]); + // Force agent to appear offline. + agent.last_seen.store(0, Ordering::Release); + registry.register(agent); + + let target: IpAddr = "10.5.5.5".parse().expect("valid IP"); + let agents = registry.find_agents_for_target(target); + assert!(agents.is_empty()); + } + + #[test] + fn agent_infos_snapshot() { + let registry = AgentRegistry::new(); + let peer = make_peer("info-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![]); + registry.register(peer); + + let infos = registry.agent_infos(); + assert_eq!(infos.len(), 1); + assert_eq!(infos[0].name, "info-agent"); + assert!(infos[0].is_online); + assert_eq!(infos[0].subnets, vec!["10.0.0.0/8"]); + assert_eq!(infos[0].route_epoch, Some(1)); + } + + #[test] + fn online_count_accuracy() { + let registry = AgentRegistry::new(); + + let online_agent = make_peer("online"); + registry.register(Arc::clone(&online_agent)); + + let offline_agent = make_peer("offline"); + offline_agent.last_seen.store(0, Ordering::Release); + registry.register(offline_agent); + + assert_eq!(registry.len(), 2); + assert_eq!(registry.online_count(), 1); + } + + #[test] + fn default_trait_creates_empty_registry() { + let registry = AgentRegistry::default(); + assert_eq!(registry.len(), 0); + } + + // ── Domain routing tests ────────────────────────────────────────── + + fn domain(name: &str, auto: bool) -> DomainAdvertisement { + DomainAdvertisement { + domain: name.to_owned(), + auto_detected: auto, + } + } + + #[test] + fn update_routes_stores_domains_with_source() { + let peer = make_peer("domain-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.domains.len(), 1); + assert_eq!(state.domains[0].domain, "contoso.local"); + assert!(!state.domains[0].auto_detected); + } + + #[test] + fn update_routes_new_epoch_replaces_domains() { + let peer = make_peer("domain-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![domain("old.local", false)]); + peer.update_routes(2, vec![subnet], vec![domain("new.local", true)]); + + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.epoch, 2); + assert_eq!(state.domains[0].domain, "new.local"); + assert!(state.domains[0].auto_detected); + } + + #[test] + fn update_routes_same_epoch_preserves_domains() { + let peer = make_peer("domain-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + peer.update_routes(1, vec![subnet], vec![domain("different.local", true)]); + + let state = peer.route_state().expect("route state should exist"); + assert_eq!(state.domains[0].domain, "contoso.local"); + assert!(!state.domains[0].auto_detected); + } + + #[test] + fn select_agent_for_domain_suffix_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, agent_id); + } + + #[test] + fn select_agent_for_domain_no_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("dc01.other.local"); + assert!(agents.is_empty()); + } + + #[test] + fn select_agent_for_domain_longest_suffix_wins() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let id_a = agent_a.agent_id; + let subnet_a: Ipv4Network = "10.1.0.0/16".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet_a], vec![domain("contoso.local", false)]); + registry.register(agent_a); + + let agent_b = make_peer("agent-b"); + let id_b = agent_b.agent_id; + let subnet_b: Ipv4Network = "10.2.0.0/16".parse().expect("valid CIDR"); + agent_b.update_routes(1, vec![subnet_b], vec![domain("finance.contoso.local", false)]); + registry.register(agent_b); + + let agents = registry.select_agents_for_domain("db01.finance.contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, id_b); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, id_a); + } + + #[test] + fn select_agent_for_domain_multiple_agents_same_domain() { + let registry = AgentRegistry::new(); + + let agent_a = make_peer("agent-a"); + let subnet_a: Ipv4Network = "10.1.0.0/16".parse().expect("valid CIDR"); + agent_a.update_routes(1, vec![subnet_a], vec![domain("contoso.local", false)]); + registry.register(Arc::clone(&agent_a)); + + std::thread::sleep(Duration::from_millis(10)); + + let agent_b = make_peer("agent-b"); + let id_b = agent_b.agent_id; + let subnet_b: Ipv4Network = "10.2.0.0/16".parse().expect("valid CIDR"); + agent_b.update_routes(1, vec![subnet_b], vec![domain("contoso.local", false)]); + registry.register(Arc::clone(&agent_b)); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert_eq!(agents.len(), 2); + assert_eq!(agents[0].agent_id, id_b); + } + + #[test] + fn select_agent_for_domain_excludes_offline() { + let registry = AgentRegistry::new(); + + let agent = make_peer("offline-agent"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + agent.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + agent.last_seen.store(0, Ordering::Release); + registry.register(agent); + + let agents = registry.select_agents_for_domain("dc01.contoso.local"); + assert!(agents.is_empty()); + } + + #[test] + fn select_agent_for_domain_exact_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let agent_id = peer.agent_id; + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("contoso.local"); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].agent_id, agent_id); + } + + #[test] + fn select_agent_for_domain_bare_hostname_no_match() { + let registry = AgentRegistry::new(); + let peer = make_peer("agent-a"); + let subnet: Ipv4Network = "10.0.0.0/8".parse().expect("valid CIDR"); + peer.update_routes(1, vec![subnet], vec![domain("contoso.local", false)]); + registry.register(peer); + + let agents = registry.select_agents_for_domain("server01"); + assert!(agents.is_empty()); + } +} diff --git a/devolutions-gateway/src/agent_tunnel/stream.rs b/devolutions-gateway/src/agent_tunnel/stream.rs new file mode 100644 index 000000000..f979e22a8 --- /dev/null +++ b/devolutions-gateway/src/agent_tunnel/stream.rs @@ -0,0 +1,37 @@ +//! Wrapper around Quinn's `SendStream` + `RecvStream` providing a single +//! `AsyncRead + AsyncWrite` type for use with the gateway's proxy infrastructure. + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A bidirectional QUIC stream backed by Quinn's `SendStream` and `RecvStream`. +/// +/// Implements `AsyncRead` (delegating to `recv`) and `AsyncWrite` (delegating +/// to `send`), so callers can treat it as a single bidirectional transport. +pub struct TunnelStream { + pub send: quinn::SendStream, + pub recv: quinn::RecvStream, +} + +impl AsyncRead for TunnelStream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.recv), cx, buf) + } +} + +impl AsyncWrite for TunnelStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.send), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.send), cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.send), cx) + } +} diff --git a/devolutions-gateway/src/api/agent_enrollment.rs b/devolutions-gateway/src/api/agent_enrollment.rs new file mode 100644 index 000000000..f65f6dc43 --- /dev/null +++ b/devolutions-gateway/src/api/agent_enrollment.rs @@ -0,0 +1,302 @@ +use std::net::IpAddr; + +use axum::extract::{Path, State}; +use axum::http::HeaderMap; +use axum::{Json, Router}; +use uuid::Uuid; + +use crate::DgwState; +use crate::extract::{AgentManagementReadAccess, AgentManagementWriteAccess}; +use crate::http::HttpError; + +/// Timing-safe byte comparison to prevent side-channel attacks on secret comparison. +/// +/// Both inputs are hashed with SHA-256 first, producing fixed 32-byte digests. +/// The digest comparison runs in constant time (fixed-length XOR fold). +/// SHA-256 itself runs in time proportional to input length, but this only +/// reveals the length of the attacker's guess — not the secret's length or content. +fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { + use sha2::{Digest, Sha256}; + let da = Sha256::digest(a); + let db = Sha256::digest(b); + da.iter().zip(db.iter()).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0 +} + +#[derive(Deserialize)] +pub struct EnrollRequest { + /// Friendly name for the agent. + pub agent_name: String, + /// PEM-encoded Certificate Signing Request from the agent. + pub csr_pem: String, +} + +#[derive(Serialize)] +pub struct EnrollResponse { + /// Assigned agent ID. + pub agent_id: Uuid, + /// Agent name. + pub agent_name: String, + /// PEM-encoded client certificate (signed by the gateway CA). + pub client_cert_pem: String, + /// PEM-encoded gateway CA certificate (for server verification). + pub gateway_ca_cert_pem: String, + /// QUIC endpoint to connect to (`host:port`). + pub quic_endpoint: String, +} + +pub fn make_router(state: DgwState) -> Router { + Router::new() + .route("/enroll", axum::routing::post(enroll_agent)) + .route("/agents", axum::routing::get(list_agents)) + .route("/agents/{agent_id}", axum::routing::get(get_agent).delete(delete_agent)) + .route("/agents/resolve-target", axum::routing::post(resolve_target)) + .with_state(state) +} + +/// Enroll a new agent. +/// +/// Requires a Bearer token matching the configured enrollment secret +/// or a valid one-time enrollment token from the store. +/// +/// The agent generates its own key pair and sends a CSR. The gateway signs it +/// and returns the certificate. The private key never leaves the agent. +async fn enroll_agent( + State(DgwState { + conf_handle, + agent_tunnel_handle, + .. + }): State, + headers: HeaderMap, + Json(req): Json, +) -> Result, HttpError> { + // Validate agent name: 1-255 printable ASCII characters. + if req.agent_name.is_empty() + || 255 < req.agent_name.len() + || req.agent_name.bytes().any(|b| !(0x20..=0x7E).contains(&b)) + { + return Err(HttpError::bad_request().msg("agent name must be 1-255 printable ASCII characters")); + } + + let conf = conf_handle.get_conf(); + + // Extract the Bearer token. + let auth_header = headers + .get(axum::http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| HttpError::unauthorized().msg("missing Authorization header"))?; + + let provided_token = auth_header + .strip_prefix("Bearer ") + .ok_or_else(|| HttpError::unauthorized().msg("expected Bearer token"))?; + + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent enrollment is not configured"))?; + + // Try one-time enrollment token from the store first. + let token_valid = handle.enrollment_token_store().consume(provided_token); + + if !token_valid { + // Fall back to the static enrollment secret. + let enrollment_secret = conf + .agent_tunnel + .enrollment_secret + .as_deref() + .ok_or_else(|| HttpError::not_found().msg("agent enrollment is not configured"))?; + + if !constant_time_eq(provided_token.as_bytes(), enrollment_secret.as_bytes()) { + return Err(HttpError::forbidden().msg("invalid enrollment token")); + } + } + + let agent_id = Uuid::new_v4(); + + let signed = handle + .ca_manager() + .sign_agent_csr(agent_id, &req.agent_name, &req.csr_pem) + .map_err(HttpError::bad_request().with_msg("invalid CSR").err())?; + + let quic_endpoint = format!("{}:{}", conf.hostname, conf.agent_tunnel.listen_port); + + info!( + %agent_id, + agent_name = %req.agent_name, + "Agent enrolled successfully", + ); + + Ok(Json(EnrollResponse { + agent_id, + agent_name: req.agent_name, + client_cert_pem: signed.client_cert_pem, + gateway_ca_cert_pem: signed.ca_cert_pem, + quic_endpoint, + })) +} + +/// List connected agents and their status. +async fn list_agents( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementReadAccess, +) -> Result>, HttpError> { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + let agents = handle.registry().agent_infos(); + + Ok(Json(agents)) +} + +/// Get a single agent by ID. +async fn get_agent( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementReadAccess, + Path(agent_id): Path, +) -> Result, HttpError> { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + let info = handle + .registry() + .agent_info(&agent_id) + .ok_or_else(|| HttpError::not_found().msg("agent not found"))?; + + Ok(Json(info)) +} + +/// Delete (unregister) an agent by ID. +async fn delete_agent( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementWriteAccess, + Path(agent_id): Path, +) -> Result { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + handle + .registry() + .unregister(&agent_id) + .ok_or_else(|| HttpError::not_found().msg("agent not found"))?; + + info!(%agent_id, "Agent deleted via API"); + + Ok(axum::http::StatusCode::NO_CONTENT) +} + +#[derive(Deserialize)] +struct ResolveTargetRequest { + target: String, +} + +#[derive(Serialize)] +struct ResolveTargetResponse { + target: String, + target_ip: Option, + reachable_agents: Vec, + target_reachable: bool, +} + +/// Resolve a target string to find which agents can reach it. +async fn resolve_target( + State(DgwState { + agent_tunnel_handle, .. + }): State, + _access: AgentManagementReadAccess, + Json(req): Json, +) -> Result, HttpError> { + let handle = agent_tunnel_handle + .as_ref() + .ok_or_else(|| HttpError::not_found().msg("agent tunnel not configured"))?; + + let target_ip = parse_target_ip(&req.target); + + // Use the same routing logic as fwd.rs: IP → subnet match, hostname → domain suffix match + let matching_peers = if let Some(ip) = target_ip { + handle.registry().find_agents_for_target(ip) + } else { + let hostname = strip_scheme_and_port(&req.target); + handle.registry().select_agents_for_domain(hostname) + }; + + let reachable_agents: Vec<_> = matching_peers + .iter() + .map(crate::agent_tunnel::registry::AgentInfo::from) + .collect(); + + let target_reachable = !reachable_agents.is_empty(); + + Ok(Json(ResolveTargetResponse { + target: req.target, + target_ip, + reachable_agents, + target_reachable, + })) +} + +/// Strip scheme prefix and port from a target string, returning the bare host. +/// +/// Handles `tcp://host:port`, `http://host:port`, `host:port`, and bare hostnames. +fn strip_scheme_and_port(target: &str) -> &str { + let host_port = target + .strip_prefix("tcp://") + .or_else(|| target.strip_prefix("http://")) + .or_else(|| target.strip_prefix("https://")) + .unwrap_or(target); + + let host = if let Some((h, _port)) = host_port.rsplit_once(':') { + h + } else { + host_port + }; + + // Strip brackets for IPv6 literals like [::1]. + host.strip_prefix('[').and_then(|h| h.strip_suffix(']')).unwrap_or(host) +} + +fn parse_target_ip(target: &str) -> Option { + strip_scheme_and_port(target).parse::().ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_target_ip_bare_ipv4() { + assert_eq!(parse_target_ip("10.0.0.1"), Some("10.0.0.1".parse().expect("test"))); + } + + #[test] + fn parse_target_ip_with_port() { + assert_eq!( + parse_target_ip("10.0.0.1:3389"), + Some("10.0.0.1".parse().expect("test")) + ); + } + + #[test] + fn parse_target_ip_tcp_scheme() { + assert_eq!( + parse_target_ip("tcp://192.168.1.1:22"), + Some("192.168.1.1".parse().expect("test")) + ); + } + + #[test] + fn parse_target_ip_hostname_returns_none() { + assert_eq!(parse_target_ip("myserver.local:3389"), None); + } + + #[test] + fn parse_target_ip_bare_hostname_returns_none() { + assert_eq!(parse_target_ip("myserver"), None); + } +} diff --git a/devolutions-gateway/src/api/mod.rs b/devolutions-gateway/src/api/mod.rs index a5cbbc643..c867389dd 100644 --- a/devolutions-gateway/src/api/mod.rs +++ b/devolutions-gateway/src/api/mod.rs @@ -1,3 +1,4 @@ +pub mod agent_enrollment; pub mod ai; pub mod config; pub mod diagnostics; @@ -35,6 +36,7 @@ pub fn make_router(state: crate::DgwState) -> axum::Router { .nest("/jet/webapp", webapp::make_router(state.clone())) .nest("/jet/net", net::make_router(state.clone())) .nest("/jet/traffic", traffic::make_router(state.clone())) + .nest("/jet/agent-tunnel", agent_enrollment::make_router(state.clone())) .route("/jet/update", axum::routing::post(update::trigger_update_check)); if state.conf_handle.get_conf().web_app.enabled { diff --git a/devolutions-gateway/src/api/webapp.rs b/devolutions-gateway/src/api/webapp.rs index 2c6b99f89..f266f4207 100644 --- a/devolutions-gateway/src/api/webapp.rs +++ b/devolutions-gateway/src/api/webapp.rs @@ -342,6 +342,7 @@ pub(crate) async fn sign_session_token( exp, jti, cert_thumb256: None, + jet_agent_id: None, } .pipe(serde_json::to_value) .map(|mut claims| { diff --git a/devolutions-gateway/src/config.rs b/devolutions-gateway/src/config.rs index 5ca9fc49f..122408795 100644 --- a/devolutions-gateway/src/config.rs +++ b/devolutions-gateway/src/config.rs @@ -193,6 +193,7 @@ pub struct Conf { pub verbosity_profile: dto::VerbosityProfile, pub web_app: WebAppConf, pub ai_gateway: AiGatewayConf, + pub agent_tunnel: dto::AgentTunnelConf, pub proxy: dto::ProxyConf, pub debug: dto::DebugConf, } @@ -925,6 +926,7 @@ impl Conf { .as_ref() .map(AiGatewayConf::from_dto) .unwrap_or_default(), + agent_tunnel: conf_file.agent_tunnel.clone().unwrap_or_default(), proxy: conf_file.proxy.clone().unwrap_or_default(), debug: conf_file.debug.clone().unwrap_or_default(), }) @@ -1725,6 +1727,10 @@ pub mod dto { #[serde(skip_serializing_if = "Option::is_none")] pub proxy: Option, + /// (Unstable) Agent tunnel configuration (QUIC-based agent tunnel) + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_tunnel: Option, + /// (Unstable) Unsafe debug options for developers #[serde(rename = "__debug__", skip_serializing_if = "Option::is_none")] pub debug: Option, @@ -1780,6 +1786,7 @@ pub mod dto { ai_gateway: None, job_queue_database: None, traffic_audit_database: None, + agent_tunnel: None, proxy: None, debug: None, rest: serde_json::Map::new(), @@ -1914,6 +1921,38 @@ pub mod dto { pub kdc_url: Option, } + /// (Unstable) QUIC-based agent tunnel configuration + #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] + #[serde(rename_all = "PascalCase")] + pub struct AgentTunnelConf { + /// Whether the agent tunnel listener is enabled + #[serde(default)] + pub enabled: bool, + /// UDP port for the QUIC listener (default: 4433) + #[serde(default = "AgentTunnelConf::default_listen_port")] + pub listen_port: u16, + /// Shared secret for agent enrollment. + /// If set, agents can enroll by providing this secret as a Bearer token. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub enrollment_secret: Option, + } + + impl AgentTunnelConf { + fn default_listen_port() -> u16 { + 4433 + } + } + + impl Default for AgentTunnelConf { + fn default() -> Self { + Self { + enabled: false, + listen_port: Self::default_listen_port(), + enrollment_secret: None, + } + } + } + /// Unsafe debug options that should only ever be used at development stage /// /// These options might change or get removed without further notice. diff --git a/devolutions-gateway/src/extract.rs b/devolutions-gateway/src/extract.rs index 9f2450854..ada08ce9a 100644 --- a/devolutions-gateway/src/extract.rs +++ b/devolutions-gateway/src/extract.rs @@ -386,6 +386,64 @@ where } } +/// Grants read access to agent management endpoints. +/// +/// Accepts a scope token with `DiagnosticsRead`, `ConfigWrite`, or `Wildcard` scope. +#[derive(Clone, Copy)] +pub struct AgentManagementReadAccess; + +impl FromRequestParts for AgentManagementReadAccess +where + S: Send + Sync, +{ + type Rejection = HttpError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let claims = Extension::::from_request_parts(parts, state) + .await + .map_err(HttpError::internal().err())? + .0; + + // DiagnosticsRead is accepted because DVLS maps its AgentRead scope + // to GatewayDiagnosticsRead, which serializes as "gateway.diagnostics.read". + match claims { + AccessTokenClaims::Scope(scope) => match scope.scope { + AccessScope::Wildcard | AccessScope::DiagnosticsRead | AccessScope::ConfigWrite => Ok(Self), + _ => Err(HttpError::forbidden().msg("invalid scope for agent management read")), + }, + _ => Err(HttpError::forbidden().msg("scope token required for agent management read")), + } + } +} + +/// Grants write access to agent management endpoints (e.g. enrollment, delete). +/// +/// Accepts scope tokens with `ConfigWrite` (or `Wildcard`) scope only. +#[derive(Clone, Copy)] +pub struct AgentManagementWriteAccess; + +impl FromRequestParts for AgentManagementWriteAccess +where + S: Send + Sync, +{ + type Rejection = HttpError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let claims = Extension::::from_request_parts(parts, state) + .await + .map_err(HttpError::internal().err())? + .0; + + match claims { + AccessTokenClaims::Scope(scope) => match scope.scope { + AccessScope::Wildcard | AccessScope::ConfigWrite => Ok(Self), + _ => Err(HttpError::forbidden().msg("invalid scope for agent management write")), + }, + _ => Err(HttpError::forbidden().msg("scope token required for agent management write")), + } + } +} + #[derive(Clone)] pub struct WebAppToken(pub WebAppTokenClaims); diff --git a/devolutions-gateway/src/generic_client.rs b/devolutions-gateway/src/generic_client.rs index 13c1d9c48..d8209ce79 100644 --- a/devolutions-gateway/src/generic_client.rs +++ b/devolutions-gateway/src/generic_client.rs @@ -6,6 +6,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _}; use tracing::field; use typed_builder::TypedBuilder; +use crate::agent_tunnel::AgentTunnelHandle; use crate::config::Conf; use crate::credential::CredentialStoreHandle; use crate::proxy::Proxy; @@ -27,6 +28,8 @@ pub struct GenericClient { subscriber_tx: SubscriberSender, active_recordings: Arc, credential_store: CredentialStoreHandle, + #[builder(default)] + agent_tunnel_handle: Option>, } impl GenericClient @@ -49,6 +52,7 @@ where subscriber_tx, active_recordings, credential_store, + agent_tunnel_handle, } = self; let span = tracing::Span::current(); @@ -109,6 +113,76 @@ where RecordingPolicy::Proxy => anyhow::bail!("can't meet recording policy"), } + // Route via agent tunnel if jet_agent_id is specified. + if let Some(agent_id) = claims.jet_agent_id { + let handle = agent_tunnel_handle.context("agent tunnel not configured on this gateway")?; + + let mut selected_target = None; + let mut server_stream = None; + let mut last_error = None; + + for candidate in targets.iter() { + let target_str = format!("{}:{}", candidate.host(), candidate.port()); + + info!(%agent_id, %target_str, "Routing via agent tunnel"); + + match handle.connect_via_agent(agent_id, claims.jet_aid, &target_str).await { + Ok(stream) => { + selected_target = Some(candidate.clone()); + server_stream = Some(stream); + break; + } + Err(error) => { + warn!( + %agent_id, + %target_str, + error = format!("{error:#}"), + "Agent tunnel target failed" + ); + last_error = Some(error); + } + } + } + + let selected_target = selected_target.ok_or_else(|| { + last_error.unwrap_or_else(|| anyhow::anyhow!("agent tunnel target selection failed")) + })?; + span.record("target", selected_target.to_string()); + let server_stream = server_stream.expect("server stream should be present when target is selected"); + + let info = SessionInfo::builder() + .id(claims.jet_aid) + .application_protocol(claims.jet_ap) + .details(ConnectionModeDetails::Fwd { + destination_host: selected_target.clone(), + }) + .time_to_live(claims.jet_ttl) + .recording_policy(claims.jet_rec) + .filtering_policy(claims.jet_flt) + .build(); + + let disconnect_interest = DisconnectInterest::from_reconnection_policy(claims.jet_reuse); + + // Agent handles the TCP connection; no leftover bytes to forward. + // Use a placeholder server address since the actual target is behind the agent. + let server_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); + + return Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_addr) + .transport_b(server_stream) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .disconnect_interest(disconnect_interest) + .build() + .select_dissector_and_forward() + .await + .context("encountered a failure during agent tunnel traffic proxying"); + } + trace!("Select and connect to target"); let ((mut server_stream, server_addr), selected_target) = diff --git a/devolutions-gateway/src/lib.rs b/devolutions-gateway/src/lib.rs index ed1a28099..93d782530 100644 --- a/devolutions-gateway/src/lib.rs +++ b/devolutions-gateway/src/lib.rs @@ -12,6 +12,7 @@ extern crate tracing; #[cfg(feature = "openapi")] pub mod openapi; +pub mod agent_tunnel; pub mod ai; pub mod api; pub mod cli; @@ -61,6 +62,7 @@ pub struct DgwState { pub credential_store: credential::CredentialStoreHandle, pub monitoring_state: Arc, pub traffic_audit_handle: traffic_audit::TrafficAuditHandle, + pub agent_tunnel_handle: Option>, } #[doc(hidden)] @@ -100,6 +102,7 @@ impl DgwState { traffic_audit_handle, credential_store, monitoring_state, + agent_tunnel_handle: None, }; let handles = MockHandles { diff --git a/devolutions-gateway/src/listener.rs b/devolutions-gateway/src/listener.rs index db5926be5..0b7ce2740 100644 --- a/devolutions-gateway/src/listener.rs +++ b/devolutions-gateway/src/listener.rs @@ -159,6 +159,7 @@ async fn handle_tcp_peer(stream: TcpStream, state: DgwState, peer_addr: SocketAd .subscriber_tx(state.subscriber_tx) .active_recordings(state.recordings.active_recordings) .credential_store(state.credential_store) + .agent_tunnel_handle(state.agent_tunnel_handle) .build() .serve() .await?; diff --git a/devolutions-gateway/src/middleware/auth.rs b/devolutions-gateway/src/middleware/auth.rs index f07e6e1b0..18f08bb66 100644 --- a/devolutions-gateway/src/middleware/auth.rs +++ b/devolutions-gateway/src/middleware/auth.rs @@ -95,6 +95,14 @@ const AUTH_EXCEPTIONS: &[AuthException] = &[ path: "/jet/ai", exact_match: false, }, + // Agent Tunnel: only /enroll skips auth (it uses its own bearer token). + // TODO: add rate limiting on this endpoint (tokens are 122-bit UUIDs so brute-force + // is infeasible, but rate limiting is good defense-in-depth). + AuthException { + method: Method::POST, + path: "/jet/agent-tunnel/enroll", + exact_match: true, + }, ]; pub async fn auth_middleware( diff --git a/devolutions-gateway/src/ngrok.rs b/devolutions-gateway/src/ngrok.rs index 8d32f58d4..71c0c005f 100644 --- a/devolutions-gateway/src/ngrok.rs +++ b/devolutions-gateway/src/ngrok.rs @@ -238,6 +238,7 @@ async fn run_tcp_tunnel(mut tunnel: ngrok::tunnel::TcpTunnel, state: DgwState) { .subscriber_tx(state.subscriber_tx) .active_recordings(state.recordings.active_recordings) .credential_store(state.credential_store) + .agent_tunnel_handle(state.agent_tunnel_handle) .build() .serve() .await diff --git a/devolutions-gateway/src/rd_clean_path.rs b/devolutions-gateway/src/rd_clean_path.rs index 6d4614b5e..9855a8987 100644 --- a/devolutions-gateway/src/rd_clean_path.rs +++ b/devolutions-gateway/src/rd_clean_path.rs @@ -669,9 +669,7 @@ impl From<&CleanPathError> for RDCleanPathPdu { } fn io_to_rdcleanpath_err(err: &io::Error) -> RDCleanPathPdu { - if let Some(tokio_rustls::rustls::Error::AlertReceived(tls_alert)) = err - .get_ref() - .and_then(|e| e.downcast_ref::()) + if let Some(rustls::Error::AlertReceived(tls_alert)) = err.get_ref().and_then(|e| e.downcast_ref::()) { RDCleanPathPdu::new_tls_error(u8::from(*tls_alert)) } else { diff --git a/devolutions-gateway/src/service.rs b/devolutions-gateway/src/service.rs index 64dde91c4..45ae07264 100644 --- a/devolutions-gateway/src/service.rs +++ b/devolutions-gateway/src/service.rs @@ -10,7 +10,7 @@ use devolutions_gateway::recording::recording_message_channel; use devolutions_gateway::session::session_manager_channel; use devolutions_gateway::subscriber::subscriber_channel; use devolutions_gateway::token::{CurrentJrl, JrlTokenClaims}; -use devolutions_gateway::{DgwState, SYSTEM_LOGGER, config}; +use devolutions_gateway::{DgwState, SYSTEM_LOGGER, agent_tunnel, config}; use devolutions_gateway_task::{ChildTask, ShutdownHandle, ShutdownSignal}; use devolutions_log::{self, LoggerGuard}; use parking_lot::Mutex; @@ -275,6 +275,35 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { ); let monitoring_state = Arc::new(network_monitor::State::new(Arc::new(filesystem_monitor_config_cache))?); + // Initialize agent tunnel if configured. + let agent_tunnel_handle = if conf.agent_tunnel.enabled { + let data_dir = config::get_data_dir(); + let hostname = &conf.hostname; + + let ca_manager = Arc::new( + agent_tunnel::cert::CaManager::load_or_generate(&data_dir) + .context("failed to initialize agent tunnel CA")?, + ); + + let listen_addr = std::net::SocketAddr::from((std::net::Ipv4Addr::UNSPECIFIED, conf.agent_tunnel.listen_port)); + + let (listener, handle) = + agent_tunnel::AgentTunnelListener::bind(listen_addr, Arc::clone(&ca_manager), hostname) + .await + .context("failed to bind agent tunnel listener")?; + + tasks.register(listener); + + info!( + port = conf.agent_tunnel.listen_port, + "Agent tunnel QUIC listener started", + ); + + Some(Arc::new(handle)) + } else { + None + }; + let state = DgwState { conf_handle: conf_handle.clone(), token_cache: Arc::clone(&token_cache), @@ -287,6 +316,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result { credential_store: credential_store.clone(), monitoring_state, traffic_audit_handle: traffic_audit_task.handle(), + agent_tunnel_handle, }; for listener in &conf.listeners { diff --git a/devolutions-gateway/src/token.rs b/devolutions-gateway/src/token.rs index 75b7d112e..5912d7dbf 100644 --- a/devolutions-gateway/src/token.rs +++ b/devolutions-gateway/src/token.rs @@ -425,6 +425,12 @@ pub struct AssociationTokenClaims { /// Optional SHA-256 thumbprint of target server certificate (for anchored TLS validation) pub cert_thumb256: Option, + + /// Optional agent ID for routing connections through an enrolled agent tunnel. + /// + /// When set alongside `ConnectionMode::Fwd`, the Gateway will proxy the connection + /// through the specified agent instead of connecting directly to the target. + pub jet_agent_id: Option, } // ----- scope claims ----- // @@ -466,15 +472,15 @@ pub enum AccessScope { NetMonitorDrain, } -#[derive(Clone, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct ScopeTokenClaims { pub scope: AccessScope, /// JWT expiration time claim. - exp: i64, + pub exp: i64, /// JWT "JWT ID" claim, the unique ID for this token - jti: Uuid, + pub jti: Uuid, } // ----- bridge claims ----- // @@ -1312,6 +1318,8 @@ mod serde_impl { jti: Uuid, #[serde(default)] cert_thumb256: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + jet_agent_id: Option, } #[derive(Deserialize)] @@ -1420,6 +1428,7 @@ mod serde_impl { exp: self.exp, jti: self.jti, cert_thumb256: self.cert_thumb256.as_ref().map(|thumb| SmolStr::new(thumb.as_str())), + jet_agent_id: self.jet_agent_id, } .serialize(serializer) } @@ -1469,6 +1478,7 @@ mod serde_impl { .map(crate::tls::thumbprint::normalize_sha256_thumbprint) .transpose() .map_err(de::Error::custom)?, + jet_agent_id: claims.jet_agent_id, }) } } diff --git a/devolutions-gateway/tests/config.rs b/devolutions-gateway/tests/config.rs index 4cb015e3f..e2eeb8a7c 100644 --- a/devolutions-gateway/tests/config.rs +++ b/devolutions-gateway/tests/config.rs @@ -97,6 +97,7 @@ fn hub_sample() -> Sample { verbosity_profile: Some(VerbosityProfile::Tls), web_app: None, ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -146,6 +147,7 @@ fn legacy_sample() -> Sample { verbosity_profile: None, web_app: None, ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -194,6 +196,7 @@ fn system_store_sample() -> Sample { verbosity_profile: None, web_app: None, ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -274,6 +277,7 @@ fn standalone_custom_auth_sample() -> Sample { static_root_path: None, }), ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -354,6 +358,7 @@ fn standalone_no_auth_sample() -> Sample { static_root_path: Some("/path/to/webapp/static/root".into()), }), ai_gateway: None, + agent_tunnel: None, proxy: None, debug: None, rest: Default::default(), @@ -439,6 +444,7 @@ fn proxy_sample() -> Sample { ], }), debug: None, + agent_tunnel: None, rest: Default::default(), }, }