From 0a382f8156fcf786ae6a512ef2a21b8c01bce861 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Sat, 25 Apr 2026 08:16:03 +0200 Subject: [PATCH 1/6] Harden auth and generated tooling --- Cargo.lock | 381 +++++++++++++++++- Cargo.toml | 2 +- .../examples/google_oauth2.rs | 5 +- .../identity/ras-identity-oauth2/src/lib.rs | 4 +- .../identity/ras-identity-session/src/lib.rs | 149 ++++++- crates/rest/ras-file-macro/Cargo.toml | 2 +- crates/rest/ras-file-macro/src/client.rs | 2 +- crates/rest/ras-file-macro/src/server.rs | 6 +- .../rest/ras-file-macro/tests/integration.rs | 17 +- .../rest/ras-file-macro/tests/minimal_test.rs | 1 - .../rest/ras-file-macro/tests/simple_test.rs | 1 - crates/rest/ras-rest-macro/src/lib.rs | 2 +- .../rest/ras-rest-macro/src/static_hosting.rs | 23 +- .../ras-rest-macro/tests/http_integration.rs | 13 +- .../tests/xss_protection_test.rs | 9 + .../src/connection.rs | 9 +- .../src/handler.rs | 28 +- .../src/manager.rs | 23 +- .../src/router.rs | 2 +- .../src/service.rs | 48 ++- .../src/upgrade.rs | 57 ++- .../examples/explorer_params_demo.rs | 16 +- .../src/jsonrpc_explorer_template.html | 25 +- .../tests/explorer_token_storage_test.rs | 8 + .../tests/http_integration.rs | 7 - crates/specs/openrpc-types/src/schema.rs | 60 +-- .../tools/openrpc-to-bruno/src/converter.rs | 60 ++- crates/tools/openrpc-to-bruno/src/error.rs | 3 + .../openrpc-to-bruno/tests/integration.rs | 92 ++++- .../bidirectional-chat/server/src/main.rs | 1 + .../server/tests/server_tests.rs | 2 +- .../server/tests/websocket_tests.rs | 12 +- examples/oauth2-demo/server/src/main.rs | 7 +- .../rest-wasm-example/rest-api/Cargo.toml | 2 +- 34 files changed, 905 insertions(+), 174 deletions(-) create mode 100644 crates/rpc/ras-jsonrpc-macro/tests/explorer_token_storage_test.rs diff --git a/Cargo.lock b/Cargo.lock index ee97d4d..052ce78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -322,6 +322,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.21.7" @@ -666,6 +672,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-random" version = "0.1.18" @@ -787,6 +799,18 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -820,6 +844,33 @@ dependencies = [ "syn", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if 1.0.0", + "cpufeatures", + "curve25519-dalek-derive", + "digest", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "darling" version = "0.20.11" @@ -893,6 +944,17 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.4.0" @@ -915,6 +977,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -1058,12 +1121,71 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der", + "digest", + "elliptic-curve", + "rfc6979", + "signature", + "spki", +] + +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8", + "signature", +] + +[[package]] +name = "ed25519-dalek" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" +dependencies = [ + "curve25519-dalek", + "ed25519", + "serde", + "sha2", + "subtle", + "zeroize", +] + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest", + "ff", + "generic-array", + "group", + "hkdf", + "pem-rfc7468", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + [[package]] name = "email_address" version = "0.2.9" @@ -1141,6 +1263,22 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "file-service-api" version = "0.1.0" @@ -1360,6 +1498,7 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", ] [[package]] @@ -1381,10 +1520,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if 1.0.0", - "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "wasm-bindgen", ] [[package]] @@ -1477,6 +1614,17 @@ dependencies = [ "web-sys", ] +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "h2" version = "0.4.10" @@ -1538,6 +1686,24 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.3.1" @@ -1904,16 +2070,24 @@ dependencies = [ [[package]] name = "jsonwebtoken" -version = "9.3.1" +version = "10.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +checksum = "0529410abe238729a60b108898784df8984c87f6054c9c4fcacc47e4803c1ce1" dependencies = [ "base64 0.22.1", + "ed25519-dalek", + "getrandom 0.2.16", + "hmac", "js-sys", + "p256", + "p384", "pem", - "ring", + "rand 0.8.5", + "rsa", "serde", "serde_json", + "sha2", + "signature", "simple_asn1", ] @@ -1922,6 +2096,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "libc" @@ -1929,6 +2106,12 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -2128,6 +2311,22 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" +dependencies = [ + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2143,6 +2342,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2150,6 +2360,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -2362,6 +2573,30 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + [[package]] name = "parking_lot" version = "0.12.4" @@ -2418,6 +2653,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -2543,6 +2787,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.32" @@ -2593,6 +2858,15 @@ dependencies = [ "syn", ] +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -3200,6 +3474,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + [[package]] name = "ring" version = "0.17.14" @@ -3226,6 +3510,26 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "rsa" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8573f03f5883dcaebdfcf4725caa1ecb9c15b2ef50c43a07b816e06799bb12d" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust-ini" version = "0.20.0" @@ -3257,6 +3561,15 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.44" @@ -3377,6 +3690,20 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -3400,12 +3727,19 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] @@ -3420,11 +3754,20 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -3565,6 +3908,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "simple_asn1" version = "0.6.3" @@ -3614,6 +3967,16 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 6043c3b..ee05510 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ gloo-net = "0.6" gloo-utils = "0.2" http = "1.0" js-sys = "0.3" -jsonwebtoken = "9.3" +jsonwebtoken = { version = "10.3", features = ["rust_crypto"] } mime_guess = "2.0" once_cell = "1.20" opentelemetry = "0.28" diff --git a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs index f31f436..098d4ee 100644 --- a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs +++ b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs @@ -51,8 +51,9 @@ async fn main() -> Result<(), Box> { let oauth2_provider = OAuth2Provider::new(oauth2_config, state_store); // Create session service - let session_config = SessionConfig::default(); - let session_service = SessionService::new(session_config); + let session_config = + SessionConfig::new("oauth2-example-secret-that-is-long-enough-for-tests").unwrap(); + let session_service = SessionService::new(session_config).unwrap(); // Register OAuth2 provider with session service session_service diff --git a/crates/identity/ras-identity-oauth2/src/lib.rs b/crates/identity/ras-identity-oauth2/src/lib.rs index 6c860fa..baf54b0 100644 --- a/crates/identity/ras-identity-oauth2/src/lib.rs +++ b/crates/identity/ras-identity-oauth2/src/lib.rs @@ -19,7 +19,9 @@ pub use config::{OAuth2Config, OAuth2ProviderConfig}; pub use error::{OAuth2Error, OAuth2Result}; pub use provider::{OAuth2AuthPayload, OAuth2Provider, OAuth2Response}; pub use state::{InMemoryStateStore, OAuth2State, OAuth2StateStore}; -pub use types::{AuthorizationRequest, AuthorizationResponse, TokenResponse, UserInfoResponse}; +pub use types::{ + AuthorizationRequest, AuthorizationResponse, ProviderMetadata, TokenResponse, UserInfoResponse, +}; // Re-export common types for convenience pub use ras_identity_core::{IdentityProvider, VerifiedIdentity}; diff --git a/crates/identity/ras-identity-session/src/lib.rs b/crates/identity/ras-identity-session/src/lib.rs index d9da41b..45589ff 100644 --- a/crates/identity/ras-identity-session/src/lib.rs +++ b/crates/identity/ras-identity-session/src/lib.rs @@ -25,6 +25,9 @@ pub enum SessionError { #[error("Invalid session")] InvalidSession, + + #[error("Invalid session configuration: {0}")] + InvalidConfig(String), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -49,17 +52,60 @@ pub struct SessionConfig { pub algorithm: Algorithm, } -impl Default for SessionConfig { - fn default() -> Self { - Self { - jwt_secret: "change-me-in-production".to_string(), +impl SessionConfig { + pub fn new(jwt_secret: impl Into) -> Result { + let config = Self { + jwt_secret: jwt_secret.into(), jwt_ttl: Duration::hours(24), refresh_enabled: true, enforce_active_sessions: true, algorithm: Algorithm::HS256, + }; + config.validate()?; + Ok(config) + } + + pub fn validate(&self) -> Result<(), SessionError> { + validate_jwt_secret(&self.jwt_secret)?; + + if self.jwt_ttl <= Duration::zero() { + return Err(SessionError::InvalidConfig( + "jwt_ttl must be positive".to_string(), + )); } + + Ok(()) } } + +fn validate_jwt_secret(secret: &str) -> Result<(), SessionError> { + let trimmed = secret.trim(); + let insecure_placeholders = [ + "change-me-in-production", + "change-me", + "secret", + "test-secret", + "test-secret-key", + ]; + + if trimmed.len() < 32 { + return Err(SessionError::InvalidConfig( + "jwt_secret must be at least 32 bytes".to_string(), + )); + } + + if insecure_placeholders + .iter() + .any(|placeholder| trimmed.eq_ignore_ascii_case(placeholder)) + { + return Err(SessionError::InvalidConfig( + "jwt_secret must not use a placeholder value".to_string(), + )); + } + + Ok(()) +} + pub struct SessionService { config: SessionConfig, providers: Arc>>>, @@ -67,13 +113,14 @@ pub struct SessionService { permissions_provider: Option>, } impl SessionService { - pub fn new(config: SessionConfig) -> Self { - Self { + pub fn new(config: SessionConfig) -> Result { + config.validate()?; + Ok(Self { config, providers: Arc::new(RwLock::new(HashMap::new())), active_sessions: Arc::new(RwLock::new(HashMap::new())), permissions_provider: None, - } + }) } pub fn with_permissions(mut self, provider: Arc) -> Self { @@ -95,6 +142,10 @@ impl SessionService { provider_id: &str, auth_payload: serde_json::Value, ) -> Result { + if self.config.enforce_active_sessions { + self.cleanup_expired_sessions().await; + } + let providers = self.providers.read().await; let provider = providers .get(provider_id) @@ -139,10 +190,18 @@ impl SessionService { } pub async fn verify_session(&self, token: &str) -> Result { + if self.config.enforce_active_sessions { + self.cleanup_expired_sessions().await; + } + + let mut validation = Validation::new(self.config.algorithm); + validation.set_required_spec_claims(&["exp"]); + validation.validate_exp = true; + let token_data = decode::( token, &DecodingKey::from_secret(self.config.jwt_secret.as_bytes()), - &Validation::new(self.config.algorithm), + &validation, )?; if self.config.enforce_active_sessions { @@ -159,6 +218,14 @@ impl SessionService { let mut sessions = self.active_sessions.write().await; sessions.remove(jti) } + + pub async fn cleanup_expired_sessions(&self) -> usize { + let now = Utc::now().timestamp(); + let mut sessions = self.active_sessions.write().await; + let before = sessions.len(); + sessions.retain(|_, claims| claims.exp > now); + before - sessions.len() + } } #[derive(Clone)] @@ -206,10 +273,12 @@ mod tests { use ras_identity_core::StaticPermissions; use ras_identity_local::LocalUserProvider; + const TEST_SECRET: &str = "test-secret-that-is-long-enough-for-hs256"; + #[tokio::test] async fn test_session_lifecycle() { - let config = SessionConfig::default(); - let session_service = SessionService::new(config); + let config = SessionConfig::new(TEST_SECRET).unwrap(); + let session_service = SessionService::new(config).unwrap(); let local_provider = LocalUserProvider::new(); local_provider @@ -248,12 +317,14 @@ mod tests { #[tokio::test] async fn test_session_with_permissions() { - let config = SessionConfig::default(); + let config = SessionConfig::new(TEST_SECRET).unwrap(); let permissions_provider = Arc::new(StaticPermissions::new(vec![ "read".to_string(), "write".to_string(), ])); - let session_service = SessionService::new(config).with_permissions(permissions_provider); + let session_service = SessionService::new(config) + .unwrap() + .with_permissions(permissions_provider); let local_provider = LocalUserProvider::new(); local_provider @@ -286,4 +357,58 @@ mod tests { assert!(claims.permissions.contains("read")); assert!(claims.permissions.contains("write")); } + + #[test] + fn test_rejects_placeholder_secret() { + let result = SessionConfig::new("change-me-in-production"); + assert!(matches!(result, Err(SessionError::InvalidConfig(_)))); + } + + #[tokio::test] + async fn test_cleanup_expired_sessions() { + let config = SessionConfig::new(TEST_SECRET).unwrap(); + let service = SessionService::new(config).unwrap(); + + { + let mut sessions = service.active_sessions.write().await; + sessions.insert( + "expired".to_string(), + JwtClaims { + sub: "user".to_string(), + exp: Utc::now().timestamp() - 1, + iat: Utc::now().timestamp() - 10, + jti: "expired".to_string(), + provider_id: "local".to_string(), + email: None, + display_name: None, + permissions: HashSet::new(), + metadata: None, + }, + ); + } + + assert_eq!(service.cleanup_expired_sessions().await, 1); + } + + #[tokio::test] + async fn test_malformed_exp_claim_is_rejected() { + let config = SessionConfig::new(TEST_SECRET).unwrap(); + let service = SessionService::new(config).unwrap(); + + let token = encode( + &Header::new(Algorithm::HS256), + &serde_json::json!({ + "sub": "user", + "exp": "not-a-number", + "iat": Utc::now().timestamp(), + "jti": "malformed", + "provider_id": "local", + "permissions": [], + }), + &EncodingKey::from_secret(TEST_SECRET.as_bytes()), + ) + .unwrap(); + + assert!(service.verify_session(&token).await.is_err()); + } } diff --git a/crates/rest/ras-file-macro/Cargo.toml b/crates/rest/ras-file-macro/Cargo.toml index 545450f..96c205c 100644 --- a/crates/rest/ras-file-macro/Cargo.toml +++ b/crates/rest/ras-file-macro/Cargo.toml @@ -22,4 +22,4 @@ serde = { workspace = true } serde_json = { workspace = true } ras-auth-core = { path = "../../core/ras-auth-core" } thiserror = { workspace = true } -async-trait = { workspace = true } \ No newline at end of file +async-trait = { workspace = true } diff --git a/crates/rest/ras-file-macro/src/client.rs b/crates/rest/ras-file-macro/src/client.rs index 3d71a38..b687c98 100644 --- a/crates/rest/ras-file-macro/src/client.rs +++ b/crates/rest/ras-file-macro/src/client.rs @@ -292,7 +292,7 @@ fn generate_wasm_client(definition: &FileServiceDefinition) -> TokenStream { let wasm_methods = generate_wasm_methods(&definition.endpoints); quote! { - #[cfg(all(target_arch = "wasm32", feature = "wasm-client"))] + #[cfg(target_arch = "wasm32")] pub mod wasm_client { use super::*; use wasm_bindgen::prelude::*; diff --git a/crates/rest/ras-file-macro/src/server.rs b/crates/rest/ras-file-macro/src/server.rs index 1d7573b..8e1c92b 100644 --- a/crates/rest/ras-file-macro/src/server.rs +++ b/crates/rest/ras-file-macro/src/server.rs @@ -106,9 +106,9 @@ pub fn generate_server(definition: &FileServiceDefinition) -> TokenStream { #error_name::NotFound => (StatusCode::NOT_FOUND, self.to_string()), #error_name::InvalidFormat => (StatusCode::BAD_REQUEST, self.to_string()), #error_name::FileTooLarge => (StatusCode::PAYLOAD_TOO_LARGE, self.to_string()), - #error_name::UploadFailed(msg) => (StatusCode::BAD_REQUEST, msg), - #error_name::DownloadFailed(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), - #error_name::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), + #error_name::UploadFailed(_) => (StatusCode::BAD_REQUEST, "Upload failed".to_string()), + #error_name::DownloadFailed(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Download failed".to_string()), + #error_name::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string()), }; <(::axum::http::StatusCode, String) as ::axum::response::IntoResponse>::into_response((status, message)) diff --git a/crates/rest/ras-file-macro/tests/integration.rs b/crates/rest/ras-file-macro/tests/integration.rs index 84e3947..abe7983 100644 --- a/crates/rest/ras-file-macro/tests/integration.rs +++ b/crates/rest/ras-file-macro/tests/integration.rs @@ -179,13 +179,13 @@ mod tests { client.set_bearer_token(Some("test-token")); // Test that client methods exist - assert!(true); // Basic compilation test + let _ = client; } #[test] fn test_file_error_variants() { // Test that all error variants exist - let _errors = vec![ + let _errors = [ TestFileServiceFileError::NotFound, TestFileServiceFileError::UploadFailed("test".to_string()), TestFileServiceFileError::DownloadFailed("test".to_string()), @@ -195,8 +195,8 @@ mod tests { ]; } - #[test] - fn test_error_into_response() { + #[tokio::test] + async fn test_error_into_response() { use axum::http::StatusCode; use axum::response::IntoResponse; @@ -229,5 +229,14 @@ mod tests { let (parts, _) = response.into_parts(); assert_eq!(parts.status, expected_status); } + + let response = TestFileServiceFileError::Internal("database password leaked".to_string()) + .into_response(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.contains("Internal server error")); + assert!(!body.contains("database password leaked")); } } diff --git a/crates/rest/ras-file-macro/tests/minimal_test.rs b/crates/rest/ras-file-macro/tests/minimal_test.rs index 7fe190b..e76413e 100644 --- a/crates/rest/ras-file-macro/tests/minimal_test.rs +++ b/crates/rest/ras-file-macro/tests/minimal_test.rs @@ -46,5 +46,4 @@ fn test_compilation() { let service = MyService; let auth = DummyAuth; let _builder = MinimalServiceBuilder::new(service).auth_provider(auth); - assert!(true); } diff --git a/crates/rest/ras-file-macro/tests/simple_test.rs b/crates/rest/ras-file-macro/tests/simple_test.rs index f696099..aaa3f0e 100644 --- a/crates/rest/ras-file-macro/tests/simple_test.rs +++ b/crates/rest/ras-file-macro/tests/simple_test.rs @@ -12,5 +12,4 @@ file_service!({ #[test] fn test_compilation() { // If it compiles, the test passes - assert!(true); } diff --git a/crates/rest/ras-rest-macro/src/lib.rs b/crates/rest/ras-rest-macro/src/lib.rs index 2e7725c..d7b97eb 100644 --- a/crates/rest/ras-rest-macro/src/lib.rs +++ b/crates/rest/ras-rest-macro/src/lib.rs @@ -542,7 +542,7 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::ResultJWT Token +
@@ -705,7 +709,7 @@ pub fn generate_static_hosting_code( // Global state let apiSpec = null; let currentEndpoint = null; - let jwtToken = localStorage.getItem('jwt-token') || ''; + let jwtToken = ''; // Initialize the application document.addEventListener('DOMContentLoaded', async () => {{ @@ -732,9 +736,13 @@ pub fn generate_static_hosting_code( function initializeAuth() {{ const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); + const rememberedToken = sessionStorage.getItem('jwt-token') || ''; - if (jwtToken) {{ + if (rememberedToken) {{ + jwtToken = rememberedToken; tokenInput.value = jwtToken; + rememberToken.checked = true; authStatus.classList.add('authenticated'); }} }} @@ -773,13 +781,17 @@ pub fn generate_static_hosting_code( function saveToken() {{ const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); jwtToken = tokenInput.value.trim(); - localStorage.setItem('jwt-token', jwtToken); + sessionStorage.removeItem('jwt-token'); + if (jwtToken && rememberToken.checked) {{ + sessionStorage.setItem('jwt-token', jwtToken); + }} if (jwtToken) {{ authStatus.classList.add('authenticated'); - showSuccess('Token saved successfully'); + showSuccess(rememberToken.checked ? 'Token saved for this tab' : 'Token ready for this page'); }} else {{ authStatus.classList.remove('authenticated'); }} @@ -788,8 +800,9 @@ pub fn generate_static_hosting_code( // Clear JWT token function clearToken() {{ jwtToken = ''; - localStorage.removeItem('jwt-token'); + sessionStorage.removeItem('jwt-token'); document.getElementById('jwt-token').value = ''; + document.getElementById('remember-token').checked = false; document.getElementById('auth-status').classList.remove('authenticated'); showSuccess('Token cleared'); }} diff --git a/crates/rest/ras-rest-macro/tests/http_integration.rs b/crates/rest/ras-rest-macro/tests/http_integration.rs index a1f3051..c2b53fb 100644 --- a/crates/rest/ras-rest-macro/tests/http_integration.rs +++ b/crates/rest/ras-rest-macro/tests/http_integration.rs @@ -2,7 +2,6 @@ use rand::Rng; use ras_jsonrpc_core::{AuthError, AuthFuture, AuthProvider, AuthenticatedUser}; use ras_rest_core::{RestError, RestResponse}; use ras_rest_macro::rest_service; -use reqwest; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::collections::HashSet; @@ -59,12 +58,6 @@ struct PostsResponse { total: usize, } -#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] -struct ErrorResponse { - error: String, - details: Option, -} - // Simple test auth provider struct TestRestAuthProvider { valid_tokens: HashSet, @@ -940,7 +933,6 @@ async fn test_openapi_generation() { // The fact that this compiles means the REST service macro generated the builder correctly // with OpenAPI configuration enabled - assert!(true, "OpenAPI generation compiled successfully"); } #[tokio::test] @@ -951,7 +943,6 @@ async fn test_missing_dependencies() { // This test ensures that our future handling is working correctly let handles: Vec> = vec![]; let _results = join_all(handles).await; - assert!(true, "Futures dependency is working"); } #[tokio::test] @@ -1049,7 +1040,7 @@ async fn test_generated_rest_client() { assert_eq!(resp.total, 2); - let _resp = client + client .delete_users_by_id_with_timeout(resp.users[0].id.unwrap(), None) .await .expect("failed to get users"); @@ -1118,7 +1109,7 @@ async fn test_query_parameters_with_auth() { assert_eq!(response.status(), 200); let posts_response: PostsResponse = response.json().await.unwrap(); assert!(posts_response.posts[0].tags.contains(&"test".to_string())); - assert_eq!(posts_response.posts[0].published, true); + assert!(posts_response.posts[0].published); // Test with no query parameters - all optional let response = make_rest_request( diff --git a/crates/rest/ras-rest-macro/tests/xss_protection_test.rs b/crates/rest/ras-rest-macro/tests/xss_protection_test.rs index d4b284f..c1b1784 100644 --- a/crates/rest/ras-rest-macro/tests/xss_protection_test.rs +++ b/crates/rest/ras-rest-macro/tests/xss_protection_test.rs @@ -17,6 +17,15 @@ fn test_xss_protection_in_generated_html() { } } +#[test] +fn test_generated_docs_do_not_store_jwt_in_local_storage() { + let source = include_str!("../src/static_hosting.rs"); + assert!(!source.contains("localStorage.getItem('jwt-token')")); + assert!(!source.contains("localStorage.setItem('jwt-token'")); + assert!(!source.contains("localStorage.removeItem('jwt-token'")); + assert!(source.contains("sessionStorage.setItem('jwt-token'")); +} + fn escape_html(unsafe_str: &str) -> String { unsafe_str .replace('&', "&") diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs index 7ea454c..9c9e1cc 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs @@ -9,15 +9,12 @@ use tokio::sync::{RwLock, mpsc}; #[derive(Debug, Clone)] pub struct ChannelMessageSender { connection_id: ConnectionId, - sender: mpsc::UnboundedSender, + sender: mpsc::Sender, } impl ChannelMessageSender { /// Create a new channel message sender - pub fn new( - connection_id: ConnectionId, - sender: mpsc::UnboundedSender, - ) -> Self { + pub fn new(connection_id: ConnectionId, sender: mpsc::Sender) -> Self { Self { connection_id, sender, @@ -26,7 +23,7 @@ impl ChannelMessageSender { /// Send a message through the channel pub async fn send(&self, message: BidirectionalMessage) -> Result<(), String> { - self.sender.send(message).map_err(|e| e.to_string()) + self.sender.send(message).await.map_err(|e| e.to_string()) } /// Get the connection ID diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs index dbae66f..5e46482 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs @@ -93,7 +93,8 @@ pub struct WebSocketHandler { /// Connection context context: Arc, /// Channel for receiving messages to send to client - message_rx: mpsc::UnboundedReceiver, + message_rx: mpsc::Receiver, + max_message_size: usize, } impl WebSocketHandler { @@ -101,12 +102,14 @@ impl WebSocketHandler { pub fn new( handler: Arc, context: Arc, - message_rx: mpsc::UnboundedReceiver, + message_rx: mpsc::Receiver, + max_message_size: usize, ) -> Self { Self { handler, context, message_rx, + max_message_size, } } @@ -205,10 +208,22 @@ impl WebSocketHandler { ) -> ServerResult<()> { match msg { Message::Text(text) => { - debug!("Received text message: {}", text); + if text.len() > self.max_message_size { + warn!("Received oversized text message: {} bytes", text.len()); + return Err(ServerError::InvalidRequest( + "Message exceeds maximum size".to_string(), + )); + } + debug!("Received text message ({} bytes)", text.len()); self.handle_text_message(text.to_string(), socket).await } Message::Binary(data) => { + if data.len() > self.max_message_size { + warn!("Received oversized binary message: {} bytes", data.len()); + return Err(ServerError::InvalidRequest( + "Message exceeds maximum size".to_string(), + )); + } debug!("Received binary message ({} bytes)", data.len()); // Try to parse as UTF-8 text match String::from_utf8(data.to_vec()) { @@ -259,10 +274,9 @@ impl WebSocketHandler { } // If neither worked, return error - Err(ServerError::InvalidRequest(format!( - "Could not parse message as JSON-RPC or bidirectional message: {}", - text - ))) + Err(ServerError::InvalidRequest( + "Could not parse message as JSON-RPC or bidirectional message".to_string(), + )) } /// Handle bidirectional messages diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs index 1d9cbf2..ea9ddf0 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs @@ -85,7 +85,7 @@ impl DefaultConnectionManager { impl ConnectionManager for DefaultConnectionManager { async fn add_connection(&self, info: ConnectionInfo) -> Result<()> { // Create a dummy sender - real senders should be added via add_connection_with_sender - let (tx, _rx) = mpsc::unbounded_channel(); + let (tx, _rx) = mpsc::channel(1); let sender = ChannelMessageSender::new(info.id, tx); self.connections.insert(info.id, (info.clone(), sender)); info!("Added connection: {}", info.id); @@ -185,7 +185,7 @@ impl ConnectionManager for DefaultConnectionManager { // Update topic subscriptions self.subscriptions .entry(topic.clone()) - .or_insert_with(Vec::new) + .or_default() .push(id); // Update connection subscriptions @@ -234,7 +234,7 @@ impl ConnectionManager for DefaultConnectionManager { .1 .send(message) .await - .map_err(|e| ras_jsonrpc_bidirectional_types::BidirectionalError::SendError(e))?; + .map_err(ras_jsonrpc_bidirectional_types::BidirectionalError::SendError)?; } else { warn!("Attempted to send to non-existent connection: {}", id); } @@ -340,7 +340,7 @@ impl ConnectionManager for DefaultConnectionManager { ) -> Result<()> { self.pending_requests .entry(connection_id) - .or_insert_with(HashMap::new) + .or_default() .insert(request_id, response_sender); debug!( @@ -373,17 +373,16 @@ impl ConnectionManager for DefaultConnectionManager { connection_id: ConnectionId, response: ras_jsonrpc_types::JsonRpcResponse, ) -> Result { - if let Some(request_id) = &response.id { - if let Some(sender) = self + if let Some(request_id) = &response.id + && let Some(sender) = self .remove_pending_request(connection_id, request_id) .await? - { - if let Err(_) = sender.send(response) { - warn!("Failed to send response to pending request - receiver dropped"); - } - debug!("Handled pending response for connection: {}", connection_id); - return Ok(true); + { + if sender.send(response).is_err() { + warn!("Failed to send response to pending request - receiver dropped"); } + debug!("Handled pending response for connection: {}", connection_id); + return Ok(true); } Ok(false) } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs index 407046b..7c0030f 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs @@ -219,7 +219,7 @@ mod tests { // Create test context let connection_id = ConnectionId::new(); - let (tx, _rx) = mpsc::unbounded_channel(); + let (tx, _rx) = mpsc::channel(1); let sender = crate::connection::ChannelMessageSender::new(connection_id, tx); let context = Arc::new(ConnectionContext::new(connection_id, sender)); diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs index 3b85587..2ac6df4 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs @@ -16,7 +16,11 @@ use std::sync::Arc; use tokio::sync::mpsc; use tracing::{error, info}; +const DEFAULT_MESSAGE_CHANNEL_CAPACITY: usize = 1024; +const DEFAULT_MAX_MESSAGE_SIZE: usize = 1024 * 1024; + /// Trait for services that handle WebSocket JSON-RPC communication +#[allow(async_fn_in_trait)] pub trait WebSocketService: Clone + Send + Sync + 'static { /// The message handler type type Handler: MessageHandler; @@ -37,6 +41,16 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { /// Check if authentication is required fn require_auth(&self) -> bool; + /// Maximum queued outbound messages per connection. + fn message_channel_capacity(&self) -> usize { + DEFAULT_MESSAGE_CHANNEL_CAPACITY + } + + /// Maximum accepted inbound WebSocket message size in bytes. + fn max_message_size(&self) -> usize { + DEFAULT_MAX_MESSAGE_SIZE + } + /// Handle WebSocket upgrade async fn handle_upgrade( &self, @@ -73,7 +87,8 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { info!("New WebSocket connection: {}", connection_id); // Create message channel for this connection - let (message_tx, message_rx) = mpsc::unbounded_channel(); + let channel_capacity = service.message_channel_capacity().max(1); + let (message_tx, message_rx) = mpsc::channel(channel_capacity); let sender = ChannelMessageSender::new(connection_id, message_tx); // Create connection info and add to manager @@ -93,10 +108,15 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { .connection_manager() .add_connection_with_sender(info, Box::new(sender.clone())) .await - .map_err(|e| ServerError::ConnectionError(e))?; + .map_err(ServerError::ConnectionError)?; // Create and run WebSocket handler - let handler = WebSocketHandler::new(service.handler(), context.clone(), message_rx); + let handler = WebSocketHandler::new( + service.handler(), + context.clone(), + message_rx, + service.max_message_size(), + ); // Handle the connection (this will block until connection closes) let result = handler.run(socket).await; @@ -127,6 +147,12 @@ pub struct WebSocketServiceBuilder { /// Whether authentication is required #[builder(default = false)] require_auth: bool, + /// Maximum queued outbound messages per connection + #[builder(default = DEFAULT_MESSAGE_CHANNEL_CAPACITY)] + message_channel_capacity: usize, + /// Maximum accepted inbound WebSocket message size in bytes + #[builder(default = DEFAULT_MAX_MESSAGE_SIZE)] + max_message_size: usize, } impl WebSocketServiceBuilder @@ -143,6 +169,8 @@ where .connection_manager .unwrap_or_else(|| Arc::new(DefaultConnectionManager::new())), require_auth: self.require_auth, + message_channel_capacity: self.message_channel_capacity, + max_message_size: self.max_message_size, } } } @@ -160,6 +188,8 @@ where auth_provider: self.auth_provider, connection_manager: manager, require_auth: self.require_auth, + message_channel_capacity: self.message_channel_capacity, + max_message_size: self.max_message_size, } } } @@ -170,6 +200,8 @@ pub struct BuiltWebSocketService { auth_provider: Arc, connection_manager: Arc, require_auth: bool, + message_channel_capacity: usize, + max_message_size: usize, } impl Clone for BuiltWebSocketService { @@ -179,6 +211,8 @@ impl Clone for BuiltWebSocketService { auth_provider: self.auth_provider.clone(), connection_manager: self.connection_manager.clone(), require_auth: self.require_auth, + message_channel_capacity: self.message_channel_capacity, + max_message_size: self.max_message_size, } } } @@ -208,6 +242,14 @@ where fn require_auth(&self) -> bool { self.require_auth } + + fn message_channel_capacity(&self) -> usize { + self.message_channel_capacity + } + + fn max_message_size(&self) -> usize { + self.max_message_size + } } /// Convenience function to create a simple router-based service diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs index e51017d..98b34a4 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs @@ -26,31 +26,31 @@ impl WebSocketUpgrade { /// Extract authentication token from headers pub fn extract_auth_token(&self) -> Option { // Try Authorization header first (Bearer token) - if let Some(auth_header) = self.headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if auth_str.starts_with("Bearer ") { - return Some(auth_str[7..].to_string()); - } - // Also support just the token without "Bearer " prefix - return Some(auth_str.to_string()); + if let Some(auth_header) = self.headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + { + if let Some(token) = auth_str.strip_prefix("Bearer ") { + return Some(token.to_string()); } + // Also support just the token without "Bearer " prefix + return Some(auth_str.to_string()); } // Try custom WebSocket auth headers - if let Some(token_header) = self.headers.get("sec-websocket-protocol") { - if let Ok(token_str) = token_header.to_str() { - // Support protocols like "token.{jwt_token}" - if token_str.starts_with("token.") { - return Some(token_str[6..].to_string()); - } + if let Some(token_header) = self.headers.get("sec-websocket-protocol") + && let Ok(token_str) = token_header.to_str() + { + // Support protocols like "token.{jwt_token}" + if let Some(token) = token_str.strip_prefix("token.") { + return Some(token.to_string()); } } // Try X-Auth-Token header - if let Some(token_header) = self.headers.get("x-auth-token") { - if let Ok(token_str) = token_header.to_str() { - return Some(token_str.to_string()); - } + if let Some(token_header) = self.headers.get("x-auth-token") + && let Ok(token_str) = token_header.to_str() + { + return Some(token_str.to_string()); } None @@ -167,7 +167,7 @@ impl WebSocketUpgrade { for header_name in &ip_headers { if let Some(value) = self.get_header(header_name) { // For X-Forwarded-For, take the first IP - let ip = value.split(',').next().unwrap_or(&value).trim(); + let ip = value.split(',').next().unwrap_or(value.as_str()).trim(); if !ip.is_empty() { return Some(ip.to_string()); } @@ -220,22 +220,21 @@ mod tests { // Test Bearer token extraction logic headers.insert("authorization", "Bearer abc123".parse().unwrap()); - if let Some(auth_header) = headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if auth_str.starts_with("Bearer ") { - assert_eq!(&auth_str[7..], "abc123"); - } - } + if let Some(auth_header) = headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + && let Some(token) = auth_str.strip_prefix("Bearer ") + { + assert_eq!(token, "abc123"); } // Test X-Forwarded-For parsing logic headers.clear(); headers.insert("x-forwarded-for", "192.168.1.1, 10.0.0.1".parse().unwrap()); - if let Some(header_value) = headers.get("x-forwarded-for") { - if let Ok(value) = header_value.to_str() { - let ip = value.split(',').next().unwrap_or(&value).trim(); - assert_eq!(ip, "192.168.1.1"); - } + if let Some(header_value) = headers.get("x-forwarded-for") + && let Ok(value) = header_value.to_str() + { + let ip = value.split(',').next().unwrap_or(value).trim(); + assert_eq!(ip, "192.168.1.1"); } } diff --git a/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs b/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs index aa5a710..180cd42 100644 --- a/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs +++ b/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs @@ -178,15 +178,15 @@ impl UserManagementServiceTrait for UserManagementServiceImpl { let filtered_users: Vec = users .into_iter() .filter(|u| { - if let Some(pattern) = &req.username_pattern { - if !u.username.contains(pattern) { - return false; - } + if let Some(pattern) = &req.username_pattern + && !u.username.contains(pattern) + { + return false; } - if let Some(pattern) = &req.email_pattern { - if !u.email.contains(pattern) { - return false; - } + if let Some(pattern) = &req.email_pattern + && !u.email.contains(pattern) + { + return false; } true }) diff --git a/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html b/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html index 6194b56..9d0eb48 100644 --- a/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html +++ b/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html @@ -571,6 +571,10 @@
+
@@ -709,7 +713,7 @@

Select a method

// Global state let openrpcDoc = null; let currentMethod = null; - let jwtToken = localStorage.getItem('jwt-token') || ''; + let jwtToken = ''; // Initialize the application document.addEventListener('DOMContentLoaded', async () => { @@ -736,9 +740,13 @@

Select a method

function initializeAuth() { const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); + const rememberedToken = sessionStorage.getItem('jwt-token') || ''; - if (jwtToken) { + if (rememberedToken) { + jwtToken = rememberedToken; tokenInput.value = jwtToken; + rememberToken.checked = true; authStatus.classList.add('authenticated'); } } @@ -772,13 +780,17 @@

Select a method

function saveToken() { const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); jwtToken = tokenInput.value.trim(); - localStorage.setItem('jwt-token', jwtToken); + sessionStorage.removeItem('jwt-token'); + if (jwtToken && rememberToken.checked) { + sessionStorage.setItem('jwt-token', jwtToken); + } if (jwtToken) { authStatus.classList.add('authenticated'); - showToast('Token saved successfully', 'success'); + showToast(rememberToken.checked ? 'Token saved for this tab' : 'Token ready for this page', 'success'); } else { authStatus.classList.remove('authenticated'); } @@ -787,8 +799,9 @@

Select a method

// Clear JWT token function clearToken() { jwtToken = ''; - localStorage.removeItem('jwt-token'); + sessionStorage.removeItem('jwt-token'); document.getElementById('jwt-token').value = ''; + document.getElementById('remember-token').checked = false; document.getElementById('auth-status').classList.remove('authenticated'); showToast('Token cleared', 'success'); } @@ -1196,4 +1209,4 @@

Authentication Required

} - \ No newline at end of file + diff --git a/crates/rpc/ras-jsonrpc-macro/tests/explorer_token_storage_test.rs b/crates/rpc/ras-jsonrpc-macro/tests/explorer_token_storage_test.rs new file mode 100644 index 0000000..cd5704a --- /dev/null +++ b/crates/rpc/ras-jsonrpc-macro/tests/explorer_token_storage_test.rs @@ -0,0 +1,8 @@ +#[test] +fn test_generated_explorer_does_not_store_jwt_in_local_storage() { + let template = include_str!("../src/jsonrpc_explorer_template.html"); + assert!(!template.contains("localStorage.getItem('jwt-token')")); + assert!(!template.contains("localStorage.setItem('jwt-token'")); + assert!(!template.contains("localStorage.removeItem('jwt-token'")); + assert!(template.contains("sessionStorage.setItem('jwt-token'")); +} diff --git a/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs b/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs index 79f4cc2..3d7b370 100644 --- a/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs +++ b/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs @@ -1,7 +1,6 @@ use rand::Rng; use ras_jsonrpc_core::{AuthError, AuthFuture, AuthProvider, AuthenticatedUser}; use ras_jsonrpc_macro::jsonrpc_service; -use reqwest; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::collections::HashSet; @@ -61,12 +60,6 @@ struct ProcessingResult { success: bool, } -#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] -struct ErrorResponse { - error: String, - details: Option, -} - // Simple test auth provider struct TestAuthProvider { valid_tokens: HashSet, diff --git a/crates/specs/openrpc-types/src/schema.rs b/crates/specs/openrpc-types/src/schema.rs index 7b776b4..68c6431 100644 --- a/crates/specs/openrpc-types/src/schema.rs +++ b/crates/specs/openrpc-types/src/schema.rs @@ -484,45 +484,45 @@ impl Default for Schema { impl Validate for Schema { fn validate(&self) -> OpenRpcResult<()> { // Validate numeric constraints - if let (Some(min), Some(max)) = (self.minimum, self.maximum) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minimum cannot be greater than maximum", - )); - } + if let (Some(min), Some(max)) = (self.minimum, self.maximum) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minimum cannot be greater than maximum", + )); } - if let (Some(min), Some(max)) = (self.min_length, self.max_length) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minLength cannot be greater than maxLength", - )); - } + if let (Some(min), Some(max)) = (self.min_length, self.max_length) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minLength cannot be greater than maxLength", + )); } - if let (Some(min), Some(max)) = (self.min_items, self.max_items) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minItems cannot be greater than maxItems", - )); - } + if let (Some(min), Some(max)) = (self.min_items, self.max_items) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minItems cannot be greater than maxItems", + )); } - if let (Some(min), Some(max)) = (self.min_properties, self.max_properties) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minProperties cannot be greater than maxProperties", - )); - } + if let (Some(min), Some(max)) = (self.min_properties, self.max_properties) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minProperties cannot be greater than maxProperties", + )); } // Validate multipleOf - if let Some(multiple_of) = self.multiple_of { - if multiple_of <= 0.0 { - return Err(crate::error::OpenRpcError::validation( - "multipleOf must be greater than 0", - )); - } + if let Some(multiple_of) = self.multiple_of + && multiple_of <= 0.0 + { + return Err(crate::error::OpenRpcError::validation( + "multipleOf must be greater than 0", + )); } // Validate pattern if present diff --git a/crates/tools/openrpc-to-bruno/src/converter.rs b/crates/tools/openrpc-to-bruno/src/converter.rs index 6a500f5..9a02f56 100644 --- a/crates/tools/openrpc-to-bruno/src/converter.rs +++ b/crates/tools/openrpc-to-bruno/src/converter.rs @@ -6,6 +6,7 @@ use openrpc_types::{ Schema, SchemaType, }; use serde_json::{Map, Value}; +use std::path::Path; use tokio::fs; pub struct OpenRpcToBrunoConverter { @@ -240,7 +241,8 @@ impl OpenRpcToBrunoConverter { } let content = bruno_request.to_bru_format(); - let path = self.args.output.join(format!("{}.bru", method.name)); + let file_name = safe_method_file_name(&method.name, sequence)?; + let path = self.safe_output_path(&file_name).await?; fs::write(&path, content) .await @@ -252,6 +254,31 @@ impl OpenRpcToBrunoConverter { Ok(()) } + async fn safe_output_path(&self, file_name: &str) -> Result { + let output_dir = + fs::canonicalize(&self.args.output) + .await + .map_err(|e| ToolError::OutputDirCreate { + path: self.args.output.clone(), + source: e, + })?; + + let candidate = output_dir.join(file_name); + let parent = candidate.parent().unwrap_or(&output_dir); + let parent = fs::canonicalize(parent) + .await + .map_err(|e| ToolError::OutputDirCreate { + path: parent.to_path_buf(), + source: e, + })?; + + if !parent.starts_with(&output_dir) { + return Err(ToolError::UnsafeMethodName(file_name.to_string())); + } + + Ok(candidate) + } + fn create_jsonrpc_request_body(&self, method: &Method) -> Result { let mut request = Map::new(); request.insert("jsonrpc".to_string(), Value::String("2.0".to_string())); @@ -331,3 +358,34 @@ impl OpenRpcToBrunoConverter { } } } + +fn safe_method_file_name(method_name: &str, sequence: u32) -> Result { + let name = method_name.trim(); + if name.is_empty() + || name == "." + || name == ".." + || name.contains('\0') + || name.contains('/') + || name.contains('\\') + || Path::new(name).is_absolute() + { + return Err(ToolError::UnsafeMethodName(method_name.to_string())); + } + + let sanitized: String = name + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || matches!(ch, '.' | '_' | '-') { + ch + } else { + '_' + } + }) + .collect(); + + if sanitized.is_empty() || sanitized == "." || sanitized == ".." { + return Err(ToolError::UnsafeMethodName(method_name.to_string())); + } + + Ok(format!("{sequence:03}_{sanitized}.bru")) +} diff --git a/crates/tools/openrpc-to-bruno/src/error.rs b/crates/tools/openrpc-to-bruno/src/error.rs index f2001a5..7f3309a 100644 --- a/crates/tools/openrpc-to-bruno/src/error.rs +++ b/crates/tools/openrpc-to-bruno/src/error.rs @@ -32,6 +32,9 @@ pub enum ToolError { source: std::io::Error, }, + #[error("Unsafe OpenRPC method name for file generation: {0}")] + UnsafeMethodName(String), + #[error("Invalid base URL: {0}")] #[allow(dead_code)] InvalidBaseUrl(String), diff --git a/crates/tools/openrpc-to-bruno/tests/integration.rs b/crates/tools/openrpc-to-bruno/tests/integration.rs index 7430aad..15f62a7 100644 --- a/crates/tools/openrpc-to-bruno/tests/integration.rs +++ b/crates/tools/openrpc-to-bruno/tests/integration.rs @@ -35,8 +35,10 @@ async fn test_conversion( assert!(env_file.exists(), "environment file should be created"); // Check that method files were created - for method in expected_methods { - let method_file = output_dir.path().join(format!("{}.bru", method)); + for (index, method) in expected_methods.iter().enumerate() { + let method_file = output_dir + .path() + .join(format!("{:03}_{}.bru", index + 1, method)); assert!( method_file.exists(), "method file {} should be created", @@ -62,6 +64,92 @@ async fn test_conversion( Ok(()) } +#[tokio::test] +async fn test_rejects_path_traversal_method_name() { + use clap::Parser; + use openrpc_to_bruno::{cli::Args, error::ToolError}; + + let temp = tempdir().unwrap(); + let input_path = temp.path().join("openrpc.json"); + let output_dir = temp.path().join("out"); + let escaped = temp.path().join("evil.bru"); + + let spec = serde_json::json!({ + "openrpc": "1.3.2", + "info": { + "title": "Unsafe API", + "version": "1.0.0" + }, + "methods": [{ + "name": "../evil", + "params": [], + "result": { + "name": "result", + "schema": { "type": "string" } + } + }] + }); + fs::write(&input_path, serde_json::to_vec(&spec).unwrap()) + .await + .unwrap(); + + let args = Args::try_parse_from(vec![ + "openrpc-to-bruno", + "--input", + input_path.to_str().unwrap(), + "--output", + output_dir.to_str().unwrap(), + "--force", + ]) + .unwrap(); + + let err = args.run().await.unwrap_err(); + assert!(matches!(err, ToolError::UnsafeMethodName(_))); + assert!(!escaped.exists()); +} + +#[tokio::test] +async fn test_sanitizes_safe_method_filename() { + use clap::Parser; + use openrpc_to_bruno::cli::Args; + + let temp = tempdir().unwrap(); + let input_path = temp.path().join("openrpc.json"); + let output_dir = temp.path().join("out"); + + let spec = serde_json::json!({ + "openrpc": "1.3.2", + "info": { + "title": "Safe API", + "version": "1.0.0" + }, + "methods": [{ + "name": "system status", + "params": [], + "result": { + "name": "result", + "schema": { "type": "string" } + } + }] + }); + fs::write(&input_path, serde_json::to_vec(&spec).unwrap()) + .await + .unwrap(); + + let args = Args::try_parse_from(vec![ + "openrpc-to-bruno", + "--input", + input_path.to_str().unwrap(), + "--output", + output_dir.to_str().unwrap(), + "--force", + ]) + .unwrap(); + + args.run().await.unwrap(); + assert!(output_dir.join("001_system_status.bru").exists()); +} + #[tokio::test] async fn test_simple_conversion() { test_conversion("simple-api-basic.json", &["hello"]) diff --git a/examples/bidirectional-chat/server/src/main.rs b/examples/bidirectional-chat/server/src/main.rs index 52cf3a1..840cb04 100644 --- a/examples/bidirectional-chat/server/src/main.rs +++ b/examples/bidirectional-chat/server/src/main.rs @@ -1500,6 +1500,7 @@ async fn main() -> Result<()> { ); let session_service = Arc::new( SessionService::new(session_config) + .map_err(anyhow::Error::from)? .with_permissions(Arc::new(ChatPermissions::new(config.admin.users.clone()))), ); diff --git a/examples/bidirectional-chat/server/tests/server_tests.rs b/examples/bidirectional-chat/server/tests/server_tests.rs index 29f7a15..d784052 100644 --- a/examples/bidirectional-chat/server/tests/server_tests.rs +++ b/examples/bidirectional-chat/server/tests/server_tests.rs @@ -123,7 +123,7 @@ async fn create_test_config() -> Result<(Config, TempDir)> { cors: Default::default(), }, auth: AuthConfig { - jwt_secret: "test-secret-key".to_string(), + jwt_secret: "test-secret-key-that-is-long-enough".to_string(), jwt_ttl_seconds: 3600, refresh_enabled: true, jwt_algorithm: "HS256".to_string(), diff --git a/examples/bidirectional-chat/server/tests/websocket_tests.rs b/examples/bidirectional-chat/server/tests/websocket_tests.rs index f9e7b9a..d983b4c 100644 --- a/examples/bidirectional-chat/server/tests/websocket_tests.rs +++ b/examples/bidirectional-chat/server/tests/websocket_tests.rs @@ -58,7 +58,7 @@ impl TestChatServer { cors: Default::default(), }, auth: AuthConfig { - jwt_secret: "test-secret-key".to_string(), + jwt_secret: "test-secret-key-that-is-long-enough".to_string(), jwt_ttl_seconds: 3600, refresh_enabled: true, jwt_algorithm: "HS256".to_string(), @@ -145,9 +145,13 @@ impl TestChatServer { algorithm: jsonwebtoken::Algorithm::HS256, }; - let session_service = Arc::new(SessionService::new(session_config).with_permissions( - Arc::new(TestChatPermissions::new(config.admin.users.clone())), - )); + let session_service = Arc::new( + SessionService::new(session_config) + .unwrap() + .with_permissions(Arc::new(TestChatPermissions::new( + config.admin.users.clone(), + ))), + ); // Register identity provider with session service let session_identity_provider = LocalUserProvider::new(); diff --git a/examples/oauth2-demo/server/src/main.rs b/examples/oauth2-demo/server/src/main.rs index 971e37c..038be11 100644 --- a/examples/oauth2-demo/server/src/main.rs +++ b/examples/oauth2-demo/server/src/main.rs @@ -45,7 +45,7 @@ impl AppConfig { redirect_uri: std::env::var("REDIRECT_URI") .unwrap_or_else(|_| "http://localhost:3000/auth/callback".to_string()), jwt_secret: std::env::var("JWT_SECRET") - .unwrap_or_else(|_| "change-me-in-production-please".to_string()), + .context("JWT_SECRET environment variable is required")?, server_host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()), server_port: std::env::var("SERVER_PORT") .unwrap_or_else(|_| "3000".to_string()) @@ -164,8 +164,9 @@ fn create_session_service(config: &AppConfig) -> Result { }; let permissions_provider = Arc::new(GoogleOAuth2Permissions::new()); - let session_service = - SessionService::new(session_config).with_permissions(permissions_provider); + let session_service = SessionService::new(session_config) + .map_err(anyhow::Error::from)? + .with_permissions(permissions_provider); Ok(session_service) } diff --git a/examples/rest-wasm-example/rest-api/Cargo.toml b/examples/rest-wasm-example/rest-api/Cargo.toml index a7ac0a6..fa68cfe 100644 --- a/examples/rest-wasm-example/rest-api/Cargo.toml +++ b/examples/rest-wasm-example/rest-api/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [lib] -crate-type = ["cdylib", "rlib"] +crate-type = ["rlib"] [dependencies] ras-rest-macro = { path = "../../../crates/rest/ras-rest-macro" } From 3215170cfb659ce70961fdccc2f2e1ee04af8087 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Sat, 25 Apr 2026 09:37:00 +0200 Subject: [PATCH 2/6] Add ras-test-helpers crate and per-macro e2e + bench suites Introduces a shared dev-only crate at crates/test-utils/ras-test-helpers with MockAuthProvider, mock_user, and spawn_http/spawn_tcp helpers, then wires per-macro tests/e2e.rs and benches/ for jsonrpc_service!, file_service!, and jsonrpc_bidirectional_service!. Each suite drives the generated reqwest/tokio-tungstenite client all the way through the axum router and back, plus a criterion bench measuring per-call latency. REST macro coverage lands in the next commit so its query-param e2e cases can land together with the query-param client codegen. Co-Authored-By: Claude Opus 4.7 (1M context) --- Cargo.lock | 226 +++++++++++++++++- Cargo.toml | 2 + crates/rest/ras-file-macro/Cargo.toml | 8 + .../rest/ras-file-macro/benches/streaming.rs | 131 ++++++++++ crates/rest/ras-file-macro/tests/e2e.rs | 163 +++++++++++++ .../Cargo.toml | 8 +- .../benches/roundtrip.rs | 118 +++++++++ .../tests/e2e.rs | 202 ++++++++++++++++ crates/rpc/ras-jsonrpc-macro/Cargo.toml | 9 + .../rpc/ras-jsonrpc-macro/benches/dispatch.rs | 66 +++++ crates/rpc/ras-jsonrpc-macro/tests/e2e.rs | 188 +++++++++++++++ crates/test-utils/ras-test-helpers/Cargo.toml | 12 + .../test-utils/ras-test-helpers/src/auth.rs | 133 +++++++++++ crates/test-utils/ras-test-helpers/src/lib.rs | 11 + .../test-utils/ras-test-helpers/src/server.rs | 40 ++++ 15 files changed, 1310 insertions(+), 7 deletions(-) create mode 100644 crates/rest/ras-file-macro/benches/streaming.rs create mode 100644 crates/rest/ras-file-macro/tests/e2e.rs create mode 100644 crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/benches/roundtrip.rs create mode 100644 crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/tests/e2e.rs create mode 100644 crates/rpc/ras-jsonrpc-macro/benches/dispatch.rs create mode 100644 crates/rpc/ras-jsonrpc-macro/tests/e2e.rs create mode 100644 crates/test-utils/ras-test-helpers/Cargo.toml create mode 100644 crates/test-utils/ras-test-helpers/src/auth.rs create mode 100644 crates/test-utils/ras-test-helpers/src/lib.rs create mode 100644 crates/test-utils/ras-test-helpers/src/server.rs diff --git a/Cargo.lock b/Cargo.lock index 052ce78..03834dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.19" @@ -538,6 +544,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "castaway" version = "0.2.3" @@ -583,6 +595,33 @@ dependencies = [ "windows-link", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.5.39" @@ -762,6 +801,63 @@ dependencies = [ "libc", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "futures", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "tokio", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1644,6 +1740,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if 1.0.0", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -2026,12 +2133,32 @@ dependencies = [ "serde", ] +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -2437,6 +2564,12 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openrpc-to-bruno" version = "0.1.0" @@ -2814,6 +2947,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -2986,14 +3147,18 @@ version = "0.1.0" dependencies = [ "async-trait", "axum", + "axum-test", + "criterion", "proc-macro2", "quote", "ras-auth-core", + "ras-test-helpers", "reqwest", "schemars", "serde", "serde_json", "syn", + "tempfile", "thiserror 2.0.12", "tokio", "tokio-util", @@ -3007,6 +3172,7 @@ dependencies = [ "serde", "serde_json", "thiserror 2.0.12", + "tokio", ] [[package]] @@ -3105,6 +3271,7 @@ dependencies = [ "async-trait", "axum", "chrono", + "criterion", "futures", "http", "proc-macro2", @@ -3115,6 +3282,7 @@ dependencies = [ "ras-jsonrpc-bidirectional-server", "ras-jsonrpc-bidirectional-types", "ras-jsonrpc-types", + "ras-test-helpers", "serde", "serde_json", "syn", @@ -3182,7 +3350,9 @@ version = "0.1.1" dependencies = [ "async-trait", "axum", + "axum-test", "bon", + "criterion", "futures", "proc-macro2", "quote", @@ -3191,6 +3361,7 @@ dependencies = [ "ras-identity-session", "ras-jsonrpc-core", "ras-jsonrpc-types", + "ras-test-helpers", "reqwest", "schemars", "serde", @@ -3256,7 +3427,9 @@ dependencies = [ "async-trait", "axum", "axum-extra", + "axum-test", "chrono", + "criterion", "futures", "hyper", "proc-macro2", @@ -3266,6 +3439,7 @@ dependencies = [ "ras-identity-session", "ras-jsonrpc-core", "ras-rest-core", + "ras-test-helpers", "reqwest", "schemars", "serde", @@ -3277,6 +3451,16 @@ dependencies = [ "wiremock", ] +[[package]] +name = "ras-test-helpers" +version = "0.0.0" +dependencies = [ + "axum", + "axum-test", + "ras-auth-core", + "tokio", +] + [[package]] name = "ratatui" version = "0.29.0" @@ -3289,7 +3473,7 @@ dependencies = [ "crossterm", "indoc", "instability", - "itertools", + "itertools 0.13.0", "lru", "paste", "strum", @@ -3298,6 +3482,26 @@ dependencies = [ "unicode-width 0.2.0", ] +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.12" @@ -4188,6 +4392,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tokio" version = "1.45.1" @@ -4534,7 +4748,7 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3644627a5af5fa321c95b9b235a72fd24cd29c648c2c379431e6628655627bf" dependencies = [ - "itertools", + "itertools 0.13.0", "unicode-segmentation", "unicode-width 0.1.14", ] @@ -5085,18 +5299,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index ee05510..a4bf3cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/rpc/ras-jsonrpc-macro", "crates/rpc/ras-jsonrpc-types", "crates/specs/*", + "crates/test-utils/*", "crates/tools/*", "examples/basic-jsonrpc/*", "examples/bidirectional-chat/api", @@ -32,6 +33,7 @@ base64 = "0.22" bon = "3.2" console = "0.15" console_error_panic_hook = "0.1" +criterion = "0.5" crossterm = "0.28" dashmap = "6.1" dialoguer = "0.11" diff --git a/crates/rest/ras-file-macro/Cargo.toml b/crates/rest/ras-file-macro/Cargo.toml index 96c205c..243e3a9 100644 --- a/crates/rest/ras-file-macro/Cargo.toml +++ b/crates/rest/ras-file-macro/Cargo.toml @@ -23,3 +23,11 @@ serde_json = { workspace = true } ras-auth-core = { path = "../../core/ras-auth-core" } thiserror = { workspace = true } async-trait = { workspace = true } +ras-test-helpers = { path = "../../test-utils/ras-test-helpers" } +axum-test = { workspace = true } +tempfile = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "streaming" +harness = false diff --git a/crates/rest/ras-file-macro/benches/streaming.rs b/crates/rest/ras-file-macro/benches/streaming.rs new file mode 100644 index 0000000..d6981a8 --- /dev/null +++ b/crates/rest/ras-file-macro/benches/streaming.rs @@ -0,0 +1,131 @@ +//! Criterion bench measuring 1 MiB upload + download through the file_service! +//! generated client and router. + +use std::io::Write; +use std::sync::{Arc, Mutex}; + +use axum::{ + body::Body, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_auth_core::AuthenticatedUser; +use ras_file_macro::file_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct UploadResponse { + file_id: String, + size: u64, +} + +file_service!({ + service_name: BenchSvc, + base_path: "/files", + endpoints: [ + DOWNLOAD UNAUTHORIZED download/{file_id: String}(), + UPLOAD WITH_PERMISSIONS(["user"]) upload() -> UploadResponse, + ] +}); + +type Storage = Arc)>>>; + +#[derive(Clone)] +struct BenchImpl { + storage: Storage, +} + +#[async_trait::async_trait] +impl BenchSvcTrait for BenchImpl { + async fn download(&self, file_id: String) -> Result { + let bytes = self + .storage + .lock() + .unwrap() + .iter() + .find_map(|(id, data)| (id == &file_id).then(|| data.clone())) + .ok_or(BenchSvcFileError::NotFound)?; + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(bytes)) + .unwrap()) + } + + async fn upload( + &self, + _user: &AuthenticatedUser, + mut multipart: axum::extract::Multipart, + ) -> Result { + let field = multipart + .next_field() + .await + .map_err(|e| BenchSvcFileError::UploadFailed(e.to_string()))? + .ok_or_else(|| BenchSvcFileError::UploadFailed("no field".into()))?; + let data = field + .bytes() + .await + .map_err(|e| BenchSvcFileError::UploadFailed(e.to_string()))?; + let id = format!("file-{}", self.storage.lock().unwrap().len()); + let size = data.len() as u64; + self.storage + .lock() + .unwrap() + .push((id.clone(), data.to_vec())); + Ok(UploadResponse { file_id: id, size }) + } +} + +fn build_router() -> (axum::Router, Storage) { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let router = BenchSvcBuilder::::new(BenchImpl { + storage: storage.clone(), + }) + .auth_provider(MockAuthProvider::default()) + .build(); + (router, storage) +} + +fn bench_streaming(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + // Prepare 1 MiB payload on disk and a live server. + let mut tmp = tempfile::NamedTempFile::new().expect("tempfile"); + let payload: Vec = (0u8..=255).cycle().take(1024 * 1024).collect(); + tmp.write_all(&payload).unwrap(); + tmp.flush().unwrap(); + let path = tmp.path().to_path_buf(); + + let (client, _server) = rt.block_on(async { + let (router, _storage) = build_router(); + let server = spawn_http(router); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + let client = BenchSvcClient::builder(base_str) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + (client, server) + }); + + c.bench_function("file_upload_download_1mib", |b| { + b.to_async(&rt).iter(|| { + let client = &client; + let path = path.clone(); + async move { + let r = client + .upload(&path, Some("blob.bin"), Some("application/octet-stream")) + .await + .expect("upload"); + let resp = client.download(r.file_id).await.expect("download"); + let bytes = resp.bytes().await.expect("body"); + std::hint::black_box(bytes); + } + }); + }); +} + +criterion_group!(benches, bench_streaming); +criterion_main!(benches); diff --git a/crates/rest/ras-file-macro/tests/e2e.rs b/crates/rest/ras-file-macro/tests/e2e.rs new file mode 100644 index 0000000..c0c64f8 --- /dev/null +++ b/crates/rest/ras-file-macro/tests/e2e.rs @@ -0,0 +1,163 @@ +//! End-to-end test for the file_service! macro: generated reqwest client → +//! axum router → handler. Exercises upload + download with byte-equality and +//! a missing-token rejection. + +use std::io::Write; +use std::sync::{Arc, Mutex}; + +use axum::{ + body::Body, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use ras_auth_core::AuthenticatedUser; +use ras_file_macro::file_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct UploadResponse { + file_id: String, + size: u64, +} + +file_service!({ + service_name: Demo, + base_path: "/files", + endpoints: [ + DOWNLOAD UNAUTHORIZED download/{file_id: String}(), + UPLOAD WITH_PERMISSIONS(["user"]) upload() -> UploadResponse, + ] +}); + +type Storage = Arc)>>>; + +#[derive(Clone)] +struct DemoImpl { + storage: Storage, +} + +#[async_trait::async_trait] +impl DemoTrait for DemoImpl { + async fn download(&self, file_id: String) -> Result { + let store = self.storage.lock().unwrap(); + let bytes = store + .iter() + .find_map(|(id, data)| (id == &file_id).then(|| data.clone())) + .ok_or(DemoFileError::NotFound)?; + + Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/octet-stream") + .body(Body::from(bytes)) + .unwrap()) + } + + async fn upload( + &self, + _user: &AuthenticatedUser, + mut multipart: axum::extract::Multipart, + ) -> Result { + let field = multipart + .next_field() + .await + .map_err(|e| DemoFileError::UploadFailed(e.to_string()))? + .ok_or_else(|| DemoFileError::UploadFailed("no field".into()))?; + let data = field + .bytes() + .await + .map_err(|e| DemoFileError::UploadFailed(e.to_string()))?; + let id = format!("file-{}", self.storage.lock().unwrap().len()); + let size = data.len() as u64; + self.storage + .lock() + .unwrap() + .push((id.clone(), data.to_vec())); + Ok(UploadResponse { file_id: id, size }) + } +} + +fn router(storage: Storage) -> axum::Router { + DemoBuilder::::new(DemoImpl { storage }) + .auth_provider(MockAuthProvider::default()) + .build() +} + +fn write_tempfile(bytes: &[u8]) -> tempfile::NamedTempFile { + let mut f = tempfile::NamedTempFile::new().expect("tempfile"); + f.write_all(bytes).expect("write tempfile"); + f.flush().expect("flush tempfile"); + f +} + +#[tokio::test] +async fn upload_and_download_round_trips_bytes() { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let server = spawn_http(router(storage.clone())); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + + let payload: Vec = (0u8..=255).cycle().take(64 * 1024).collect(); + let tmp = write_tempfile(&payload); + + let client = DemoClient::builder(base_str.clone()) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + + let upload = client + .upload( + tmp.path(), + Some("blob.bin"), + Some("application/octet-stream"), + ) + .await + .expect("upload ok"); + assert_eq!(upload.size, payload.len() as u64); + + let resp = client.download(upload.file_id).await.expect("download ok"); + let bytes = resp.bytes().await.expect("read body"); + assert_eq!(bytes.as_ref(), payload.as_slice()); +} + +#[tokio::test] +async fn upload_rejected_without_token() { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let server = spawn_http(router(storage)); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + + let payload = b"hello world"; + let tmp = write_tempfile(payload); + + let client = DemoClient::builder(base_str).build().expect("client build"); + // No bearer token. + + let result = client + .upload(tmp.path(), Some("hi.txt"), Some("text/plain")) + .await; + // The server short-circuits with 401 before consuming the multipart body, + // so reqwest may surface that either as the parsed status or as a generic + // connection error depending on how the upload stream was cut. Either is a + // valid signal of rejection — the only outcome we want to rule out is + // success. + assert!( + result.is_err(), + "upload must be rejected without a bearer token, got: {result:?}" + ); +} + +#[tokio::test] +async fn download_unknown_file_returns_404() { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let server = spawn_http(router(storage)); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + + let client = DemoClient::builder(base_str).build().expect("client build"); + let err = client + .download("does-not-exist".to_string()) + .await + .expect_err("missing file must error"); + assert!(err.to_string().contains("404"), "got: {err}"); +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml index 123ca6b..175447b 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml @@ -24,6 +24,7 @@ ras-jsonrpc-bidirectional-server = { path = "../ras-jsonrpc-bidirectional-server ras-jsonrpc-bidirectional-client = { path = "../ras-jsonrpc-bidirectional-client" } ras-auth-core = { path = "../../../core/ras-auth-core" } ras-jsonrpc-types = { path = "../../ras-jsonrpc-types" } +ras-test-helpers = { path = "../../../test-utils/ras-test-helpers" } tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -37,4 +38,9 @@ async-trait = { workspace = true } uuid = { workspace = true } anyhow = { workspace = true } thiserror = { workspace = true } -chrono = { workspace = true } \ No newline at end of file +chrono = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "roundtrip" +harness = false \ No newline at end of file diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/benches/roundtrip.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/benches/roundtrip.rs new file mode 100644 index 0000000..c3ce50a --- /dev/null +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/benches/roundtrip.rs @@ -0,0 +1,118 @@ +//! Criterion bench measuring c2s call round-trip latency through a real +//! WebSocket connection (tokio-tungstenite client → axum server). + +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use axum::{Router, routing::get}; +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_auth_core::AuthenticatedUser; +use ras_jsonrpc_bidirectional_macro::jsonrpc_bidirectional_service; +use ras_jsonrpc_bidirectional_server::DefaultConnectionManager; +use ras_jsonrpc_bidirectional_server::service::{BuiltWebSocketService, websocket_handler}; +use ras_jsonrpc_bidirectional_types::ConnectionId; +use ras_test_helpers::{MockAuthProvider, spawn_tcp}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct EchoIn { + msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct EchoOut { + msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Ignored; + +jsonrpc_bidirectional_service!({ + service_name: BenchSvc, + client_to_server: [ + WITH_PERMISSIONS(["user"]) echo(EchoIn) -> EchoOut, + ], + server_to_client: [ + unused(Ignored), + ], + server_to_client_calls: [ + ] +}); + +#[derive(Clone)] +struct BenchImpl; + +#[async_trait] +impl BenchSvcService for BenchImpl { + async fn echo( + &self, + _client: ConnectionId, + _conns: &dyn ras_jsonrpc_bidirectional_types::ConnectionManager, + _user: &AuthenticatedUser, + req: EchoIn, + ) -> Result> { + Ok(EchoOut { msg: req.msg }) + } + + async fn notify_unused( + &self, + _connection_id: ConnectionId, + _params: Ignored, + ) -> ras_jsonrpc_bidirectional_types::Result<()> { + Ok(()) + } +} + +async fn start_server() -> String { + let cm = Arc::new(DefaultConnectionManager::new()); + let handler = Arc::new(BenchSvcHandler::new(Arc::new(BenchImpl), cm.clone())); + let svc = ras_jsonrpc_bidirectional_server::WebSocketServiceBuilder::builder() + .handler(handler) + .auth_provider(Arc::new(MockAuthProvider::default())) + .require_auth(false) + .build() + .build_with_manager(cm); + + type SvcType = BuiltWebSocketService< + BenchSvcHandler, + MockAuthProvider, + DefaultConnectionManager, + >; + let app: Router = Router::new() + .route("/ws", get(websocket_handler::)) + .with_state(svc); + let (addr, _h) = spawn_tcp(app).await; + tokio::time::sleep(Duration::from_millis(50)).await; + format!("ws://{addr}/ws") +} + +fn bench_roundtrip(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let client = rt.block_on(async { + let url = start_server().await; + let client = BenchSvcClientBuilder::new(url) + .with_jwt_token("user-token".to_string()) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + client + }); + + c.bench_function("ws_echo_roundtrip", |b| { + b.to_async(&rt).iter(|| async { + let r = client.echo(EchoIn { msg: "x".into() }).await.expect("echo"); + std::hint::black_box(r); + }); + }); + + rt.block_on(async { + let _ = client.disconnect().await; + }); +} + +criterion_group!(benches, bench_roundtrip); +criterion_main!(benches); diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/tests/e2e.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/tests/e2e.rs new file mode 100644 index 0000000..8d3d917 --- /dev/null +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/tests/e2e.rs @@ -0,0 +1,202 @@ +//! End-to-end test for `jsonrpc_bidirectional_service!`: +//! generated client → real WebSocket → server handler → response/notification. +//! +//! Existing `bidirectional_integration.rs` exercises this thoroughly. This file +//! is a slim companion test that uses the shared `MockAuthProvider` and proves +//! the helper integration works. + +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use axum::{Router, routing::get}; +use ras_auth_core::AuthenticatedUser; +use ras_jsonrpc_bidirectional_macro::jsonrpc_bidirectional_service; +use ras_jsonrpc_bidirectional_server::DefaultConnectionManager; +use ras_jsonrpc_bidirectional_server::service::{BuiltWebSocketService, websocket_handler}; +use ras_jsonrpc_bidirectional_types::ConnectionId; +use ras_test_helpers::{MockAuthProvider, spawn_tcp}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EchoIn { + pub msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EchoOut { + pub msg: String, + pub user: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PushNote { + pub kind: String, +} + +jsonrpc_bidirectional_service!({ + service_name: Demo, + client_to_server: [ + UNAUTHORIZED hello(String) -> String, + WITH_PERMISSIONS(["user"]) echo(EchoIn) -> EchoOut, + ], + server_to_client: [ + ping(PushNote), + ], + server_to_client_calls: [ + ] +}); + +#[derive(Clone)] +struct DemoImpl; + +#[async_trait] +impl DemoService for DemoImpl { + async fn hello( + &self, + _client: ConnectionId, + _conns: &dyn ras_jsonrpc_bidirectional_types::ConnectionManager, + name: String, + ) -> Result> { + Ok(format!("hello, {name}")) + } + + async fn echo( + &self, + client: ConnectionId, + conns: &dyn ras_jsonrpc_bidirectional_types::ConnectionManager, + user: &AuthenticatedUser, + req: EchoIn, + ) -> Result> { + // Also push a server→client notification so the test can observe it. + let note = ras_jsonrpc_bidirectional_types::ServerNotification { + method: "ping".to_string(), + params: serde_json::to_value(PushNote { + kind: "after-echo".into(), + }) + .unwrap(), + metadata: None, + }; + let _ = conns + .send_to_connection( + client, + ras_jsonrpc_bidirectional_types::BidirectionalMessage::ServerNotification(note), + ) + .await; + + Ok(EchoOut { + msg: req.msg, + user: user.user_id.clone(), + }) + } + + async fn notify_ping( + &self, + _connection_id: ConnectionId, + _params: PushNote, + ) -> ras_jsonrpc_bidirectional_types::Result<()> { + Ok(()) + } +} + +async fn start_server() -> String { + let connection_manager = Arc::new(DefaultConnectionManager::new()); + let handler = Arc::new(DemoHandler::new( + Arc::new(DemoImpl), + connection_manager.clone(), + )); + + let ws_service = ras_jsonrpc_bidirectional_server::WebSocketServiceBuilder::builder() + .handler(handler) + .auth_provider(Arc::new(MockAuthProvider::default())) + .require_auth(false) + .build() + .build_with_manager(connection_manager); + + type SvcType = BuiltWebSocketService< + DemoHandler, + MockAuthProvider, + DefaultConnectionManager, + >; + let app: Router = Router::new() + .route("/ws", get(websocket_handler::)) + .with_state(ws_service); + + let (addr, _handle) = spawn_tcp(app).await; + // Give axum a tick to start serving. + tokio::time::sleep(Duration::from_millis(50)).await; + format!("ws://{addr}/ws") +} + +#[tokio::test] +async fn unauthorized_method_round_trips() { + let url = start_server().await; + let client = DemoClientBuilder::new(url) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + + let resp = client.hello("alice".to_string()).await.expect("hello ok"); + assert_eq!(resp, "hello, alice"); + + client.disconnect().await.expect("disconnect"); +} + +#[tokio::test] +async fn auth_method_succeeds_and_pushes_notification() { + let url = start_server().await; + let mut client = DemoClientBuilder::new(url) + .with_jwt_token("user-token".to_string()) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + + let pushed = Arc::new(AtomicBool::new(false)); + let pushed_flag = pushed.clone(); + client.on_ping(move |_n: PushNote| { + pushed_flag.store(true, Ordering::SeqCst); + }); + + let resp = client + .echo(EchoIn { + msg: "hi".to_string(), + }) + .await + .expect("echo ok"); + assert_eq!(resp.msg, "hi"); + assert_eq!(resp.user, "user-1"); + + // Wait briefly for the push to land. + let deadline = std::time::Instant::now() + Duration::from_secs(2); + while !pushed.load(Ordering::SeqCst) && std::time::Instant::now() < deadline { + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert!( + pushed.load(Ordering::SeqCst), + "expected ping notification to arrive" + ); + + client.disconnect().await.expect("disconnect"); +} + +#[tokio::test] +async fn auth_method_rejected_for_readonly_user() { + let url = start_server().await; + let client = DemoClientBuilder::new(url) + .with_jwt_token("readonly-token".to_string()) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + + let result = client.echo(EchoIn { msg: "nope".into() }).await; + assert!( + result.is_err(), + "readonly token must not be able to call echo" + ); + + client.disconnect().await.expect("disconnect"); +} diff --git a/crates/rpc/ras-jsonrpc-macro/Cargo.toml b/crates/rpc/ras-jsonrpc-macro/Cargo.toml index 542d528..a7630eb 100644 --- a/crates/rpc/ras-jsonrpc-macro/Cargo.toml +++ b/crates/rpc/ras-jsonrpc-macro/Cargo.toml @@ -46,3 +46,12 @@ axum = { workspace = true } ras-jsonrpc-core = { path = "../ras-jsonrpc-core" } ras-auth-core = { path = "../../core/ras-auth-core" } async-trait = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +ras-test-helpers = { path = "../../test-utils/ras-test-helpers" } +axum-test = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "dispatch" +harness = false diff --git a/crates/rpc/ras-jsonrpc-macro/benches/dispatch.rs b/crates/rpc/ras-jsonrpc-macro/benches/dispatch.rs new file mode 100644 index 0000000..ff87b8f --- /dev/null +++ b/crates/rpc/ras-jsonrpc-macro/benches/dispatch.rs @@ -0,0 +1,66 @@ +//! Criterion bench measuring per-call latency of an authenticated JSON-RPC +//! method through the full stack: generated client → axum router → handler. +//! +//! Run with `cargo bench -p ras-jsonrpc-macro`. + +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_jsonrpc_macro::jsonrpc_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AddRequest { + a: i64, + b: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AddResponse { + sum: i64, +} + +jsonrpc_service!({ + service_name: BenchSvc, + openrpc: false, + methods: [ + WITH_PERMISSIONS(["user"]) add(AddRequest) -> AddResponse, + ] +}); + +fn build_router() -> axum::Router { + BenchSvcBuilder::new("/rpc") + .auth_provider(MockAuthProvider::default()) + .add_handler(|_user, req: AddRequest| async move { Ok(AddResponse { sum: req.a + req.b }) }) + .build() + .expect("router build") +} + +fn bench_dispatch(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + // Spin the server up once and reuse across every iteration. + let (client, _server) = rt.block_on(async { + let server = spawn_http(build_router()); + let url = server.server_url("/rpc").unwrap().to_string(); + let mut client = BenchSvcClientBuilder::new() + .server_url(url) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + (client, server) + }); + + c.bench_function("jsonrpc_add_dispatch", |b| { + b.to_async(&rt).iter(|| { + let client = client.clone(); + async move { + let r = client.add(AddRequest { a: 1, b: 2 }).await.expect("add ok"); + std::hint::black_box(r); + } + }); + }); +} + +criterion_group!(benches, bench_dispatch); +criterion_main!(benches); diff --git a/crates/rpc/ras-jsonrpc-macro/tests/e2e.rs b/crates/rpc/ras-jsonrpc-macro/tests/e2e.rs new file mode 100644 index 0000000..8652a56 --- /dev/null +++ b/crates/rpc/ras-jsonrpc-macro/tests/e2e.rs @@ -0,0 +1,188 @@ +//! End-to-end test that exercises the full chain: +//! generated reqwest client → axum router → handler → response → client. +//! +//! Covers: success path, missing-permission rejection, malformed input. + +use ras_jsonrpc_macro::jsonrpc_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct EchoRequest { + msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct EchoResponse { + msg: String, + user_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct AddRequest { + a: i64, + b: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct AddResponse { + sum: i64, +} + +jsonrpc_service!({ + service_name: Demo, + openrpc: false, + methods: [ + UNAUTHORIZED ping(EchoRequest) -> EchoResponse, + WITH_PERMISSIONS(["user"]) add(AddRequest) -> AddResponse, + WITH_PERMISSIONS(["admin"]) admin_only(EchoRequest) -> EchoResponse, + ] +}); + +fn router() -> axum::Router { + DemoBuilder::new("/rpc") + .auth_provider(MockAuthProvider::default()) + .ping_handler(|req: EchoRequest| async move { + Ok(EchoResponse { + msg: req.msg, + user_id: None, + }) + }) + .add_handler(|_user, req: AddRequest| async move { Ok(AddResponse { sum: req.a + req.b }) }) + .admin_only_handler(|user, req: EchoRequest| async move { + Ok(EchoResponse { + msg: req.msg, + user_id: Some(user.user_id), + }) + }) + .build() + .expect("build router") +} + +fn client(url: String) -> DemoClient { + DemoClientBuilder::new() + .server_url(url) + .build() + .expect("client build") +} + +#[tokio::test] +async fn unauth_method_round_trips() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").expect("server url").to_string(); + + let mut c = client(url); + c.set_bearer_token(Option::::None); + + let resp = c + .ping(EchoRequest { + msg: "hello".to_string(), + }) + .await + .expect("ping ok"); + + assert_eq!(resp.msg, "hello"); + assert_eq!(resp.user_id, None); +} + +#[tokio::test] +async fn permission_required_method_rejects_anonymous() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Option::::None); + + let err = c + .add(AddRequest { a: 2, b: 3 }) + .await + .expect_err("anonymous add must be rejected"); + + let s = err.to_string(); + assert!( + s.contains("Authentication") || s.contains("AUTH") || s.contains("auth"), + "expected auth-related error, got: {s}" + ); +} + +#[tokio::test] +async fn permission_required_method_rejects_wrong_perms() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Some("readonly-token".to_string())); + + let err = c + .add(AddRequest { a: 2, b: 3 }) + .await + .expect_err("readonly user must not be allowed to call add"); + let s = err.to_string(); + assert!( + s.contains("permission") || s.contains("Permission") || s.contains("PERMISSION"), + "expected permission-related error, got: {s}" + ); +} + +#[tokio::test] +async fn permission_required_method_succeeds_with_correct_perms() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Some("user-token".to_string())); + + let resp = c.add(AddRequest { a: 7, b: 35 }).await.expect("add ok"); + assert_eq!(resp.sum, 42); +} + +#[tokio::test] +async fn admin_method_succeeds_with_admin_token() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Some("admin-token".to_string())); + + let resp = c + .admin_only(EchoRequest { + msg: "secret".to_string(), + }) + .await + .expect("admin call ok"); + + assert_eq!(resp.msg, "secret"); + assert_eq!(resp.user_id.as_deref(), Some("admin-1")); +} + +#[tokio::test] +async fn malformed_params_yield_jsonrpc_error() { + // Bypass the typed client to send a malformed body and confirm the + // server returns a JSON-RPC `invalid_params` error rather than a panic. + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let body = serde_json::json!({ + "jsonrpc": "2.0", + "method": "ping", + "params": { "bogus": 1 }, + "id": 1, + }); + + let resp: serde_json::Value = reqwest::Client::new() + .post(url) + .json(&body) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + assert!( + resp.get("error").is_some(), + "expected error in response: {resp}" + ); + let code = resp["error"]["code"].as_i64().unwrap(); + assert_eq!(code, -32602, "expected invalid_params (-32602), got {code}"); +} diff --git a/crates/test-utils/ras-test-helpers/Cargo.toml b/crates/test-utils/ras-test-helpers/Cargo.toml new file mode 100644 index 0000000..68bab77 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "ras-test-helpers" +version = "0.0.0" +edition = "2024" +publish = false +description = "Internal test helpers for the rust-agent-stack workspace (dev-dependency only)" + +[dependencies] +ras-auth-core = { path = "../../core/ras-auth-core" } +axum = { workspace = true } +axum-test = { workspace = true } +tokio = { workspace = true } diff --git a/crates/test-utils/ras-test-helpers/src/auth.rs b/crates/test-utils/ras-test-helpers/src/auth.rs new file mode 100644 index 0000000..5c18644 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/src/auth.rs @@ -0,0 +1,133 @@ +use std::collections::{HashMap, HashSet}; + +use ras_auth_core::{AuthError, AuthFuture, AuthProvider, AuthenticatedUser}; + +/// A small fixed-token auth provider for tests. +/// +/// The default token table: +/// - `"user-token"` → user `user-1`, perms `["user"]` +/// - `"admin-token"` → user `admin-1`, perms `["admin", "user"]` +/// - `"readonly-token"` → user `ro-1`, perms `["read"]` +/// +/// Any other (or empty) token returns [`AuthError::InvalidToken`]. +#[derive(Clone, Debug)] +pub struct MockAuthProvider { + table: HashMap, +} + +impl Default for MockAuthProvider { + fn default() -> Self { + let mut table = HashMap::new(); + table.insert("user-token".to_string(), mock_user("user-1", &["user"])); + table.insert( + "admin-token".to_string(), + mock_user("admin-1", &["admin", "user"]), + ); + table.insert("readonly-token".to_string(), mock_user("ro-1", &["read"])); + Self { table } + } +} + +impl MockAuthProvider { + /// New empty auth provider with no recognized tokens. + pub fn empty() -> Self { + Self { + table: HashMap::new(), + } + } + + /// Insert or replace a token → user mapping. Useful for adding bespoke + /// fixtures on top of the default table. + pub fn with_token(mut self, token: impl Into, user: AuthenticatedUser) -> Self { + self.table.insert(token.into(), user); + self + } +} + +impl AuthProvider for MockAuthProvider { + fn authenticate(&self, token: String) -> AuthFuture<'_> { + let result = self + .table + .get(&token) + .cloned() + .ok_or(AuthError::InvalidToken); + Box::pin(async move { result }) + } +} + +/// Build an [`AuthenticatedUser`] from a string id and a slice of permission +/// names. Convenience for tests that need to construct a user by hand. +pub fn mock_user(user_id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: user_id.to_string(), + permissions: perms + .iter() + .map(|p| (*p).to_string()) + .collect::>(), + metadata: None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mock_user_builds_expected_fields() { + let u = mock_user("alice", &["a", "b"]); + assert_eq!(u.user_id, "alice"); + assert!(u.permissions.contains("a")); + assert!(u.permissions.contains("b")); + assert!(u.metadata.is_none()); + } + + #[tokio::test] + async fn default_provider_resolves_well_known_tokens() { + let p = MockAuthProvider::default(); + let user = p.authenticate("user-token".to_string()).await.unwrap(); + assert_eq!(user.user_id, "user-1"); + assert!(user.permissions.contains("user")); + + let admin = p.authenticate("admin-token".to_string()).await.unwrap(); + assert!(admin.permissions.contains("admin")); + assert!(admin.permissions.contains("user")); + + let ro = p.authenticate("readonly-token".to_string()).await.unwrap(); + assert!(ro.permissions.contains("read")); + + let err = p + .authenticate("totally-bogus".to_string()) + .await + .unwrap_err(); + assert!(matches!(err, ras_auth_core::AuthError::InvalidToken)); + } + + #[tokio::test] + async fn empty_provider_rejects_everything() { + let p = MockAuthProvider::empty(); + let err = p.authenticate("user-token".to_string()).await.unwrap_err(); + assert!(matches!(err, ras_auth_core::AuthError::InvalidToken)); + } + + #[tokio::test] + async fn with_token_extends_table() { + let p = MockAuthProvider::empty().with_token("custom", mock_user("zed", &["god"])); + let user = p.authenticate("custom".to_string()).await.unwrap(); + assert_eq!(user.user_id, "zed"); + assert!(user.permissions.contains("god")); + } + + #[test] + fn check_permissions_returns_specific_error() { + let p = MockAuthProvider::default(); + let user = mock_user("u", &["read"]); + // Has the permission → ok. + p.check_permissions(&user, &["read".into()]).unwrap(); + // Missing → InsufficientPermissions. + let err = p.check_permissions(&user, &["admin".into()]).unwrap_err(); + assert!(matches!( + err, + ras_auth_core::AuthError::InsufficientPermissions { .. } + )); + } +} diff --git a/crates/test-utils/ras-test-helpers/src/lib.rs b/crates/test-utils/ras-test-helpers/src/lib.rs new file mode 100644 index 0000000..222ceb8 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/src/lib.rs @@ -0,0 +1,11 @@ +//! Internal test helpers shared across the rust-agent-stack workspace. +//! +//! This crate is `publish = false` and intended only as a `dev-dependency` for +//! integration tests and benches. It exists to avoid duplicating mock auth +//! providers and server-spawn boilerplate across crates. + +mod auth; +mod server; + +pub use auth::{MockAuthProvider, mock_user}; +pub use server::{spawn_http, spawn_tcp}; diff --git a/crates/test-utils/ras-test-helpers/src/server.rs b/crates/test-utils/ras-test-helpers/src/server.rs new file mode 100644 index 0000000..397ac16 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/src/server.rs @@ -0,0 +1,40 @@ +use std::net::SocketAddr; + +use axum::Router; +use axum_test::TestServer; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; + +/// Spawn the given router behind an `axum-test::TestServer` configured with a +/// real TCP listener on a random port. The returned [`TestServer`] exposes a +/// real `http://127.0.0.1:PORT` URL via [`TestServer::server_address`], which +/// lets generated reqwest-based clients talk to it. +/// +/// Use this for HTTP / JSON-RPC over HTTP / file service tests. +pub fn spawn_http(router: Router) -> TestServer { + TestServer::builder() + .http_transport() + .build(router) + .expect("failed to start axum-test TestServer with http transport") +} + +/// Spawn the given router on a freshly-bound `127.0.0.1` port using a real +/// `axum::serve` task. Returns the bound address and the join handle for the +/// server task. Drop the handle to abort the server. +/// +/// Use this for WebSocket tests where the generated client uses +/// `tokio-tungstenite` and needs a genuine TCP socket. +pub async fn spawn_tcp(router: Router) -> (SocketAddr, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind ephemeral test port"); + let addr = listener + .local_addr() + .expect("failed to read local addr from test listener"); + + let handle = tokio::spawn(async move { + let _ = axum::serve(listener, router).await; + }); + + (addr, handle) +} From fef459777860443ac60186965732c295f71ddba2 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Sat, 25 Apr 2026 09:37:19 +0200 Subject: [PATCH 3/6] Wire query parameters into rest_service! generated client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The rest_service! macro already parsed `?` query syntax for the server side, but the generated reqwest client emitted methods without query arguments — callers had to drop down to raw reqwest to actually exercise those endpoints. This commit teaches the client codegen to: - include each query parameter in the function signature, in the same order as the macro syntax (path → query → body), - detect Option at the type level and serialize Some(_) only, - url-encode required parameters via reqwest's `.query()` helper. Adds the REST e2e suite + dispatch bench in the same commit so the new typed-client query-param test cases (required + optional + body + path combinations) land alongside the wiring they cover. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/rest/ras-rest-macro/Cargo.toml | 10 +- .../rest/ras-rest-macro/benches/dispatch.rs | 72 +++++ crates/rest/ras-rest-macro/src/client.rs | 61 ++++ crates/rest/ras-rest-macro/tests/e2e.rs | 285 ++++++++++++++++++ 4 files changed, 427 insertions(+), 1 deletion(-) create mode 100644 crates/rest/ras-rest-macro/benches/dispatch.rs create mode 100644 crates/rest/ras-rest-macro/tests/e2e.rs diff --git a/crates/rest/ras-rest-macro/Cargo.toml b/crates/rest/ras-rest-macro/Cargo.toml index 45da79c..820c33c 100644 --- a/crates/rest/ras-rest-macro/Cargo.toml +++ b/crates/rest/ras-rest-macro/Cargo.toml @@ -49,4 +49,12 @@ async-trait = { workspace = true } # Server dependencies for tests axum = { workspace = true } ras-auth-core = { path = "../../core/ras-auth-core" } -ras-rest-core = { path = "../ras-rest-core" } \ No newline at end of file +ras-rest-core = { path = "../ras-rest-core" } +ras-test-helpers = { path = "../../test-utils/ras-test-helpers" } +axum-test = { workspace = true } +schemars = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "dispatch" +harness = false \ No newline at end of file diff --git a/crates/rest/ras-rest-macro/benches/dispatch.rs b/crates/rest/ras-rest-macro/benches/dispatch.rs new file mode 100644 index 0000000..6e15e7c --- /dev/null +++ b/crates/rest/ras-rest-macro/benches/dispatch.rs @@ -0,0 +1,72 @@ +//! Criterion bench measuring per-call latency of an authenticated REST GET +//! through the full stack: generated client → axum router → handler. +//! +//! Run with `cargo bench -p ras-rest-macro`. + +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_auth_core::AuthenticatedUser; +use ras_rest_core::{RestResponse, RestResult}; +use ras_rest_macro::rest_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct Item { + id: u32, + name: String, +} + +rest_service!({ + service_name: BenchSvc, + base_path: "/api", + openapi: false, + serve_docs: false, + endpoints: [ + GET WITH_PERMISSIONS(["user"]) items/{id: u32}() -> Item, + ] +}); + +struct BenchImpl; + +#[async_trait::async_trait] +impl BenchSvcTrait for BenchImpl { + async fn get_items_by_id(&self, _user: &AuthenticatedUser, id: u32) -> RestResult { + Ok(RestResponse::ok(Item { + id, + name: "x".into(), + })) + } +} + +fn build_router() -> axum::Router { + BenchSvcBuilder::new(BenchImpl) + .auth_provider(MockAuthProvider::default()) + .build() +} + +fn bench_dispatch(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + let (client, _server) = rt.block_on(async { + let server = spawn_http(build_router()); + let base = server.server_address().unwrap().to_string(); + let mut client = BenchSvcClient::builder(&base) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + (client, server) + }); + + c.bench_function("rest_get_dispatch", |b| { + b.to_async(&rt).iter(|| { + let client = client.clone(); + async move { + let r = client.get_items_by_id(1).await.expect("get ok"); + std::hint::black_box(r); + } + }); + }); +} + +criterion_group!(benches, bench_dispatch); +criterion_main!(benches); diff --git a/crates/rest/ras-rest-macro/src/client.rs b/crates/rest/ras-rest-macro/src/client.rs index dbc0525..bd2f3cd 100644 --- a/crates/rest/ras-rest-macro/src/client.rs +++ b/crates/rest/ras-rest-macro/src/client.rs @@ -1,5 +1,18 @@ use crate::{EndpointDefinition, HttpMethod, ServiceDefinition}; use quote::quote; +use syn::Type; + +/// True if `ty` is syntactically `Option<...>`. Matches the bare `Option` +/// segment as well as fully-qualified forms like `std::option::Option` / +/// `core::option::Option` — anything whose last path segment is `Option`. +fn is_option_type(ty: &Type) -> bool { + if let Type::Path(type_path) = ty { + if let Some(last) = type_path.path.segments.last() { + return last.ident == "Option"; + } + } + false +} /// Generate client code for REST service pub fn generate_client_code(service_def: &ServiceDefinition) -> proc_macro2::TokenStream { @@ -148,6 +161,14 @@ fn generate_client_method(endpoint: &EndpointDefinition) -> proc_macro2::TokenSt call_args.push(quote! { #param_name }); } + // Add query parameters (mirroring the macro syntax order: path → query → body). + for query_param in endpoint.query_params.iter() { + let param_name = &query_param.name; + let param_type = &query_param.param_type; + params.push(quote! { #param_name: #param_type }); + call_args.push(quote! { #param_name }); + } + // Add request body parameter if present if endpoint.request_type.is_some() { let request_type = endpoint.request_type.as_ref().unwrap(); @@ -211,6 +232,44 @@ fn generate_client_method_with_timeout(endpoint: &EndpointDefinition) -> proc_ma }; } + // Build query-string handling. Required params are always serialized; + // `Option` params are skipped when `None`. Values are converted with + // `ToString` and url-encoded by reqwest's `.query()` helper. + let query_handling = if endpoint.query_params.is_empty() { + quote! {} + } else { + let pushes = endpoint.query_params.iter().map(|qp| { + let param_name = &qp.name; + let param_str = qp.name.to_string(); + if is_option_type(&qp.param_type) { + quote! { + if let Some(__v) = &#param_name { + __query_pairs.push((#param_str, __v.to_string())); + } + } + } else { + quote! { + __query_pairs.push((#param_str, #param_name.to_string())); + } + } + }); + quote! { + let mut __query_pairs: Vec<(&'static str, String)> = Vec::new(); + #(#pushes)* + if !__query_pairs.is_empty() { + request_builder = request_builder.query(&__query_pairs); + } + } + }; + + // Add query parameters to the function signature (after path params, + // before the body — matches macro syntax order). + for query_param in endpoint.query_params.iter() { + let param_name = &query_param.name; + let param_type = &query_param.param_type; + params.push(quote! { #param_name: #param_type }); + } + // Add request body parameter if present let request_body_handling = if let Some(request_type) = &endpoint.request_type { params.push(quote! { body: #request_type }); @@ -266,6 +325,8 @@ fn generate_client_method_with_timeout(endpoint: &EndpointDefinition) -> proc_ma request_builder = request_builder.header("Authorization", format!("Bearer {}", token)); } + #query_handling + #request_body_handling // Override timeout if provided (not supported in WASM builds) diff --git a/crates/rest/ras-rest-macro/tests/e2e.rs b/crates/rest/ras-rest-macro/tests/e2e.rs new file mode 100644 index 0000000..2c95e29 --- /dev/null +++ b/crates/rest/ras-rest-macro/tests/e2e.rs @@ -0,0 +1,285 @@ +//! End-to-end test: generated reqwest client → axum router → trait impl → +//! response → client. Covers GET, POST with body, path params, query params, +//! and auth-related rejection paths. + +use ras_auth_core::AuthenticatedUser; +use ras_rest_core::{RestError, RestResponse, RestResult}; +use ras_rest_macro::rest_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct Item { + id: u32, + name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct CreateItem { + name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct ItemsResponse { + items: Vec, +} + +rest_service!({ + service_name: Demo, + base_path: "/api", + openapi: false, + serve_docs: false, + endpoints: [ + GET UNAUTHORIZED items() -> ItemsResponse, + GET WITH_PERMISSIONS(["user"]) items/{id: u32}() -> Item, + POST WITH_PERMISSIONS(["admin"]) items(CreateItem) -> Item, + GET UNAUTHORIZED search ? q: String & limit: Option & exact: bool () -> ItemsResponse, + POST WITH_PERMISSIONS(["admin"]) items/batch ? notify: bool (CreateItem) -> Item, + GET WITH_PERMISSIONS(["user"]) items/{id: u32}/related ? tag: Option () -> ItemsResponse, + ] +}); + +struct DemoImpl; + +#[async_trait::async_trait] +impl DemoTrait for DemoImpl { + async fn get_items(&self) -> RestResult { + Ok(RestResponse::ok(ItemsResponse { + items: vec![Item { + id: 1, + name: "alpha".into(), + }], + })) + } + + async fn get_items_by_id(&self, _user: &AuthenticatedUser, id: u32) -> RestResult { + if id == 404 { + Err(RestError::not_found("missing")) + } else { + Ok(RestResponse::ok(Item { + id, + name: format!("item-{id}"), + })) + } + } + + async fn post_items(&self, user: &AuthenticatedUser, body: CreateItem) -> RestResult { + // Use the user_id length so we can verify the user actually arrived. + Ok(RestResponse::created(Item { + id: user.user_id.len() as u32, + name: body.name, + })) + } + + async fn get_search( + &self, + q: String, + limit: Option, + exact: bool, + ) -> RestResult { + let n = limit.unwrap_or(2); + let prefix = if exact { "exact" } else { "fuzzy" }; + let items = (0..n) + .map(|i| Item { + id: i, + name: format!("{prefix}:{q}-{i}"), + }) + .collect(); + Ok(RestResponse::ok(ItemsResponse { items })) + } + + async fn post_items_batch( + &self, + _user: &AuthenticatedUser, + notify: bool, + body: CreateItem, + ) -> RestResult { + // Encode the bool query param into the response so we can assert on it. + let suffix = if notify { "(notified)" } else { "(silent)" }; + Ok(RestResponse::created(Item { + id: 0, + name: format!("{}{suffix}", body.name), + })) + } + + async fn get_items_by_id_related( + &self, + _user: &AuthenticatedUser, + id: u32, + tag: Option, + ) -> RestResult { + let label = tag.unwrap_or_else(|| "none".into()); + Ok(RestResponse::ok(ItemsResponse { + items: vec![Item { + id, + name: format!("related/{label}"), + }], + })) + } +} + +fn router() -> axum::Router { + DemoBuilder::new(DemoImpl) + .auth_provider(MockAuthProvider::default()) + .build() +} + +fn client(base: &str) -> DemoClient { + DemoClient::builder(base).build().expect("client build") +} + +#[tokio::test] +async fn unauth_get_round_trips() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let resp = client(&base).get_items().await.expect("get_items ok"); + assert_eq!(resp.items.len(), 1); + assert_eq!(resp.items[0].name, "alpha"); +} + +#[tokio::test] +async fn auth_get_with_path_param_succeeds_with_user_token() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); + + let item = c.get_items_by_id(7).await.expect("get_items_by_id ok"); + assert_eq!(item.id, 7); + assert_eq!(item.name, "item-7"); +} + +#[tokio::test] +async fn auth_get_rejected_without_token() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + // No bearer token set on client. + let err = client(&base) + .get_items_by_id(1) + .await + .expect_err("must be rejected"); + let s = err.to_string(); + assert!(s.contains("401") || s.contains("Unauthorized"), "got: {s}"); +} + +#[tokio::test] +async fn auth_post_rejected_with_insufficient_perms() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); // not admin + + let err = c + .post_items(CreateItem { + name: "x".to_string(), + }) + .await + .expect_err("user-token can't POST items"); + let s = err.to_string(); + assert!(s.contains("403") || s.contains("Forbidden"), "got: {s}"); +} + +#[tokio::test] +async fn auth_post_with_admin_succeeds_and_user_id_propagates() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("admin-token".to_string())); + + let item = c + .post_items(CreateItem { name: "foo".into() }) + .await + .expect("post_items ok"); + assert_eq!(item.name, "foo"); + // admin-1 is 7 chars long. + assert_eq!(item.id, 7); +} + +#[tokio::test] +async fn query_params_required_and_optional_serialize_correctly() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + + // Optional `limit` provided, required `q` and `exact` set. + let resp = client(&base) + .get_search("hi".to_string(), Some(3), true) + .await + .expect("search ok"); + assert_eq!(resp.items.len(), 3); + assert_eq!(resp.items[0].name, "exact:hi-0"); + assert_eq!(resp.items[2].name, "exact:hi-2"); + + // Optional `limit` omitted (None) → handler default of 2 applies, and the + // bool flips the prefix. + let resp = client(&base) + .get_search("zz".to_string(), None, false) + .await + .expect("search ok"); + assert_eq!(resp.items.len(), 2); + assert_eq!(resp.items[0].name, "fuzzy:zz-0"); +} + +#[tokio::test] +async fn query_params_with_body_and_auth() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("admin-token".to_string())); + + let item = c + .post_items_batch( + true, + CreateItem { + name: "alpha".into(), + }, + ) + .await + .expect("post_items_batch ok"); + assert_eq!(item.name, "alpha(notified)"); + + let item = c + .post_items_batch( + false, + CreateItem { + name: "beta".into(), + }, + ) + .await + .expect("post_items_batch ok"); + assert_eq!(item.name, "beta(silent)"); +} + +#[tokio::test] +async fn query_params_with_path_param() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); + + let resp = c + .get_items_by_id_related(42, Some("featured".into())) + .await + .expect("related with tag"); + assert_eq!(resp.items[0].id, 42); + assert_eq!(resp.items[0].name, "related/featured"); + + let resp = c + .get_items_by_id_related(42, None) + .await + .expect("related without tag"); + assert_eq!(resp.items[0].name, "related/none"); +} + +#[tokio::test] +async fn handler_error_surfaces_to_client() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); + + let err = c + .get_items_by_id(404) + .await + .expect_err("404 sentinel must error"); + assert!(err.to_string().contains("404"), "got: {err}"); +} From 56f8e7b345cbc25c4c45e21a50a55818494b2d9c Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Sat, 25 Apr 2026 09:37:46 +0200 Subject: [PATCH 4/6] Push framework coverage with targeted unit tests Lifts workspace test coverage from 80% to 87% by adding focused unit tests for the framework's runtime types and trait defaults. The bias is toward "every public function called from at least one test" rather than chasing line coverage on data-shape modules. Files moved into the 95-100% range: - ras-rest-core (RestError/RestResponse constructors, IntoRestError) - ras-identity-core (NoopPermissions/StaticPermissions, error variants) - ras-identity-oauth2/config (provider config builder + serde) - ras-jsonrpc-bidirectional-types: error.rs, sender.rs (incl. real-sink WebSocketMessageSender close idempotency), manager.rs default + ext trait helpers via a small in-memory stub - ras-jsonrpc-bidirectional-server: error.rs status codes, connection.rs context + ChannelMessageSender, handler.rs default trait methods, router.rs notification + error wrapping paths, plus a new tests/manager_unit.rs that pins down the DefaultConnectionManager contract (subscriptions, broadcast counts, permission filtering, pending-request lifecycle) without spinning up a real WebSocket - ras-jsonrpc-bidirectional-client: error.rs constructors + From impls + recovery classification, config.rs full builder + URL/header paths, client.rs not-connected guards on call/notify/subscribe and full ClientBuilder setter coverage - ras-jsonrpc-types (canonical error code constructors + serde) Remaining gaps live in genuinely network-bound paths (WebSocket run loop / reconnect loop) and in the openrpc-types data crate, which is spec-shaped and lower runtime risk. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/ras-identity-core/Cargo.toml | 5 +- crates/core/ras-identity-core/src/lib.rs | 70 +++++ .../ras-identity-oauth2/src/config.rs | 69 +++++ crates/rest/ras-rest-core/src/lib.rs | 86 ++++++ .../src/client.rs | 84 +++++ .../src/config.rs | 100 ++++++ .../src/error.rs | 89 ++++++ .../Cargo.toml | 3 +- .../src/connection.rs | 87 ++++++ .../src/error.rs | 63 ++++ .../src/handler.rs | 68 +++++ .../src/router.rs | 154 ++++++++++ .../tests/manager_unit.rs | 262 ++++++++++++++++ .../src/error.rs | 77 +++++ .../src/manager.rs | 288 ++++++++++++++++++ .../src/sender.rs | 120 ++++++++ crates/rpc/ras-jsonrpc-types/src/lib.rs | 73 +++++ 17 files changed, 1696 insertions(+), 2 deletions(-) create mode 100644 crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs diff --git a/crates/core/ras-identity-core/Cargo.toml b/crates/core/ras-identity-core/Cargo.toml index ab6de24..748d4e4 100644 --- a/crates/core/ras-identity-core/Cargo.toml +++ b/crates/core/ras-identity-core/Cargo.toml @@ -11,4 +11,7 @@ homepage = "https://github.com/example/rust-agent-stack" async-trait = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -thiserror = { workspace = true } \ No newline at end of file +thiserror = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true } \ No newline at end of file diff --git a/crates/core/ras-identity-core/src/lib.rs b/crates/core/ras-identity-core/src/lib.rs index 49d7543..ac8490c 100644 --- a/crates/core/ras-identity-core/src/lib.rs +++ b/crates/core/ras-identity-core/src/lib.rs @@ -78,3 +78,73 @@ impl UserPermissions for StaticPermissions { Ok(self.permissions.clone()) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn vi() -> VerifiedIdentity { + VerifiedIdentity { + provider_id: "test".into(), + subject: "alice".into(), + email: Some("a@b.com".into()), + display_name: Some("Alice".into()), + metadata: None, + } + } + + #[test] + fn identity_error_display_per_variant() { + assert_eq!( + IdentityError::InvalidCredentials.to_string(), + "Invalid credentials" + ); + assert_eq!( + IdentityError::ProviderNotFound("foo".into()).to_string(), + "Provider not found: foo" + ); + assert_eq!( + IdentityError::ProviderError("bad".into()).to_string(), + "Provider error: bad" + ); + assert_eq!( + IdentityError::UnsupportedMethod.to_string(), + "Unsupported authentication method" + ); + assert_eq!( + IdentityError::InvalidPayload.to_string(), + "Invalid authentication payload" + ); + assert_eq!( + IdentityError::SessionError("expired".into()).to_string(), + "Session error: expired" + ); + + let parse_err = serde_json::from_str::("not json").unwrap_err(); + let wrapped: IdentityError = parse_err.into(); + assert!(wrapped.to_string().starts_with("Serialization error:")); + } + + #[tokio::test] + async fn noop_permissions_returns_empty() { + let p = NoopPermissions; + let perms = p.get_permissions(&vi()).await.unwrap(); + assert!(perms.is_empty()); + } + + #[tokio::test] + async fn static_permissions_returns_provided_list() { + let p = StaticPermissions::new(vec!["a".into(), "b".into()]); + let perms = p.get_permissions(&vi()).await.unwrap(); + assert_eq!(perms, vec!["a".to_string(), "b".to_string()]); + } + + #[test] + fn verified_identity_serde_round_trips() { + let v = vi(); + let json = serde_json::to_string(&v).unwrap(); + let parsed: VerifiedIdentity = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.subject, "alice"); + assert_eq!(parsed.provider_id, "test"); + } +} diff --git a/crates/identity/ras-identity-oauth2/src/config.rs b/crates/identity/ras-identity-oauth2/src/config.rs index 14c09ad..01f8a35 100644 --- a/crates/identity/ras-identity-oauth2/src/config.rs +++ b/crates/identity/ras-identity-oauth2/src/config.rs @@ -80,3 +80,72 @@ impl OAuth2Config { self } } + +#[cfg(test)] +mod tests { + use super::*; + + fn provider() -> OAuth2ProviderConfig { + OAuth2ProviderConfig { + provider_id: "google".into(), + client_id: "cid".into(), + client_secret: "secret".into(), + authorization_endpoint: "https://x/auth".into(), + token_endpoint: "https://x/token".into(), + userinfo_endpoint: Some("https://x/info".into()), + redirect_uri: "https://app/cb".into(), + scopes: vec!["openid".into(), "email".into()], + auth_params: HashMap::new(), + use_pkce: true, + user_info_mapping: None, + } + } + + #[test] + fn user_info_mapping_default_uses_oidc_field_names() { + let m = UserInfoMapping::default(); + assert_eq!(m.subject_field.as_deref(), Some("sub")); + assert_eq!(m.email_field.as_deref(), Some("email")); + assert_eq!(m.name_field.as_deref(), Some("name")); + assert_eq!(m.picture_field.as_deref(), Some("picture")); + } + + #[test] + fn oauth2_config_builder_chains_settings() { + let p = provider(); + let cfg = OAuth2Config::new() + .add_provider(p.clone()) + .with_state_ttl(120) + .with_http_timeout(7); + assert_eq!(cfg.state_ttl_seconds, 120); + assert_eq!(cfg.http_timeout_seconds, 7); + assert!(cfg.providers.contains_key("google")); + } + + #[test] + fn provider_config_round_trips_through_serde() { + let p = provider(); + let json = serde_json::to_string(&p).unwrap(); + let parsed: OAuth2ProviderConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.provider_id, p.provider_id); + assert_eq!(parsed.client_id, p.client_id); + assert_eq!(parsed.scopes, p.scopes); + assert!(parsed.use_pkce); + } + + #[test] + fn user_info_mapping_serde() { + let m = UserInfoMapping::default(); + let json = serde_json::to_string(&m).unwrap(); + let parsed: UserInfoMapping = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.subject_field, m.subject_field); + } + + #[test] + fn defaults_are_sensible() { + let cfg = OAuth2Config::default(); + assert!(cfg.providers.is_empty()); + assert_eq!(cfg.state_ttl_seconds, 600); + assert_eq!(cfg.http_timeout_seconds, 30); + } +} diff --git a/crates/rest/ras-rest-core/src/lib.rs b/crates/rest/ras-rest-core/src/lib.rs index 63557e2..47bab98 100644 --- a/crates/rest/ras-rest-core/src/lib.rs +++ b/crates/rest/ras-rest-core/src/lib.rs @@ -168,3 +168,89 @@ impl RestResultExt for Resul .map_err(|e| RestError::with_internal(status, msg, e)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rest_response_constructors_set_correct_status() { + assert_eq!(RestResponse::ok(1).status, 200); + assert_eq!(RestResponse::created("x").status, 201); + assert_eq!(RestResponse::accepted(true).status, 202); + let nc: RestResponse<()> = RestResponse::no_content(); + assert_eq!(nc.status, 204); + assert_eq!(RestResponse::with_status(418, "tea").status, 418); + // Body is preserved. + assert_eq!(RestResponse::ok(42).body, 42); + } + + #[test] + fn rest_error_constructors_set_correct_status_and_message() { + let cases = [ + (RestError::bad_request("a"), 400), + (RestError::unauthorized("a"), 401), + (RestError::forbidden("a"), 403), + (RestError::not_found("a"), 404), + (RestError::conflict("a"), 409), + (RestError::unprocessable_entity("a"), 422), + (RestError::internal_server_error("a"), 500), + (RestError::bad_gateway("a"), 502), + (RestError::service_unavailable("a"), 503), + ]; + for (err, expected) in cases { + assert_eq!(err.status, expected); + assert_eq!(err.message, "a"); + assert!(err.internal_error.is_none()); + // Display includes the status and message. + let s = err.to_string(); + assert!(s.contains(&expected.to_string())); + assert!(s.contains("a")); + } + } + + #[test] + fn rest_error_with_internal_carries_source() { + #[derive(Debug, thiserror::Error)] + #[error("inner failure")] + struct Inner; + let err = RestError::with_internal(503, "down", Inner); + assert_eq!(err.status, 503); + assert!(err.internal_error.is_some()); + // source() returns the wrapped error. + let src = std::error::Error::source(&err).unwrap(); + assert_eq!(src.to_string(), "inner failure"); + } + + #[test] + fn into_rest_error_blanket_impl() { + let err = std::io::Error::new(std::io::ErrorKind::Other, "io"); + let rest = err.into_rest_error(); + assert_eq!(rest.status, 500); + assert_eq!(rest.message, "Internal server error"); + assert!(rest.internal_error.is_some()); + } + + #[test] + fn rest_result_ext_maps_ok_and_err() { + let ok: Result = Ok(7); + let mapped: RestResult = ok.internal_server_error(); + let resp = mapped.unwrap(); + assert_eq!(resp.status, 200); + assert_eq!(resp.body, 7); + + let err: Result = + Err(std::io::Error::new(std::io::ErrorKind::Other, "x")); + let mapped: RestResult = err.internal_server_error(); + let e = mapped.unwrap_err(); + assert_eq!(e.status, 500); + + // rest_error variant lets callers customize. + let err: Result = + Err(std::io::Error::new(std::io::ErrorKind::Other, "x")); + let mapped: RestResult = err.rest_error(418, "teapot"); + let e = mapped.unwrap_err(); + assert_eq!(e.status, 418); + assert_eq!(e.message, "teapot"); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs index 2a7634f..3fa1de6 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs @@ -734,4 +734,88 @@ mod tests { assert!(!client.is_connected().await); assert!(client.connection_id().await.is_none()); } + + #[tokio::test] + async fn builder_jwt_in_query_params_and_full_setters() { + // Exercise every with_* setter so each path is colored. We don't + // auto-connect (no server), but the resulting config must reflect + // each option. + let custom = ReconnectConfig::default(); + let client = ClientBuilder::new("ws://localhost:8080") + .with_jwt_token("tok".into()) + .with_jwt_in_header(false) + .with_header("X-Custom", "v") + .with_request_timeout(Duration::from_secs(11)) + .with_reconnect_config(custom) + .with_heartbeat_interval(None) + .with_connection_timeout(Duration::from_secs(7)) + .with_auto_connect(false) + .build() + .await + .expect("build"); + + assert!(matches!(client.config().auth, AuthConfig::JwtParams { .. })); + assert_eq!(client.config().request_timeout, Duration::from_secs(11)); + assert_eq!(client.config().connection_timeout, Duration::from_secs(7)); + assert!(client.config().heartbeat_interval.is_none()); + assert_eq!( + client.config().custom_headers.get("X-Custom"), + Some(&"v".to_string()) + ); + assert!(client.active_subscriptions().is_empty()); + assert_eq!(client.pending_requests_count(), 0); + } + + #[tokio::test] + async fn builder_without_token_yields_no_auth() { + let client = ClientBuilder::new("ws://localhost:8080") + .build() + .await + .expect("build"); + assert!(matches!(client.config().auth, AuthConfig::None)); + } + + #[tokio::test] + async fn call_notify_subscribe_unsubscribe_require_connected_state() { + let client = ClientBuilder::new("ws://localhost:8080") + .build() + .await + .expect("build"); + + // call → NotConnected + let err = client.call("m", None).await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + + // notify → NotConnected + let err = client.notify("m", None).await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + + // subscribe → NotConnected + let handler: NotificationHandler = std::sync::Arc::new(|_method: &str, _params: &Value| {}); + let err = client.subscribe("t", handler.clone()).await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + + // unsubscribe → NotConnected + let err = client.unsubscribe("t").await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + } + + #[tokio::test] + async fn handler_registration_does_not_require_connected_state() { + let client = ClientBuilder::new("ws://localhost:8080") + .build() + .await + .expect("build"); + + let n: NotificationHandler = std::sync::Arc::new(|_, _| {}); + let e: ConnectionEventHandler = std::sync::Arc::new(|_event| {}); + client.on_notification("evt", n); + client.on_connection_event("named", e); + + // cleanup_expired_requests is callable even with nothing pending. + client.cleanup_expired_requests().await; + + // Disconnect-when-already-disconnected is a no-op success. + client.disconnect().await.expect("disconnect ok"); + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs index 78cc666..bf73d1f 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs @@ -337,4 +337,104 @@ mod tests { config.request_timeout = Duration::from_secs(0); assert!(config.validate().is_err()); } + + #[test] + fn validate_rejects_each_invalid_field() { + let base = ClientConfig::new("ws://localhost:8080"); + + let mut c = base.clone(); + c.request_timeout = Duration::ZERO; + assert!(c.validate().unwrap_err().contains("Request timeout")); + + let mut c = base.clone(); + c.connection_timeout = Duration::ZERO; + assert!(c.validate().unwrap_err().contains("Connection timeout")); + + let mut c = base.clone(); + c.message_buffer_size = 0; + assert!(c.validate().unwrap_err().contains("Message buffer size")); + + let mut c = base.clone(); + c.max_pending_requests = 0; + assert!(c.validate().unwrap_err().contains("Max pending requests")); + + let mut c = base.clone(); + c.reconnect.backoff_multiplier = 0.0; + assert!(c.validate().unwrap_err().contains("Backoff multiplier")); + + let mut c = base.clone(); + c.reconnect.jitter = 1.5; + assert!(c.validate().unwrap_err().contains("Jitter")); + + // Native build also rejects an unparseable URL. + let mut c = base.clone(); + c.url = "not a url".to_string(); + assert!(c.validate().is_err()); + } + + #[test] + fn connection_url_appends_amp_when_query_already_present() { + let cfg = ClientConfig { + auth: AuthConfig::JwtParams { + token: "tok".into(), + }, + ..ClientConfig::new("ws://h/ws?x=1") + }; + assert_eq!(cfg.get_connection_url(), "ws://h/ws?x=1&token=tok"); + } + + #[test] + fn connection_url_with_custom_params() { + let mut params = HashMap::new(); + params.insert("foo".to_string(), "bar".to_string()); + let cfg = ClientConfig { + auth: AuthConfig::CustomParams { params }, + ..ClientConfig::new("ws://h/ws") + }; + let url = cfg.get_connection_url(); + assert!(url.starts_with("ws://h/ws?")); + assert!(url.contains("foo=bar")); + } + + #[test] + fn connection_headers_with_custom_headers_variant() { + let mut headers = HashMap::new(); + headers.insert("X-API-Key".to_string(), "k".to_string()); + let cfg = ClientConfig { + auth: AuthConfig::CustomHeaders { + headers: headers.clone(), + }, + ..ClientConfig::new("ws://h/ws") + }; + let h = cfg.get_connection_headers(); + assert_eq!(h.get("X-API-Key"), Some(&"k".to_string())); + } + + #[test] + fn connection_url_falls_through_for_no_param_auth() { + let cfg = ClientConfig { + auth: AuthConfig::JwtHeader { + token: "tok".into(), + }, + ..ClientConfig::new("ws://h/ws") + }; + // Header-based auth must NOT mutate the URL. + assert_eq!(cfg.get_connection_url(), "ws://h/ws"); + } + + #[test] + fn calculate_delay_zero_attempt_returns_initial() { + let cfg = ReconnectConfig { + jitter: 0.0, + ..ReconnectConfig::default() + }; + assert_eq!(cfg.calculate_delay(0), cfg.initial_delay); + } + + #[test] + fn auth_config_default_via_helper() { + // exercises the `Default` impl + AuthConfig::default branches. + let _ = AuthConfig::default(); + let _ = ReconnectConfig::default(); + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs index 033f852..90f549a 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs @@ -232,4 +232,93 @@ mod tests { panic!("Expected reconnection failed error"); } } + + #[test] + fn covers_all_constructors_and_display() { + // Stringy constructors → matching variants and messages. + for (err, expected_prefix) in [ + ( + ClientError::invalid_request_id("rid"), + "Invalid request ID:", + ), + (ClientError::invalid_url("not://valid"), "Invalid URL:"), + (ClientError::send_failed("eof"), "Failed to send message:"), + ( + ClientError::receive_failed("eof"), + "Failed to receive message:", + ), + (ClientError::subscription("topic"), "Subscription error:"), + (ClientError::configuration("bad"), "Configuration error:"), + (ClientError::internal("oops"), "Internal error:"), + (ClientError::authentication("nope"), "Authentication error:"), + ] { + let s = err.to_string(); + assert!( + s.starts_with(expected_prefix), + "expected prefix {expected_prefix:?} in {s:?}" + ); + } + + // Bare variants. + assert_eq!( + ClientError::NotConnected.to_string(), + "Client is not connected" + ); + assert_eq!( + ClientError::AlreadyConnected.to_string(), + "Client is already connected" + ); + } + + #[test] + fn from_impls_route_to_correct_variants() { + let json_err = serde_json::from_str::("not json").unwrap_err(); + assert!(matches!(ClientError::from(json_err), ClientError::Json(_))); + + let bidir_err = BidirectionalError::Timeout; + assert!(matches!( + ClientError::from(bidir_err), + ClientError::Bidirectional(_) + )); + + let io_err = std::io::Error::new(std::io::ErrorKind::Other, "io"); + assert!(matches!(ClientError::from(io_err), ClientError::Io(_))); + + let url_err = url::Url::parse("not a url").unwrap_err(); + assert!(matches!( + ClientError::from(url_err), + ClientError::UrlParse(_) + )); + } + + #[test] + fn recovery_classification_is_exhaustive_for_named_buckets() { + // Should reconnect → also recoverable. + for err in [ + ClientError::connection("x"), + ClientError::receive_failed("x"), + ClientError::NotConnected, + ] { + assert!(err.should_reconnect()); + assert!(err.is_recoverable()); + } + + // Recoverable but no reconnect. + for err in [ClientError::timeout(1), ClientError::send_failed("x")] { + assert!(err.is_recoverable()); + assert!(!err.should_reconnect()); + } + + // Neither. + for err in [ + ClientError::authentication("x"), + ClientError::AlreadyConnected, + ClientError::invalid_url("x"), + ClientError::configuration("x"), + ClientError::internal("x"), + ] { + assert!(!err.is_recoverable()); + assert!(!err.should_reconnect()); + } + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml index 3bb3dd5..7c6b3f7 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml @@ -30,4 +30,5 @@ futures = { workspace = true } dashmap = { workspace = true } [dev-dependencies] -tokio-test = { workspace = true } \ No newline at end of file +tokio-test = { workspace = true } +ras-jsonrpc-types = { path = "../../ras-jsonrpc-types" } \ No newline at end of file diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs index 9c9e1cc..63c9bc0 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs @@ -125,3 +125,90 @@ impl ConnectionContext { } } } + +#[cfg(test)] +mod tests { + use super::*; + use ras_auth_core::AuthenticatedUser; + use std::collections::HashSet; + + fn user(id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: id.to_string(), + permissions: perms.iter().map(|s| s.to_string()).collect::>(), + metadata: None, + } + } + + fn ctx() -> ConnectionContext { + let id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(8); + let sender = ChannelMessageSender::new(id, tx); + ConnectionContext::new(id, sender) + } + + #[tokio::test] + async fn channel_sender_send_propagates_and_id_round_trips() { + let id = ConnectionId::new(); + let (tx, mut rx) = mpsc::channel(2); + let sender = ChannelMessageSender::new(id, tx); + assert_eq!(sender.connection_id(), id); + + sender.send(BidirectionalMessage::Ping).await.unwrap(); + let received = rx.recv().await.unwrap(); + assert!(matches!(received, BidirectionalMessage::Ping)); + } + + #[tokio::test] + async fn channel_sender_returns_string_error_when_closed() { + let id = ConnectionId::new(); + let (tx, rx) = mpsc::channel(1); + drop(rx); + let sender = ChannelMessageSender::new(id, tx); + let err = sender.send(BidirectionalMessage::Ping).await.unwrap_err(); + assert!(!err.is_empty()); + } + + #[tokio::test] + async fn auth_state_round_trips() { + let c = ctx(); + assert!(!c.is_authenticated().await); + assert!(c.get_user().await.is_none()); + assert!(!c.has_permission("admin").await); + + c.set_user(user("alice", &["admin"])).await; + assert!(c.is_authenticated().await); + assert_eq!(c.get_user().await.unwrap().user_id, "alice"); + assert!(c.has_permission("admin").await); + assert!(!c.has_permission("nope").await); + + c.clear_user().await; + assert!(!c.is_authenticated().await); + } + + #[tokio::test] + async fn subscriptions_round_trip() { + let c = ctx(); + assert!(c.get_subscriptions().await.is_empty()); + assert!(!c.is_subscribed_to("t1").await); + + c.subscribe("t1".into()).await; + c.subscribe("t2".into()).await; + assert!(c.is_subscribed_to("t1").await); + assert_eq!(c.get_subscriptions().await.len(), 2); + + assert!(c.unsubscribe("t1").await); + assert!(!c.is_subscribed_to("t1").await); + // Idempotent: removing again returns false. + assert!(!c.unsubscribe("t1").await); + } + + #[tokio::test] + async fn metadata_get_set() { + let c = ctx(); + assert!(c.get_metadata("k").await.is_none()); + c.set_metadata("k", serde_json::json!("v")).await; + assert_eq!(c.get_metadata("k").await.unwrap(), serde_json::json!("v")); + assert!(c.get_metadata("missing").await.is_none()); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs index 5105325..52fb538 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs @@ -74,3 +74,66 @@ impl ServerError { /// Convenience type alias for server operation results pub type ServerResult = Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn status_codes_per_variant() { + assert_eq!( + ServerError::AuthenticationFailed(AuthError::InvalidToken).to_status_code(), + StatusCode::UNAUTHORIZED + ); + assert_eq!( + ServerError::PermissionDenied("nope".into()).to_status_code(), + StatusCode::FORBIDDEN + ); + assert_eq!( + ServerError::ConnectionNotFound("abc".into()).to_status_code(), + StatusCode::NOT_FOUND + ); + assert_eq!( + ServerError::InvalidRequest("bad".into()).to_status_code(), + StatusCode::BAD_REQUEST + ); + assert_eq!( + ServerError::HandlerNotFound("m".into()).to_status_code(), + StatusCode::NOT_IMPLEMENTED + ); + for variant in [ + ServerError::UpgradeFailed("x".into()), + ServerError::RoutingFailed("x".into()), + ServerError::WebSocketError("x".into()), + ServerError::Internal("x".into()), + ] { + assert_eq!(variant.to_status_code(), StatusCode::INTERNAL_SERVER_ERROR); + } + + // From impls + let json_err = serde_json::from_str::("not json").unwrap_err(); + let from_json: ServerError = json_err.into(); + assert_eq!( + from_json.to_status_code(), + StatusCode::INTERNAL_SERVER_ERROR + ); + + let bidir_err: BidirectionalError = BidirectionalError::SendError("e".into()); + let from_bidir: ServerError = bidir_err.into(); + assert_eq!( + from_bidir.to_status_code(), + StatusCode::INTERNAL_SERVER_ERROR + ); + + let auth_err: AuthError = AuthError::TokenExpired; + let from_auth: ServerError = auth_err.into(); + assert_eq!(from_auth.to_status_code(), StatusCode::UNAUTHORIZED); + + // Display formatting: spot-check each kind once. + assert!( + ServerError::UpgradeFailed("x".into()) + .to_string() + .starts_with("WebSocket upgrade failed:") + ); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs index 5e46482..b9f07eb 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs @@ -353,3 +353,71 @@ impl WebSocketHandler { .map_err(|e| ServerError::WebSocketError(e.to_string())) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::connection::ChannelMessageSender; + use ras_jsonrpc_bidirectional_types::ConnectionId; + + /// A minimal MessageHandler that only implements the required method — + /// every other method falls through to the default impl, which is what + /// these tests are verifying. + struct PassThrough; + + #[async_trait] + impl MessageHandler for PassThrough { + async fn handle_request( + &self, + _request: JsonRpcRequest, + _context: Arc, + ) -> ServerResult> { + Ok(None) + } + } + + fn ctx() -> Arc { + let id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(4); + let sender = ChannelMessageSender::new(id, tx); + Arc::new(ConnectionContext::new(id, sender)) + } + + #[tokio::test] + async fn default_handle_subscribe_writes_to_context() { + let h = PassThrough; + let c = ctx(); + h.handle_subscribe(vec!["a".into(), "b".into()], c.clone()) + .await + .unwrap(); + assert!(c.is_subscribed_to("a").await); + assert!(c.is_subscribed_to("b").await); + } + + #[tokio::test] + async fn default_handle_unsubscribe_removes_from_context() { + let h = PassThrough; + let c = ctx(); + c.subscribe("a".into()).await; + c.subscribe("b".into()).await; + h.handle_unsubscribe(vec!["a".into()], c.clone()) + .await + .unwrap(); + assert!(!c.is_subscribed_to("a").await); + assert!(c.is_subscribed_to("b").await); + } + + #[tokio::test] + async fn default_lifecycle_methods_succeed() { + let h = PassThrough; + let c = ctx(); + h.on_connect(c.clone()).await.unwrap(); + h.on_ping(c.clone()).await.unwrap(); + h.on_pong(c.clone()).await.unwrap(); + h.on_disconnect(c.clone(), Some("bye".into())) + .await + .unwrap(); + // None reason path too. + h.on_disconnect(c, None).await.unwrap(); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs index 7c0030f..e92293e 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs @@ -258,4 +258,158 @@ mod tests { assert!(response.error.is_some()); assert_eq!(response.error.unwrap().code, -32601); // METHOD_NOT_FOUND } + + fn test_context() -> Arc { + let connection_id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(1); + let sender = crate::connection::ChannelMessageSender::new(connection_id, tx); + Arc::new(ConnectionContext::new(connection_id, sender)) + } + + #[tokio::test] + async fn register_low_level_handler_returns_explicit_response() { + let mut router = MessageRouter::new(); + router.register("low.echo", |req, _ctx| async move { + // Hand-built response — proves the low-level register path is wired. + Ok(req + .id + .clone() + .map(|id| JsonRpcResponse::success(req.params.unwrap_or(json!(null)), Some(id)))) + }); + + let ctx = test_context(); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "low.echo".into(), + params: Some(json!(42)), + id: Some(json!(7)), + }; + let resp = router.handle_request(req, ctx).await.unwrap().unwrap(); + assert_eq!(resp.id, Some(json!(7))); + assert_eq!(resp.result.unwrap(), json!(42)); + } + + #[tokio::test] + async fn register_notification_never_returns_response() { + let mut router = MessageRouter::new(); + let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let c2 = counter.clone(); + router.register_notification("evt.tick", move |_req, _ctx| { + let c = c2.clone(); + async move { + c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Ok(()) + } + }); + + let ctx = test_context(); + // Even if the request has an id, the notification handler returns None. + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "evt.tick".into(), + params: None, + id: Some(json!(1)), + }; + assert!(router.handle_request(req, ctx).await.unwrap().is_none()); + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn handler_error_with_id_becomes_error_response() { + let mut router = MessageRouter::new(); + router.register("explode", |_req, _ctx| async move { + Err::, _>(ServerError::Internal("kaboom".into())) + }); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "explode".into(), + params: None, + id: Some(json!("rid")), + }; + let resp = router.handle_request(req, test_context()).await.unwrap(); + let resp = resp.unwrap(); + assert_eq!(resp.id, Some(json!("rid"))); + // `internal_error` constructor strips details for security; the wire + // message is the canonical "Internal error" string. We confirm the code + // and the absence of leaked detail. + let err = resp.error.unwrap(); + assert_eq!(err.code, -32603); + assert!(!err.message.contains("kaboom")); + } + + #[tokio::test] + async fn handler_error_without_id_propagates_as_err() { + let mut router = MessageRouter::new(); + router.register("explode", |_req, _ctx| async move { + Err::, _>(ServerError::Internal("kaboom".into())) + }); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "explode".into(), + params: None, + id: None, + }; + let result = router.handle_request(req, test_context()).await; + assert!(matches!(result, Err(ServerError::Internal(_)))); + } + + #[tokio::test] + async fn register_value_with_notification_request_returns_none() { + let mut router = MessageRouter::new(); + router.register_value("v.echo", |req, _ctx| async move { + Ok::(req.params.unwrap_or(json!(null))) + }); + // No id ⇒ notification. The register_value branch should return Ok(None). + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "v.echo".into(), + params: Some(json!(true)), + id: None, + }; + assert!( + router + .handle_request(req, test_context()) + .await + .unwrap() + .is_none() + ); + } + + #[tokio::test] + async fn register_value_handler_error_with_id_becomes_error_response() { + let mut router = MessageRouter::new(); + router.register_value("v.fail", |_req, _ctx| async move { + Err::(ServerError::Internal("nope".into())) + }); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "v.fail".into(), + params: None, + id: Some(json!(9)), + }; + let resp = router.handle_request(req, test_context()).await.unwrap(); + let resp = resp.unwrap(); + assert_eq!(resp.id, Some(json!(9))); + let err = resp.error.unwrap(); + assert_eq!(err.code, -32603); + assert!(!err.message.contains("nope")); + } + + #[tokio::test] + async fn unknown_method_without_id_returns_no_response() { + let router = MessageRouter::new(); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "ghost".into(), + params: None, + id: None, + }; + assert!( + router + .handle_request(req, test_context()) + .await + .unwrap() + .is_none() + ); + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs new file mode 100644 index 0000000..aa58808 --- /dev/null +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs @@ -0,0 +1,262 @@ +//! Direct unit tests for `DefaultConnectionManager`. +//! +//! The end-to-end suite in `examples/bidirectional-chat` and +//! `ras-jsonrpc-bidirectional-macro/tests/bidirectional_integration.rs` +//! covers the manager's happy path indirectly. This file pins down the +//! manager's contract on its own — subscriptions, broadcast counts, +//! permission filtering, and pending-request lifecycle — without spinning +//! up a real WebSocket. + +use std::collections::HashSet; +use std::sync::Arc; + +use ras_auth_core::AuthenticatedUser; +use ras_jsonrpc_bidirectional_server::DefaultConnectionManager; +use ras_jsonrpc_bidirectional_server::connection::ChannelMessageSender; +use ras_jsonrpc_bidirectional_types::{ + BidirectionalMessage, ConnectionId, ConnectionInfo, ConnectionManager, +}; +use ras_jsonrpc_types::JsonRpcResponse; +use tokio::sync::{mpsc, oneshot}; + +fn user(id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: id.to_string(), + permissions: perms.iter().map(|s| s.to_string()).collect::>(), + metadata: None, + } +} + +/// Build a connection paired with a real receiver so we can observe sends. +async fn join( + mgr: &DefaultConnectionManager, +) -> (ConnectionId, mpsc::Receiver) { + let id = ConnectionId::new(); + let (tx, rx) = mpsc::channel(16); + let sender = ChannelMessageSender::new(id, tx); + let info = ConnectionInfo::new(id); + mgr.add_connection_with_sender_direct(info, sender) + .await + .unwrap(); + (id, rx) +} + +#[tokio::test] +async fn add_remove_round_trip_and_inspect() { + let mgr = DefaultConnectionManager::new(); + assert_eq!(mgr.connection_count(), 0); + + let (a, _ra) = join(&mgr).await; + let (b, _rb) = join(&mgr).await; + assert_eq!(mgr.connection_count(), 2); + + let ids = mgr.get_connection_ids(); + assert!(ids.contains(&a) && ids.contains(&b)); + assert!(mgr.connection_exists(a).await.unwrap()); + assert!(mgr.get_sender(a).is_some()); + + // Removing a missing id is logged-and-ignored, not an error. + mgr.remove_connection(ConnectionId::new()).await.unwrap(); + + mgr.remove_connection(a).await.unwrap(); + assert_eq!(mgr.connection_count(), 1); + assert!(!mgr.connection_exists(a).await.unwrap()); + assert!(mgr.get_sender(a).is_none()); +} + +#[tokio::test] +async fn add_connection_with_sender_box_downcasts() { + let mgr = DefaultConnectionManager::new(); + let id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(1); + let sender = ChannelMessageSender::new(id, tx); + // Round-trip through Box as the trait method requires. + let boxed: Box = Box::new(sender); + mgr.add_connection_with_sender(ConnectionInfo::new(id), boxed) + .await + .unwrap(); + assert!(mgr.connection_exists(id).await.unwrap()); + assert!(mgr.get_sender(id).is_some()); +} + +#[tokio::test] +async fn add_connection_with_unknown_sender_falls_back_to_dummy() { + let mgr = DefaultConnectionManager::new(); + let id = ConnectionId::new(); + let bogus: Box = Box::new(123u32); + mgr.add_connection_with_sender(ConnectionInfo::new(id), bogus) + .await + .unwrap(); + assert!(mgr.connection_exists(id).await.unwrap()); +} + +#[tokio::test] +async fn subscriptions_track_topics_and_clean_up_on_remove() { + let mgr = DefaultConnectionManager::new(); + let (a, _ra) = join(&mgr).await; + let (b, _rb) = join(&mgr).await; + + mgr.add_subscription(a, "room:1".into()).await.unwrap(); + mgr.add_subscription(b, "room:1".into()).await.unwrap(); + mgr.add_subscription(b, "room:2".into()).await.unwrap(); + + let topics = mgr.get_active_topics(); + assert!(topics.contains(&"room:1".to_string())); + assert!(topics.contains(&"room:2".to_string())); + + let r1: HashSet<_> = mgr.get_topic_connections("room:1").into_iter().collect(); + assert!(r1.contains(&a) && r1.contains(&b)); + + let subs_b = mgr.get_subscriptions(b).await.unwrap(); + assert!(subs_b.iter().any(|s| s == "room:1")); + assert!(subs_b.iter().any(|s| s == "room:2")); + + // Direct unsubscribe on the only-non-empty topic frees it. + mgr.remove_subscription(a, "room:1").await.unwrap(); + mgr.remove_subscription(b, "room:1").await.unwrap(); + assert!(!mgr.get_active_topics().contains(&"room:1".to_string())); + + // Removing a connection prunes any remaining subscriptions for it. + mgr.remove_connection(b).await.unwrap(); + // room:2 had only b, so it should be gone. + assert!(!mgr.get_active_topics().contains(&"room:2".to_string())); +} + +#[tokio::test] +async fn subscribed_connections_returns_full_info_for_topic() { + let mgr = DefaultConnectionManager::new(); + let (a, _ra) = join(&mgr).await; + let (_b, _rb) = join(&mgr).await; + mgr.add_subscription(a, "t".into()).await.unwrap(); + let subs = mgr.get_subscribed_connections("t").await.unwrap(); + assert_eq!(subs.len(), 1); + assert_eq!(subs[0].id, a); +} + +#[tokio::test] +async fn user_set_clear_filter_paths_for_broadcasts() { + let mgr = DefaultConnectionManager::new(); + let (auth_id, mut auth_rx) = join(&mgr).await; + let (admin_id, mut admin_rx) = join(&mgr).await; + let (_anon_id, mut anon_rx) = join(&mgr).await; + + mgr.set_connection_user(auth_id, user("u", &["read"])) + .await + .unwrap(); + mgr.set_connection_user(admin_id, user("a", &["read", "admin"])) + .await + .unwrap(); + // anon stays unauthenticated. + + // broadcast_to_authenticated reaches both authenticated peers. + let n = mgr + .broadcast_to_authenticated(BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 2); + assert!(auth_rx.try_recv().is_ok()); + assert!(admin_rx.try_recv().is_ok()); + assert!(anon_rx.try_recv().is_err()); + + // broadcast_to_permission only reaches the admin. + let n = mgr + .broadcast_to_permission("admin", BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 1); + assert!(admin_rx.try_recv().is_ok()); + + // clear_connection_user flips the auth flag back. + mgr.clear_connection_user(auth_id).await.unwrap(); + let n = mgr + .broadcast_to_authenticated(BidirectionalMessage::Pong) + .await + .unwrap(); + assert_eq!(n, 1); + + // Setting/clearing user on missing id is best-effort, not an error. + mgr.set_connection_user(ConnectionId::new(), user("ghost", &[])) + .await + .unwrap(); + mgr.clear_connection_user(ConnectionId::new()) + .await + .unwrap(); +} + +#[tokio::test] +async fn broadcast_to_topic_counts_recipients_and_skips_empty() { + let mgr = DefaultConnectionManager::new(); + let (a, mut ra) = join(&mgr).await; + let (_b, _rb) = join(&mgr).await; + mgr.add_subscription(a, "t".into()).await.unwrap(); + + let n = mgr + .broadcast_to_topic("t", BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 1); + assert!(ra.try_recv().is_ok()); + + // Topic with no subscribers reports zero. + let n = mgr + .broadcast_to_topic("missing", BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn pending_request_lifecycle() { + let mgr = DefaultConnectionManager::new(); + let (id, _rx) = join(&mgr).await; + + let (tx, rx) = oneshot::channel(); + mgr.register_pending_request(id, serde_json::json!("rid"), tx) + .await + .unwrap(); + + // remove_pending_request hands back the sender. + let pulled = mgr + .remove_pending_request(id, &serde_json::json!("rid")) + .await + .unwrap(); + assert!(pulled.is_some()); + drop(pulled); + drop(rx); + + // handle_pending_response with no registered id reports false. + let resp = JsonRpcResponse::success(serde_json::json!("ok"), Some(serde_json::json!("rid"))); + let handled = mgr.handle_pending_response(id, resp).await.unwrap(); + assert!(!handled); + + // Register again, then route a real response through handle_pending_response. + let (tx, rx) = oneshot::channel(); + mgr.register_pending_request(id, serde_json::json!("rid2"), tx) + .await + .unwrap(); + let resp = JsonRpcResponse::success(serde_json::json!(7), Some(serde_json::json!("rid2"))); + assert!(mgr.handle_pending_response(id, resp).await.unwrap()); + let received = rx.await.unwrap(); + assert_eq!(received.result.unwrap(), serde_json::json!(7)); + + // Removing for a connection that never registered any returns None. + let pulled = mgr + .remove_pending_request(ConnectionId::new(), &serde_json::json!("nope")) + .await + .unwrap(); + assert!(pulled.is_none()); +} + +#[tokio::test] +async fn send_to_missing_connection_is_silent_ok() { + let mgr = DefaultConnectionManager::new(); + // Nothing registered — manager logs and returns Ok. + mgr.send_to_connection(ConnectionId::new(), BidirectionalMessage::Ping) + .await + .unwrap(); +} + +#[tokio::test] +async fn default_impl_is_equivalent_to_new() { + let _ = Arc::new(DefaultConnectionManager::default()); +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs index 1bc6a62..2535447 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs @@ -86,3 +86,80 @@ impl BidirectionalError { Self::Custom(error.to_string()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ConnectionId; + + #[test] + fn helpers_wrap_display_into_correct_variant() { + let id = ConnectionId::new(); + assert!( + BidirectionalError::ConnectionNotFound(id) + .to_string() + .contains(&id.to_string()) + ); + assert_eq!( + BidirectionalError::ConnectionAlreadyExists(id).to_string(), + format!("Connection already exists: {id}") + ); + assert_eq!( + BidirectionalError::SendError("oops".into()).to_string(), + "Failed to send message: oops" + ); + assert_eq!( + BidirectionalError::BroadcastError { + topic: "t".into(), + reason: "r".into(), + } + .to_string(), + "Failed to broadcast to topic 't': r" + ); + assert_eq!( + BidirectionalError::AuthenticationRequired.to_string(), + "Authentication required" + ); + assert_eq!( + BidirectionalError::PermissionDenied("admin".into()).to_string(), + "Permission denied: admin" + ); + assert_eq!( + BidirectionalError::InvalidTopic("foo".into()).to_string(), + "Invalid subscription topic: foo" + ); + assert_eq!( + BidirectionalError::ConnectionClosed.to_string(), + "Connection closed" + ); + assert_eq!(BidirectionalError::Timeout.to_string(), "Request timeout"); + assert_eq!( + BidirectionalError::RpcError("nope".into()).to_string(), + "RPC error: nope" + ); + assert_eq!( + BidirectionalError::InvalidResponse("garbage".into()).to_string(), + "Invalid response: garbage" + ); + assert_eq!( + BidirectionalError::ConnectionError("eof".into()).to_string(), + "Connection error: eof" + ); + + // Constructor helpers. + let we = BidirectionalError::websocket("boom"); + assert!(matches!(we, BidirectionalError::WebSocketError(ref s) if s == "boom")); + let ie = BidirectionalError::internal("ugh"); + assert!(matches!(ie, BidirectionalError::InternalError(ref s) if s == "ugh")); + let ce = BidirectionalError::custom("plain"); + assert!(matches!(ce, BidirectionalError::Custom(ref s) if s == "plain")); + } + + #[test] + fn from_serde_json_error() { + let parse_err = serde_json::from_str::("{not json}").unwrap_err(); + let wrapped: BidirectionalError = parse_err.into(); + assert!(matches!(wrapped, BidirectionalError::SerializationError(_))); + assert!(wrapped.to_string().starts_with("Serialization error:")); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs index c593806..b78829f 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs @@ -184,3 +184,291 @@ pub trait ConnectionManagerExt: ConnectionManager { // Blanket implementation for all ConnectionManager types impl ConnectionManagerExt for T {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{BidirectionalError, BroadcastMessage, ConnectionId, ConnectionInfo}; + use ras_auth_core::AuthenticatedUser; + use ras_jsonrpc_types::JsonRpcResponse; + use std::collections::HashMap; + use std::collections::HashSet; + use std::sync::Mutex; + use tokio::sync::oneshot; + + /// Tiny stub manager: enough state to exercise the default `connection_*` + /// methods and the `ConnectionManagerExt` helpers without dragging in the + /// full `DefaultConnectionManager`. Uses sync `Mutex` for simplicity. + #[derive(Default)] + struct StubManager { + conns: Mutex>, + subs: Mutex>>, + sent: Mutex>, + broadcasts: Mutex>, + } + + #[async_trait] + impl ConnectionManager for StubManager { + async fn add_connection(&self, info: ConnectionInfo) -> Result<()> { + self.conns.lock().unwrap().insert(info.id, info); + Ok(()) + } + async fn remove_connection(&self, id: ConnectionId) -> Result<()> { + self.conns + .lock() + .unwrap() + .remove(&id) + .ok_or(BidirectionalError::ConnectionNotFound(id))?; + Ok(()) + } + async fn get_connection(&self, id: ConnectionId) -> Result> { + Ok(self.conns.lock().unwrap().get(&id).cloned()) + } + async fn get_all_connections(&self) -> Result> { + Ok(self.conns.lock().unwrap().values().cloned().collect()) + } + async fn get_subscribed_connections(&self, topic: &str) -> Result> { + let ids = self + .subs + .lock() + .unwrap() + .get(topic) + .cloned() + .unwrap_or_default(); + let conns = self.conns.lock().unwrap(); + Ok(ids.iter().filter_map(|id| conns.get(id).cloned()).collect()) + } + async fn set_connection_user( + &self, + id: ConnectionId, + user: AuthenticatedUser, + ) -> Result<()> { + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.set_user(user); + Ok(()) + } else { + Err(BidirectionalError::ConnectionNotFound(id)) + } + } + async fn clear_connection_user(&self, id: ConnectionId) -> Result<()> { + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.clear_user(); + Ok(()) + } else { + Err(BidirectionalError::ConnectionNotFound(id)) + } + } + async fn add_subscription(&self, id: ConnectionId, topic: String) -> Result<()> { + self.subs + .lock() + .unwrap() + .entry(topic.clone()) + .or_default() + .insert(id); + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.subscribe(topic); + } + Ok(()) + } + async fn remove_subscription(&self, id: ConnectionId, topic: &str) -> Result<()> { + if let Some(set) = self.subs.lock().unwrap().get_mut(topic) { + set.remove(&id); + } + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.unsubscribe(topic); + } + Ok(()) + } + async fn get_subscriptions(&self, id: ConnectionId) -> Result> { + Ok(self + .conns + .lock() + .unwrap() + .get(&id) + .map(|c| c.subscriptions.iter().cloned().collect()) + .unwrap_or_default()) + } + async fn send_to_connection( + &self, + id: ConnectionId, + message: BidirectionalMessage, + ) -> Result<()> { + self.sent.lock().unwrap().push((id, message)); + Ok(()) + } + async fn broadcast_to_topic( + &self, + topic: &str, + message: BidirectionalMessage, + ) -> Result { + let n = self + .subs + .lock() + .unwrap() + .get(topic) + .map(|s| s.len()) + .unwrap_or(0); + self.broadcasts + .lock() + .unwrap() + .push((topic.to_string(), message)); + Ok(n) + } + async fn broadcast_to_authenticated( + &self, + _message: BidirectionalMessage, + ) -> Result { + Ok(self.authenticated_connection_count().await?) + } + async fn broadcast_to_permission( + &self, + permission: &str, + _message: BidirectionalMessage, + ) -> Result { + Ok(self + .conns + .lock() + .unwrap() + .values() + .filter(|c| c.has_permission(permission)) + .count()) + } + async fn register_pending_request( + &self, + _connection_id: ConnectionId, + _request_id: serde_json::Value, + _response_sender: oneshot::Sender, + ) -> Result<()> { + Ok(()) + } + async fn remove_pending_request( + &self, + _connection_id: ConnectionId, + _request_id: &serde_json::Value, + ) -> Result>> { + Ok(None) + } + async fn handle_pending_response( + &self, + _connection_id: ConnectionId, + _response: JsonRpcResponse, + ) -> Result { + Ok(false) + } + } + + fn user(id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: id.to_string(), + permissions: perms.iter().map(|s| s.to_string()).collect(), + metadata: None, + } + } + + #[tokio::test] + async fn default_methods_delegate_to_required_methods() { + let mgr = StubManager::default(); + + // Initially nothing is registered. + assert_eq!(mgr.connection_count().await.unwrap(), 0); + assert_eq!(mgr.authenticated_connection_count().await.unwrap(), 0); + assert_eq!(mgr.cleanup_stale_connections().await.unwrap(), 0); + + let id1 = ConnectionId::new(); + let id2 = ConnectionId::new(); + mgr.add_connection(ConnectionInfo::new(id1)).await.unwrap(); + mgr.add_connection(ConnectionInfo::new(id2)).await.unwrap(); + + assert!(mgr.connection_exists(id1).await.unwrap()); + assert!(mgr.connection_exists(id2).await.unwrap()); + assert_eq!(mgr.connection_count().await.unwrap(), 2); + + // Authenticate one connection. + mgr.set_connection_user(id1, user("u1", &["read"])) + .await + .unwrap(); + assert_eq!(mgr.authenticated_connection_count().await.unwrap(), 1); + + // The default `add_connection_with_sender` must fall through to + // `add_connection`. + let id3 = ConnectionId::new(); + let dummy: Box = Box::new(()) as _; + mgr.add_connection_with_sender(ConnectionInfo::new(id3), dummy) + .await + .unwrap(); + assert!(mgr.connection_exists(id3).await.unwrap()); + } + + #[tokio::test] + async fn ext_helpers_route_to_correct_messages() { + let mgr = StubManager::default(); + let id = ConnectionId::new(); + mgr.add_connection(ConnectionInfo::new(id)).await.unwrap(); + mgr.add_subscription(id, "room:1".into()).await.unwrap(); + + // notify_connection wraps as ServerNotification. + mgr.notify_connection(id, "evt", serde_json::json!({"k": 1})) + .await + .unwrap(); + let sent = mgr.sent.lock().unwrap(); + assert_eq!(sent.len(), 1); + match &sent[0].1 { + BidirectionalMessage::ServerNotification(n) => assert_eq!(n.method, "evt"), + other => panic!("unexpected: {other:?}"), + } + drop(sent); + + // notify_topic broadcasts to the topic with one subscriber. + let n = mgr + .notify_topic("room:1", "msg", serde_json::json!("hi")) + .await + .unwrap(); + assert_eq!(n, 1); + let bs = mgr.broadcasts.lock().unwrap(); + assert!(matches!( + &bs[0].1, + BidirectionalMessage::Broadcast(BroadcastMessage { method, .. }) if method == "msg" + )); + drop(bs); + + // ping_connection should produce a Ping payload. + mgr.ping_connection(id).await.unwrap(); + let sent = mgr.sent.lock().unwrap(); + assert!(matches!(sent.last().unwrap().1, BidirectionalMessage::Ping)); + } + + #[tokio::test] + async fn user_helpers_filter_and_disconnect() { + let mgr = StubManager::default(); + let alice1 = ConnectionId::new(); + let alice2 = ConnectionId::new(); + let bob = ConnectionId::new(); + mgr.add_connection(ConnectionInfo::new(alice1)) + .await + .unwrap(); + mgr.add_connection(ConnectionInfo::new(alice2)) + .await + .unwrap(); + mgr.add_connection(ConnectionInfo::new(bob)).await.unwrap(); + mgr.set_connection_user(alice1, user("alice", &[])) + .await + .unwrap(); + mgr.set_connection_user(alice2, user("alice", &[])) + .await + .unwrap(); + mgr.set_connection_user(bob, user("bob", &[])) + .await + .unwrap(); + + assert_eq!(mgr.get_user_connections("alice").await.unwrap().len(), 2); + assert_eq!(mgr.get_user_connections("bob").await.unwrap().len(), 1); + assert_eq!(mgr.get_user_connections("nobody").await.unwrap().len(), 0); + + let dropped = mgr.disconnect_user("alice").await.unwrap(); + assert_eq!(dropped, 2); + assert!(!mgr.connection_exists(alice1).await.unwrap()); + assert!(!mgr.connection_exists(alice2).await.unwrap()); + // Bob unaffected. + assert!(mgr.connection_exists(bob).await.unwrap()); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs index 9fb0be5..5d26954 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs @@ -245,4 +245,124 @@ mod tests { BidirectionalMessage::ServerNotification(n) if n.method == "test.method" )); } + + #[tokio::test] + async fn message_sender_ext_request_response_subscription() { + struct Recorder { + id: ConnectionId, + sent: Arc>>, + } + #[async_trait] + impl MessageSender for Recorder { + async fn send_message(&self, message: BidirectionalMessage) -> Result<()> { + self.sent.lock().await.push(message); + Ok(()) + } + async fn close(&self) -> Result<()> { + Ok(()) + } + async fn is_connected(&self) -> bool { + true + } + fn connection_id(&self) -> ConnectionId { + self.id + } + } + let r = Recorder { + id: ConnectionId::new(), + sent: Arc::new(Mutex::new(Vec::new())), + }; + + r.send_request(ras_jsonrpc_types::JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "m".into(), + params: None, + id: Some(serde_json::json!(1)), + }) + .await + .unwrap(); + r.send_response(ras_jsonrpc_types::JsonRpcResponse::success( + serde_json::json!("ok"), + Some(serde_json::json!(1)), + )) + .await + .unwrap(); + r.send_subscription_update(vec!["t1".into()], true) + .await + .unwrap(); + r.send_subscription_update(vec!["t1".into()], false) + .await + .unwrap(); + + let s = r.sent.lock().await; + assert!(matches!(s[0], BidirectionalMessage::Request(_))); + assert!(matches!(s[1], BidirectionalMessage::Response(_))); + assert!(matches!(s[2], BidirectionalMessage::Subscribe { .. })); + assert!(matches!(s[3], BidirectionalMessage::Unsubscribe { .. })); + } + + #[tokio::test] + async fn noop_message_sender_round_trip() { + let id = ConnectionId::new(); + let sender = NoOpMessageSender::with_connection_id(id); + assert_eq!(sender.connection_id(), id); + assert!(sender.is_connected().await); + sender + .send_message(BidirectionalMessage::Ping) + .await + .unwrap(); + sender.close().await.unwrap(); + + // Default constructor + Default impl. + let s2 = NoOpMessageSender::new(); + let s3 = NoOpMessageSender::default(); + assert_ne!(s2.connection_id(), s3.connection_id()); + } + + #[tokio::test] + async fn websocket_sender_drives_real_sink() { + use futures::channel::mpsc; + use futures::stream::StreamExt; + + // mpsc::channel's Sender impls Sink, satisfying the SinkExt bound + // on `WebSocketMessageSender::new`. + let (tx, mut rx) = mpsc::channel::(8); + let id = ConnectionId::new(); + let sender = WebSocketMessageSender::new(id, tx); + + assert_eq!(sender.connection_id(), id); + assert!(sender.is_connected().await); + + sender + .send_message(BidirectionalMessage::Ping) + .await + .unwrap(); + // close once → emits a Close frame and flips is_closed. + sender.close().await.unwrap(); + assert!(!sender.is_connected().await); + // close again is idempotent (no panic, no extra send). + sender.close().await.unwrap(); + + // Sending after close yields ConnectionClosed. + let err = sender + .send_message(BidirectionalMessage::Pong) + .await + .unwrap_err(); + assert!(matches!(err, BidirectionalError::ConnectionClosed)); + + // Drain what we actually pushed: a Text(Ping) and a Close. + let mut received: Vec = Vec::new(); + while let Some(m) = rx.next().await { + received.push(m); + if received.len() == 2 { + break; + } + } + assert_eq!(received.len(), 2); + match &received[0] { + WsMessage::Text(t) => assert!(t.contains("ping")), + other => panic!("expected Text(ping), got {other:?}"), + } + assert!(matches!(received[1], WsMessage::Close(_))); + } } diff --git a/crates/rpc/ras-jsonrpc-types/src/lib.rs b/crates/rpc/ras-jsonrpc-types/src/lib.rs index b574757..6ec268a 100644 --- a/crates/rpc/ras-jsonrpc-types/src/lib.rs +++ b/crates/rpc/ras-jsonrpc-types/src/lib.rs @@ -202,3 +202,76 @@ impl JsonRpcError { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn jsonrpc_request_constructor_sets_version() { + let r = JsonRpcRequest::new( + "m".into(), + Some(serde_json::json!(1)), + Some(serde_json::json!("rid")), + ); + assert_eq!(r.jsonrpc, "2.0"); + assert_eq!(r.method, "m"); + } + + #[test] + fn jsonrpc_response_success_and_error() { + let s = JsonRpcResponse::success(serde_json::json!("ok"), Some(serde_json::json!(1))); + assert_eq!(s.jsonrpc, "2.0"); + assert!(s.error.is_none()); + assert_eq!(s.result, Some(serde_json::json!("ok"))); + + let e = JsonRpcResponse::error(JsonRpcError::parse_error(), Some(serde_json::json!(1))); + assert!(e.result.is_none()); + assert_eq!(e.error.unwrap().code, error_codes::PARSE_ERROR); + } + + #[test] + fn json_rpc_error_constructors_use_canonical_codes() { + assert_eq!(JsonRpcError::parse_error().code, error_codes::PARSE_ERROR); + assert_eq!( + JsonRpcError::invalid_request().code, + error_codes::INVALID_REQUEST + ); + let nf = JsonRpcError::method_not_found("m"); + assert_eq!(nf.code, error_codes::METHOD_NOT_FOUND); + assert!(nf.message.contains("m")); + assert_eq!( + JsonRpcError::invalid_params("bad".into()).code, + error_codes::INVALID_PARAMS + ); + assert_eq!( + JsonRpcError::internal_error("e".into()).code, + error_codes::INTERNAL_ERROR + ); + assert_eq!( + JsonRpcError::authentication_required().code, + error_codes::AUTHENTICATION_REQUIRED + ); + assert_eq!( + JsonRpcError::token_expired().code, + error_codes::TOKEN_EXPIRED + ); + } + + #[test] + fn insufficient_permissions_carries_data() { + let err = JsonRpcError::insufficient_permissions(vec!["admin".into()], vec!["user".into()]); + assert_eq!(err.code, error_codes::INSUFFICIENT_PERMISSIONS); + let data = err.data.unwrap(); + assert_eq!(data["required"], serde_json::json!(["admin"])); + assert_eq!(data["has"], serde_json::json!(["user"])); + } + + #[test] + fn request_with_no_id_skips_field_in_serialization() { + let req = JsonRpcRequest::new("notify".into(), None, None); + let s = serde_json::to_string(&req).unwrap(); + assert!(!s.contains("\"id\"")); + assert!(!s.contains("\"params\"")); + } +} From 631567c28070d544a02c43bf78c512b81d076176 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Sat, 25 Apr 2026 09:38:01 +0200 Subject: [PATCH 5/6] Fix clippy never_loop in file-service-example The demo's `upload` handler used `while let Some(field) = ...` and then unconditionally returned on the first iteration, which a recent clippy release flags as `never_loop`. Replaced with a single `next_field().await?.ok_or_else(...)` so the control flow matches what the code was actually doing. Behavior is unchanged: the handler still consumes exactly the first multipart field and returns NotFound when none is present. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/file-service-example/src/main.rs | 55 ++++++++++++----------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/examples/file-service-example/src/main.rs b/examples/file-service-example/src/main.rs index 8d241fc..3b8ae92 100644 --- a/examples/file-service-example/src/main.rs +++ b/examples/file-service-example/src/main.rs @@ -102,34 +102,37 @@ impl DocumentServiceTrait for DocumentServiceImpl { ) -> Result { println!("User {} is uploading a file", user.user_id); - // Process the multipart upload - while let Some(field) = multipart.next_field().await.map_err(|e| { - DocumentServiceFileError::UploadFailed(format!("Failed to get next field: {}", e)) - })? { - let name = field.name().unwrap_or("unknown").to_string(); - let file_name = field.file_name().unwrap_or("unknown").to_string(); - let data = field.bytes().await.map_err(|e| { - DocumentServiceFileError::UploadFailed(format!("Failed to read field data: {}", e)) + // Process the first multipart field — that's the uploaded file in the + // demo's contract. Real implementations would loop and accept several. + let field = multipart + .next_field() + .await + .map_err(|e| { + DocumentServiceFileError::UploadFailed(format!("Failed to get next field: {}", e)) + })? + .ok_or_else(|| { + DocumentServiceFileError::UploadFailed("No file in multipart data".to_string()) })?; - println!( - "Received field '{}' with filename '{}', size: {} bytes", - name, - file_name, - data.len() - ); - - // In a real implementation, you would save this to storage - return Ok(UploadResponse { - file_id: format!("file_{}", Uuid::new_v4()), - size: data.len() as u64, - filename: file_name, - }); - } - - Err(DocumentServiceFileError::UploadFailed( - "No file in multipart data".to_string(), - )) + let name = field.name().unwrap_or("unknown").to_string(); + let file_name = field.file_name().unwrap_or("unknown").to_string(); + let data = field.bytes().await.map_err(|e| { + DocumentServiceFileError::UploadFailed(format!("Failed to read field data: {}", e)) + })?; + + println!( + "Received field '{}' with filename '{}', size: {} bytes", + name, + file_name, + data.len() + ); + + // In a real implementation, you would save this to storage + Ok(UploadResponse { + file_id: format!("file_{}", Uuid::new_v4()), + size: data.len() as u64, + filename: file_name, + }) } async fn info( From 3625e01d152aed92235ccc592e9867c65a17ed76 Mon Sep 17 00:00:00 2001 From: Mathias Myrland Date: Sat, 25 Apr 2026 09:38:49 +0200 Subject: [PATCH 6/6] ci: add fmt/clippy/test/coverage and criterion bench workflows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the previous test-coverage.yml (which only ran cargo test) with two workflows: ci.yml — runs on every PR and master push, with parallel jobs for: * rustfmt --check * cargo clippy --workspace --all-targets --all-features (errors only; legacy warnings are not yet promoted to deny) * cargo test --workspace --all-features (separate build + run steps so failures point to the right phase) * cargo llvm-cov producing lcov.info as a build artifact and a printed summary in the job log; Codecov upload is left commented for the repo owner to enable with a CODECOV_TOKEN bench.yml — runs on master pushes and on workflow_dispatch only. PRs intentionally skip benches: GitHub runners are noisy and PR-time bench diffs are dominated by that noise. Uploads target/criterion/ as an artifact for trend inspection. Each macro's bench is run individually under a short measurement window so the total runner time stays bounded. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/bench.yml | 53 ++++++++++++++++++++ .github/workflows/ci.yml | 78 +++++++++++++++++++++++++++++ .github/workflows/test-coverage.yml | 32 ------------ 3 files changed, 131 insertions(+), 32 deletions(-) create mode 100644 .github/workflows/bench.yml create mode 100644 .github/workflows/ci.yml delete mode 100644 .github/workflows/test-coverage.yml diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000..f2f8d51 --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,53 @@ +name: Benchmarks + +# GitHub-hosted runners are noisy — benches here track the *order of magnitude* +# baseline rather than fail on small regressions. They run on master pushes and +# on demand. PRs intentionally do NOT run benches: they're slow and the noise +# floor would just produce false alarms. + +on: + push: + branches: [master, main] + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: 0 + +jobs: + bench: + name: Criterion benches + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + # Bench builds use the release profile; cache it under its own key + # so it doesn't fight with the test job's debug cache. + key: bench + - name: Run criterion benches (each macro) + run: | + set -euo pipefail + mkdir -p bench-output + for crate in \ + ras-jsonrpc-macro \ + ras-rest-macro \ + ras-file-macro \ + ras-jsonrpc-bidirectional-macro + do + echo "::group::$crate" + cargo bench -p "$crate" -- \ + --warm-up-time 1 --measurement-time 3 \ + | tee "bench-output/$crate.txt" + echo "::endgroup::" + done + - name: Upload bench artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: criterion-results + path: | + bench-output/ + target/criterion/ + retention-days: 30 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..10ae2fe --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,78 @@ +name: CI + +on: + push: + branches: [master, main] + pull_request: + branches: [master, main] + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: 0 + +jobs: + fmt: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - run: cargo fmt --all -- --check + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + - uses: Swatinem/rust-cache@v2 + # Catch hard errors and lint regressions. We don't enforce -D warnings + # workspace-wide yet (legacy code has standing warnings); the CI gate is + # "compiles cleanly + no clippy ERRORS". Tighten later by promoting + # selected lints to deny. + - run: cargo clippy --workspace --all-targets --all-features + + test: + name: Test (workspace) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: Build tests + run: cargo test --workspace --all-features --no-run --locked + - name: Run tests + run: cargo test --workspace --all-features -- --nocapture --test-threads=4 + + coverage: + name: Coverage report + runs-on: ubuntu-latest + needs: [test] + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: llvm-tools-preview + - uses: taiki-e/install-action@cargo-llvm-cov + - uses: Swatinem/rust-cache@v2 + - name: Generate coverage (lcov) + run: cargo llvm-cov --workspace --all-features --lcov --output-path lcov.info + - name: Print summary + run: cargo llvm-cov report --summary-only + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-lcov + path: lcov.info + retention-days: 30 + # Optional: enable Codecov upload by adding a CODECOV_TOKEN secret and + # uncommenting. Without the token the run still succeeds and the lcov + # artifact above remains the source of truth. + # - uses: codecov/codecov-action@v4 + # with: + # files: lcov.info + # fail_ci_if_error: false diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml deleted file mode 100644 index 8346f61..0000000 --- a/.github/workflows/test-coverage.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Test Coverage - -on: - push: - branches: [master, main] - pull_request: - branches: [master, main] - -env: - CARGO_TERM_COLOR: always - -jobs: - test-coverage: - name: Test and Coverage - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Install Rust - uses: dtolnay/rust-toolchain@stable - with: - components: llvm-tools-preview - - - name: Install cargo-llvm-cov - uses: taiki-e/install-action@cargo-llvm-cov - - - name: Cache dependencies - uses: Swatinem/rust-cache@v2 - - - name: Run tests - run: cargo test --all-features --workspace