diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000..f2f8d51 --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,53 @@ +name: Benchmarks + +# GitHub-hosted runners are noisy — benches here track the *order of magnitude* +# baseline rather than fail on small regressions. They run on master pushes and +# on demand. PRs intentionally do NOT run benches: they're slow and the noise +# floor would just produce false alarms. + +on: + push: + branches: [master, main] + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: 0 + +jobs: + bench: + name: Criterion benches + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + # Bench builds use the release profile; cache it under its own key + # so it doesn't fight with the test job's debug cache. + key: bench + - name: Run criterion benches (each macro) + run: | + set -euo pipefail + mkdir -p bench-output + for crate in \ + ras-jsonrpc-macro \ + ras-rest-macro \ + ras-file-macro \ + ras-jsonrpc-bidirectional-macro + do + echo "::group::$crate" + cargo bench -p "$crate" -- \ + --warm-up-time 1 --measurement-time 3 \ + | tee "bench-output/$crate.txt" + echo "::endgroup::" + done + - name: Upload bench artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: criterion-results + path: | + bench-output/ + target/criterion/ + retention-days: 30 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..10ae2fe --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,78 @@ +name: CI + +on: + push: + branches: [master, main] + pull_request: + branches: [master, main] + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: 0 + +jobs: + fmt: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - run: cargo fmt --all -- --check + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + - uses: Swatinem/rust-cache@v2 + # Catch hard errors and lint regressions. We don't enforce -D warnings + # workspace-wide yet (legacy code has standing warnings); the CI gate is + # "compiles cleanly + no clippy ERRORS". Tighten later by promoting + # selected lints to deny. + - run: cargo clippy --workspace --all-targets --all-features + + test: + name: Test (workspace) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: Build tests + run: cargo test --workspace --all-features --no-run --locked + - name: Run tests + run: cargo test --workspace --all-features -- --nocapture --test-threads=4 + + coverage: + name: Coverage report + runs-on: ubuntu-latest + needs: [test] + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: llvm-tools-preview + - uses: taiki-e/install-action@cargo-llvm-cov + - uses: Swatinem/rust-cache@v2 + - name: Generate coverage (lcov) + run: cargo llvm-cov --workspace --all-features --lcov --output-path lcov.info + - name: Print summary + run: cargo llvm-cov report --summary-only + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-lcov + path: lcov.info + retention-days: 30 + # Optional: enable Codecov upload by adding a CODECOV_TOKEN secret and + # uncommenting. Without the token the run still succeeds and the lcov + # artifact above remains the source of truth. + # - uses: codecov/codecov-action@v4 + # with: + # files: lcov.info + # fail_ci_if_error: false diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml deleted file mode 100644 index 8346f61..0000000 --- a/.github/workflows/test-coverage.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Test Coverage - -on: - push: - branches: [master, main] - pull_request: - branches: [master, main] - -env: - CARGO_TERM_COLOR: always - -jobs: - test-coverage: - name: Test and Coverage - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Install Rust - uses: dtolnay/rust-toolchain@stable - with: - components: llvm-tools-preview - - - name: Install cargo-llvm-cov - uses: taiki-e/install-action@cargo-llvm-cov - - - name: Cache dependencies - uses: Swatinem/rust-cache@v2 - - - name: Run tests - run: cargo test --all-features --workspace diff --git a/Cargo.lock b/Cargo.lock index ee97d4d..03834dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.19" @@ -322,6 +328,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.21.7" @@ -532,6 +544,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "castaway" version = "0.2.3" @@ -577,6 +595,33 @@ dependencies = [ "windows-link", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.5.39" @@ -666,6 +711,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-random" version = "0.1.18" @@ -750,6 +801,63 @@ dependencies = [ "libc", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "futures", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "tokio", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -787,6 +895,18 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -820,6 +940,33 @@ dependencies = [ "syn", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if 1.0.0", + "cpufeatures", + "curve25519-dalek-derive", + "digest", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "darling" version = "0.20.11" @@ -893,6 +1040,17 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.4.0" @@ -915,6 +1073,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -1058,12 +1217,71 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der", + "digest", + "elliptic-curve", + "rfc6979", + "signature", + "spki", +] + +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8", + "signature", +] + +[[package]] +name = "ed25519-dalek" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" +dependencies = [ + "curve25519-dalek", + "ed25519", + "serde", + "sha2", + "subtle", + "zeroize", +] + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest", + "ff", + "generic-array", + "group", + "hkdf", + "pem-rfc7468", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + [[package]] name = "email_address" version = "0.2.9" @@ -1141,6 +1359,22 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "file-service-api" version = "0.1.0" @@ -1360,6 +1594,7 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", ] [[package]] @@ -1381,10 +1616,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if 1.0.0", - "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "wasm-bindgen", ] [[package]] @@ -1477,6 +1710,17 @@ dependencies = [ "web-sys", ] +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "h2" version = "0.4.10" @@ -1496,6 +1740,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if 1.0.0", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1538,6 +1793,24 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.3.1" @@ -1860,12 +2133,32 @@ dependencies = [ "serde", ] +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -1904,16 +2197,24 @@ dependencies = [ [[package]] name = "jsonwebtoken" -version = "9.3.1" +version = "10.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +checksum = "0529410abe238729a60b108898784df8984c87f6054c9c4fcacc47e4803c1ce1" dependencies = [ "base64 0.22.1", + "ed25519-dalek", + "getrandom 0.2.16", + "hmac", "js-sys", + "p256", + "p384", "pem", - "ring", + "rand 0.8.5", + "rsa", "serde", "serde_json", + "sha2", + "signature", "simple_asn1", ] @@ -1922,6 +2223,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "libc" @@ -1929,6 +2233,12 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -2128,6 +2438,22 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" +dependencies = [ + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2143,6 +2469,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2150,6 +2487,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -2226,6 +2564,12 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openrpc-to-bruno" version = "0.1.0" @@ -2362,6 +2706,30 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + [[package]] name = "parking_lot" version = "0.12.4" @@ -2418,6 +2786,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -2543,12 +2920,61 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -2593,6 +3019,15 @@ dependencies = [ "syn", ] +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -2712,14 +3147,18 @@ version = "0.1.0" dependencies = [ "async-trait", "axum", + "axum-test", + "criterion", "proc-macro2", "quote", "ras-auth-core", + "ras-test-helpers", "reqwest", "schemars", "serde", "serde_json", "syn", + "tempfile", "thiserror 2.0.12", "tokio", "tokio-util", @@ -2733,6 +3172,7 @@ dependencies = [ "serde", "serde_json", "thiserror 2.0.12", + "tokio", ] [[package]] @@ -2831,6 +3271,7 @@ dependencies = [ "async-trait", "axum", "chrono", + "criterion", "futures", "http", "proc-macro2", @@ -2841,6 +3282,7 @@ dependencies = [ "ras-jsonrpc-bidirectional-server", "ras-jsonrpc-bidirectional-types", "ras-jsonrpc-types", + "ras-test-helpers", "serde", "serde_json", "syn", @@ -2908,7 +3350,9 @@ version = "0.1.1" dependencies = [ "async-trait", "axum", + "axum-test", "bon", + "criterion", "futures", "proc-macro2", "quote", @@ -2917,6 +3361,7 @@ dependencies = [ "ras-identity-session", "ras-jsonrpc-core", "ras-jsonrpc-types", + "ras-test-helpers", "reqwest", "schemars", "serde", @@ -2982,7 +3427,9 @@ dependencies = [ "async-trait", "axum", "axum-extra", + "axum-test", "chrono", + "criterion", "futures", "hyper", "proc-macro2", @@ -2992,6 +3439,7 @@ dependencies = [ "ras-identity-session", "ras-jsonrpc-core", "ras-rest-core", + "ras-test-helpers", "reqwest", "schemars", "serde", @@ -3003,6 +3451,16 @@ dependencies = [ "wiremock", ] +[[package]] +name = "ras-test-helpers" +version = "0.0.0" +dependencies = [ + "axum", + "axum-test", + "ras-auth-core", + "tokio", +] + [[package]] name = "ratatui" version = "0.29.0" @@ -3015,7 +3473,7 @@ dependencies = [ "crossterm", "indoc", "instability", - "itertools", + "itertools 0.13.0", "lru", "paste", "strum", @@ -3024,6 +3482,26 @@ dependencies = [ "unicode-width 0.2.0", ] +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.12" @@ -3200,6 +3678,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + [[package]] name = "ring" version = "0.17.14" @@ -3226,6 +3714,26 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "rsa" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8573f03f5883dcaebdfcf4725caa1ecb9c15b2ef50c43a07b816e06799bb12d" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust-ini" version = "0.20.0" @@ -3257,6 +3765,15 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.44" @@ -3377,6 +3894,20 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -3400,12 +3931,19 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] @@ -3420,11 +3958,20 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -3565,6 +4112,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "simple_asn1" version = "0.6.3" @@ -3614,6 +4171,16 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -3825,6 +4392,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tokio" version = "1.45.1" @@ -4171,7 +4748,7 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3644627a5af5fa321c95b9b235a72fd24cd29c648c2c379431e6628655627bf" dependencies = [ - "itertools", + "itertools 0.13.0", "unicode-segmentation", "unicode-width 0.1.14", ] @@ -4722,18 +5299,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 6043c3b..a4bf3cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/rpc/ras-jsonrpc-macro", "crates/rpc/ras-jsonrpc-types", "crates/specs/*", + "crates/test-utils/*", "crates/tools/*", "examples/basic-jsonrpc/*", "examples/bidirectional-chat/api", @@ -32,6 +33,7 @@ base64 = "0.22" bon = "3.2" console = "0.15" console_error_panic_hook = "0.1" +criterion = "0.5" crossterm = "0.28" dashmap = "6.1" dialoguer = "0.11" @@ -49,7 +51,7 @@ gloo-net = "0.6" gloo-utils = "0.2" http = "1.0" js-sys = "0.3" -jsonwebtoken = "9.3" +jsonwebtoken = { version = "10.3", features = ["rust_crypto"] } mime_guess = "2.0" once_cell = "1.20" opentelemetry = "0.28" diff --git a/crates/core/ras-identity-core/Cargo.toml b/crates/core/ras-identity-core/Cargo.toml index ab6de24..748d4e4 100644 --- a/crates/core/ras-identity-core/Cargo.toml +++ b/crates/core/ras-identity-core/Cargo.toml @@ -11,4 +11,7 @@ homepage = "https://github.com/example/rust-agent-stack" async-trait = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } -thiserror = { workspace = true } \ No newline at end of file +thiserror = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true } \ No newline at end of file diff --git a/crates/core/ras-identity-core/src/lib.rs b/crates/core/ras-identity-core/src/lib.rs index 49d7543..ac8490c 100644 --- a/crates/core/ras-identity-core/src/lib.rs +++ b/crates/core/ras-identity-core/src/lib.rs @@ -78,3 +78,73 @@ impl UserPermissions for StaticPermissions { Ok(self.permissions.clone()) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn vi() -> VerifiedIdentity { + VerifiedIdentity { + provider_id: "test".into(), + subject: "alice".into(), + email: Some("a@b.com".into()), + display_name: Some("Alice".into()), + metadata: None, + } + } + + #[test] + fn identity_error_display_per_variant() { + assert_eq!( + IdentityError::InvalidCredentials.to_string(), + "Invalid credentials" + ); + assert_eq!( + IdentityError::ProviderNotFound("foo".into()).to_string(), + "Provider not found: foo" + ); + assert_eq!( + IdentityError::ProviderError("bad".into()).to_string(), + "Provider error: bad" + ); + assert_eq!( + IdentityError::UnsupportedMethod.to_string(), + "Unsupported authentication method" + ); + assert_eq!( + IdentityError::InvalidPayload.to_string(), + "Invalid authentication payload" + ); + assert_eq!( + IdentityError::SessionError("expired".into()).to_string(), + "Session error: expired" + ); + + let parse_err = serde_json::from_str::("not json").unwrap_err(); + let wrapped: IdentityError = parse_err.into(); + assert!(wrapped.to_string().starts_with("Serialization error:")); + } + + #[tokio::test] + async fn noop_permissions_returns_empty() { + let p = NoopPermissions; + let perms = p.get_permissions(&vi()).await.unwrap(); + assert!(perms.is_empty()); + } + + #[tokio::test] + async fn static_permissions_returns_provided_list() { + let p = StaticPermissions::new(vec!["a".into(), "b".into()]); + let perms = p.get_permissions(&vi()).await.unwrap(); + assert_eq!(perms, vec!["a".to_string(), "b".to_string()]); + } + + #[test] + fn verified_identity_serde_round_trips() { + let v = vi(); + let json = serde_json::to_string(&v).unwrap(); + let parsed: VerifiedIdentity = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.subject, "alice"); + assert_eq!(parsed.provider_id, "test"); + } +} diff --git a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs index f31f436..098d4ee 100644 --- a/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs +++ b/crates/identity/ras-identity-oauth2/examples/google_oauth2.rs @@ -51,8 +51,9 @@ async fn main() -> Result<(), Box> { let oauth2_provider = OAuth2Provider::new(oauth2_config, state_store); // Create session service - let session_config = SessionConfig::default(); - let session_service = SessionService::new(session_config); + let session_config = + SessionConfig::new("oauth2-example-secret-that-is-long-enough-for-tests").unwrap(); + let session_service = SessionService::new(session_config).unwrap(); // Register OAuth2 provider with session service session_service diff --git a/crates/identity/ras-identity-oauth2/src/config.rs b/crates/identity/ras-identity-oauth2/src/config.rs index 14c09ad..01f8a35 100644 --- a/crates/identity/ras-identity-oauth2/src/config.rs +++ b/crates/identity/ras-identity-oauth2/src/config.rs @@ -80,3 +80,72 @@ impl OAuth2Config { self } } + +#[cfg(test)] +mod tests { + use super::*; + + fn provider() -> OAuth2ProviderConfig { + OAuth2ProviderConfig { + provider_id: "google".into(), + client_id: "cid".into(), + client_secret: "secret".into(), + authorization_endpoint: "https://x/auth".into(), + token_endpoint: "https://x/token".into(), + userinfo_endpoint: Some("https://x/info".into()), + redirect_uri: "https://app/cb".into(), + scopes: vec!["openid".into(), "email".into()], + auth_params: HashMap::new(), + use_pkce: true, + user_info_mapping: None, + } + } + + #[test] + fn user_info_mapping_default_uses_oidc_field_names() { + let m = UserInfoMapping::default(); + assert_eq!(m.subject_field.as_deref(), Some("sub")); + assert_eq!(m.email_field.as_deref(), Some("email")); + assert_eq!(m.name_field.as_deref(), Some("name")); + assert_eq!(m.picture_field.as_deref(), Some("picture")); + } + + #[test] + fn oauth2_config_builder_chains_settings() { + let p = provider(); + let cfg = OAuth2Config::new() + .add_provider(p.clone()) + .with_state_ttl(120) + .with_http_timeout(7); + assert_eq!(cfg.state_ttl_seconds, 120); + assert_eq!(cfg.http_timeout_seconds, 7); + assert!(cfg.providers.contains_key("google")); + } + + #[test] + fn provider_config_round_trips_through_serde() { + let p = provider(); + let json = serde_json::to_string(&p).unwrap(); + let parsed: OAuth2ProviderConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.provider_id, p.provider_id); + assert_eq!(parsed.client_id, p.client_id); + assert_eq!(parsed.scopes, p.scopes); + assert!(parsed.use_pkce); + } + + #[test] + fn user_info_mapping_serde() { + let m = UserInfoMapping::default(); + let json = serde_json::to_string(&m).unwrap(); + let parsed: UserInfoMapping = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.subject_field, m.subject_field); + } + + #[test] + fn defaults_are_sensible() { + let cfg = OAuth2Config::default(); + assert!(cfg.providers.is_empty()); + assert_eq!(cfg.state_ttl_seconds, 600); + assert_eq!(cfg.http_timeout_seconds, 30); + } +} diff --git a/crates/identity/ras-identity-oauth2/src/lib.rs b/crates/identity/ras-identity-oauth2/src/lib.rs index 6c860fa..baf54b0 100644 --- a/crates/identity/ras-identity-oauth2/src/lib.rs +++ b/crates/identity/ras-identity-oauth2/src/lib.rs @@ -19,7 +19,9 @@ pub use config::{OAuth2Config, OAuth2ProviderConfig}; pub use error::{OAuth2Error, OAuth2Result}; pub use provider::{OAuth2AuthPayload, OAuth2Provider, OAuth2Response}; pub use state::{InMemoryStateStore, OAuth2State, OAuth2StateStore}; -pub use types::{AuthorizationRequest, AuthorizationResponse, TokenResponse, UserInfoResponse}; +pub use types::{ + AuthorizationRequest, AuthorizationResponse, ProviderMetadata, TokenResponse, UserInfoResponse, +}; // Re-export common types for convenience pub use ras_identity_core::{IdentityProvider, VerifiedIdentity}; diff --git a/crates/identity/ras-identity-session/src/lib.rs b/crates/identity/ras-identity-session/src/lib.rs index d9da41b..45589ff 100644 --- a/crates/identity/ras-identity-session/src/lib.rs +++ b/crates/identity/ras-identity-session/src/lib.rs @@ -25,6 +25,9 @@ pub enum SessionError { #[error("Invalid session")] InvalidSession, + + #[error("Invalid session configuration: {0}")] + InvalidConfig(String), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -49,17 +52,60 @@ pub struct SessionConfig { pub algorithm: Algorithm, } -impl Default for SessionConfig { - fn default() -> Self { - Self { - jwt_secret: "change-me-in-production".to_string(), +impl SessionConfig { + pub fn new(jwt_secret: impl Into) -> Result { + let config = Self { + jwt_secret: jwt_secret.into(), jwt_ttl: Duration::hours(24), refresh_enabled: true, enforce_active_sessions: true, algorithm: Algorithm::HS256, + }; + config.validate()?; + Ok(config) + } + + pub fn validate(&self) -> Result<(), SessionError> { + validate_jwt_secret(&self.jwt_secret)?; + + if self.jwt_ttl <= Duration::zero() { + return Err(SessionError::InvalidConfig( + "jwt_ttl must be positive".to_string(), + )); } + + Ok(()) } } + +fn validate_jwt_secret(secret: &str) -> Result<(), SessionError> { + let trimmed = secret.trim(); + let insecure_placeholders = [ + "change-me-in-production", + "change-me", + "secret", + "test-secret", + "test-secret-key", + ]; + + if trimmed.len() < 32 { + return Err(SessionError::InvalidConfig( + "jwt_secret must be at least 32 bytes".to_string(), + )); + } + + if insecure_placeholders + .iter() + .any(|placeholder| trimmed.eq_ignore_ascii_case(placeholder)) + { + return Err(SessionError::InvalidConfig( + "jwt_secret must not use a placeholder value".to_string(), + )); + } + + Ok(()) +} + pub struct SessionService { config: SessionConfig, providers: Arc>>>, @@ -67,13 +113,14 @@ pub struct SessionService { permissions_provider: Option>, } impl SessionService { - pub fn new(config: SessionConfig) -> Self { - Self { + pub fn new(config: SessionConfig) -> Result { + config.validate()?; + Ok(Self { config, providers: Arc::new(RwLock::new(HashMap::new())), active_sessions: Arc::new(RwLock::new(HashMap::new())), permissions_provider: None, - } + }) } pub fn with_permissions(mut self, provider: Arc) -> Self { @@ -95,6 +142,10 @@ impl SessionService { provider_id: &str, auth_payload: serde_json::Value, ) -> Result { + if self.config.enforce_active_sessions { + self.cleanup_expired_sessions().await; + } + let providers = self.providers.read().await; let provider = providers .get(provider_id) @@ -139,10 +190,18 @@ impl SessionService { } pub async fn verify_session(&self, token: &str) -> Result { + if self.config.enforce_active_sessions { + self.cleanup_expired_sessions().await; + } + + let mut validation = Validation::new(self.config.algorithm); + validation.set_required_spec_claims(&["exp"]); + validation.validate_exp = true; + let token_data = decode::( token, &DecodingKey::from_secret(self.config.jwt_secret.as_bytes()), - &Validation::new(self.config.algorithm), + &validation, )?; if self.config.enforce_active_sessions { @@ -159,6 +218,14 @@ impl SessionService { let mut sessions = self.active_sessions.write().await; sessions.remove(jti) } + + pub async fn cleanup_expired_sessions(&self) -> usize { + let now = Utc::now().timestamp(); + let mut sessions = self.active_sessions.write().await; + let before = sessions.len(); + sessions.retain(|_, claims| claims.exp > now); + before - sessions.len() + } } #[derive(Clone)] @@ -206,10 +273,12 @@ mod tests { use ras_identity_core::StaticPermissions; use ras_identity_local::LocalUserProvider; + const TEST_SECRET: &str = "test-secret-that-is-long-enough-for-hs256"; + #[tokio::test] async fn test_session_lifecycle() { - let config = SessionConfig::default(); - let session_service = SessionService::new(config); + let config = SessionConfig::new(TEST_SECRET).unwrap(); + let session_service = SessionService::new(config).unwrap(); let local_provider = LocalUserProvider::new(); local_provider @@ -248,12 +317,14 @@ mod tests { #[tokio::test] async fn test_session_with_permissions() { - let config = SessionConfig::default(); + let config = SessionConfig::new(TEST_SECRET).unwrap(); let permissions_provider = Arc::new(StaticPermissions::new(vec![ "read".to_string(), "write".to_string(), ])); - let session_service = SessionService::new(config).with_permissions(permissions_provider); + let session_service = SessionService::new(config) + .unwrap() + .with_permissions(permissions_provider); let local_provider = LocalUserProvider::new(); local_provider @@ -286,4 +357,58 @@ mod tests { assert!(claims.permissions.contains("read")); assert!(claims.permissions.contains("write")); } + + #[test] + fn test_rejects_placeholder_secret() { + let result = SessionConfig::new("change-me-in-production"); + assert!(matches!(result, Err(SessionError::InvalidConfig(_)))); + } + + #[tokio::test] + async fn test_cleanup_expired_sessions() { + let config = SessionConfig::new(TEST_SECRET).unwrap(); + let service = SessionService::new(config).unwrap(); + + { + let mut sessions = service.active_sessions.write().await; + sessions.insert( + "expired".to_string(), + JwtClaims { + sub: "user".to_string(), + exp: Utc::now().timestamp() - 1, + iat: Utc::now().timestamp() - 10, + jti: "expired".to_string(), + provider_id: "local".to_string(), + email: None, + display_name: None, + permissions: HashSet::new(), + metadata: None, + }, + ); + } + + assert_eq!(service.cleanup_expired_sessions().await, 1); + } + + #[tokio::test] + async fn test_malformed_exp_claim_is_rejected() { + let config = SessionConfig::new(TEST_SECRET).unwrap(); + let service = SessionService::new(config).unwrap(); + + let token = encode( + &Header::new(Algorithm::HS256), + &serde_json::json!({ + "sub": "user", + "exp": "not-a-number", + "iat": Utc::now().timestamp(), + "jti": "malformed", + "provider_id": "local", + "permissions": [], + }), + &EncodingKey::from_secret(TEST_SECRET.as_bytes()), + ) + .unwrap(); + + assert!(service.verify_session(&token).await.is_err()); + } } diff --git a/crates/rest/ras-file-macro/Cargo.toml b/crates/rest/ras-file-macro/Cargo.toml index 545450f..243e3a9 100644 --- a/crates/rest/ras-file-macro/Cargo.toml +++ b/crates/rest/ras-file-macro/Cargo.toml @@ -22,4 +22,12 @@ serde = { workspace = true } serde_json = { workspace = true } ras-auth-core = { path = "../../core/ras-auth-core" } thiserror = { workspace = true } -async-trait = { workspace = true } \ No newline at end of file +async-trait = { workspace = true } +ras-test-helpers = { path = "../../test-utils/ras-test-helpers" } +axum-test = { workspace = true } +tempfile = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "streaming" +harness = false diff --git a/crates/rest/ras-file-macro/benches/streaming.rs b/crates/rest/ras-file-macro/benches/streaming.rs new file mode 100644 index 0000000..d6981a8 --- /dev/null +++ b/crates/rest/ras-file-macro/benches/streaming.rs @@ -0,0 +1,131 @@ +//! Criterion bench measuring 1 MiB upload + download through the file_service! +//! generated client and router. + +use std::io::Write; +use std::sync::{Arc, Mutex}; + +use axum::{ + body::Body, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_auth_core::AuthenticatedUser; +use ras_file_macro::file_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct UploadResponse { + file_id: String, + size: u64, +} + +file_service!({ + service_name: BenchSvc, + base_path: "/files", + endpoints: [ + DOWNLOAD UNAUTHORIZED download/{file_id: String}(), + UPLOAD WITH_PERMISSIONS(["user"]) upload() -> UploadResponse, + ] +}); + +type Storage = Arc)>>>; + +#[derive(Clone)] +struct BenchImpl { + storage: Storage, +} + +#[async_trait::async_trait] +impl BenchSvcTrait for BenchImpl { + async fn download(&self, file_id: String) -> Result { + let bytes = self + .storage + .lock() + .unwrap() + .iter() + .find_map(|(id, data)| (id == &file_id).then(|| data.clone())) + .ok_or(BenchSvcFileError::NotFound)?; + Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::from(bytes)) + .unwrap()) + } + + async fn upload( + &self, + _user: &AuthenticatedUser, + mut multipart: axum::extract::Multipart, + ) -> Result { + let field = multipart + .next_field() + .await + .map_err(|e| BenchSvcFileError::UploadFailed(e.to_string()))? + .ok_or_else(|| BenchSvcFileError::UploadFailed("no field".into()))?; + let data = field + .bytes() + .await + .map_err(|e| BenchSvcFileError::UploadFailed(e.to_string()))?; + let id = format!("file-{}", self.storage.lock().unwrap().len()); + let size = data.len() as u64; + self.storage + .lock() + .unwrap() + .push((id.clone(), data.to_vec())); + Ok(UploadResponse { file_id: id, size }) + } +} + +fn build_router() -> (axum::Router, Storage) { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let router = BenchSvcBuilder::::new(BenchImpl { + storage: storage.clone(), + }) + .auth_provider(MockAuthProvider::default()) + .build(); + (router, storage) +} + +fn bench_streaming(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + // Prepare 1 MiB payload on disk and a live server. + let mut tmp = tempfile::NamedTempFile::new().expect("tempfile"); + let payload: Vec = (0u8..=255).cycle().take(1024 * 1024).collect(); + tmp.write_all(&payload).unwrap(); + tmp.flush().unwrap(); + let path = tmp.path().to_path_buf(); + + let (client, _server) = rt.block_on(async { + let (router, _storage) = build_router(); + let server = spawn_http(router); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + let client = BenchSvcClient::builder(base_str) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + (client, server) + }); + + c.bench_function("file_upload_download_1mib", |b| { + b.to_async(&rt).iter(|| { + let client = &client; + let path = path.clone(); + async move { + let r = client + .upload(&path, Some("blob.bin"), Some("application/octet-stream")) + .await + .expect("upload"); + let resp = client.download(r.file_id).await.expect("download"); + let bytes = resp.bytes().await.expect("body"); + std::hint::black_box(bytes); + } + }); + }); +} + +criterion_group!(benches, bench_streaming); +criterion_main!(benches); diff --git a/crates/rest/ras-file-macro/src/client.rs b/crates/rest/ras-file-macro/src/client.rs index 3d71a38..b687c98 100644 --- a/crates/rest/ras-file-macro/src/client.rs +++ b/crates/rest/ras-file-macro/src/client.rs @@ -292,7 +292,7 @@ fn generate_wasm_client(definition: &FileServiceDefinition) -> TokenStream { let wasm_methods = generate_wasm_methods(&definition.endpoints); quote! { - #[cfg(all(target_arch = "wasm32", feature = "wasm-client"))] + #[cfg(target_arch = "wasm32")] pub mod wasm_client { use super::*; use wasm_bindgen::prelude::*; diff --git a/crates/rest/ras-file-macro/src/server.rs b/crates/rest/ras-file-macro/src/server.rs index 1d7573b..8e1c92b 100644 --- a/crates/rest/ras-file-macro/src/server.rs +++ b/crates/rest/ras-file-macro/src/server.rs @@ -106,9 +106,9 @@ pub fn generate_server(definition: &FileServiceDefinition) -> TokenStream { #error_name::NotFound => (StatusCode::NOT_FOUND, self.to_string()), #error_name::InvalidFormat => (StatusCode::BAD_REQUEST, self.to_string()), #error_name::FileTooLarge => (StatusCode::PAYLOAD_TOO_LARGE, self.to_string()), - #error_name::UploadFailed(msg) => (StatusCode::BAD_REQUEST, msg), - #error_name::DownloadFailed(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), - #error_name::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), + #error_name::UploadFailed(_) => (StatusCode::BAD_REQUEST, "Upload failed".to_string()), + #error_name::DownloadFailed(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Download failed".to_string()), + #error_name::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string()), }; <(::axum::http::StatusCode, String) as ::axum::response::IntoResponse>::into_response((status, message)) diff --git a/crates/rest/ras-file-macro/tests/e2e.rs b/crates/rest/ras-file-macro/tests/e2e.rs new file mode 100644 index 0000000..c0c64f8 --- /dev/null +++ b/crates/rest/ras-file-macro/tests/e2e.rs @@ -0,0 +1,163 @@ +//! End-to-end test for the file_service! macro: generated reqwest client → +//! axum router → handler. Exercises upload + download with byte-equality and +//! a missing-token rejection. + +use std::io::Write; +use std::sync::{Arc, Mutex}; + +use axum::{ + body::Body, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use ras_auth_core::AuthenticatedUser; +use ras_file_macro::file_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct UploadResponse { + file_id: String, + size: u64, +} + +file_service!({ + service_name: Demo, + base_path: "/files", + endpoints: [ + DOWNLOAD UNAUTHORIZED download/{file_id: String}(), + UPLOAD WITH_PERMISSIONS(["user"]) upload() -> UploadResponse, + ] +}); + +type Storage = Arc)>>>; + +#[derive(Clone)] +struct DemoImpl { + storage: Storage, +} + +#[async_trait::async_trait] +impl DemoTrait for DemoImpl { + async fn download(&self, file_id: String) -> Result { + let store = self.storage.lock().unwrap(); + let bytes = store + .iter() + .find_map(|(id, data)| (id == &file_id).then(|| data.clone())) + .ok_or(DemoFileError::NotFound)?; + + Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/octet-stream") + .body(Body::from(bytes)) + .unwrap()) + } + + async fn upload( + &self, + _user: &AuthenticatedUser, + mut multipart: axum::extract::Multipart, + ) -> Result { + let field = multipart + .next_field() + .await + .map_err(|e| DemoFileError::UploadFailed(e.to_string()))? + .ok_or_else(|| DemoFileError::UploadFailed("no field".into()))?; + let data = field + .bytes() + .await + .map_err(|e| DemoFileError::UploadFailed(e.to_string()))?; + let id = format!("file-{}", self.storage.lock().unwrap().len()); + let size = data.len() as u64; + self.storage + .lock() + .unwrap() + .push((id.clone(), data.to_vec())); + Ok(UploadResponse { file_id: id, size }) + } +} + +fn router(storage: Storage) -> axum::Router { + DemoBuilder::::new(DemoImpl { storage }) + .auth_provider(MockAuthProvider::default()) + .build() +} + +fn write_tempfile(bytes: &[u8]) -> tempfile::NamedTempFile { + let mut f = tempfile::NamedTempFile::new().expect("tempfile"); + f.write_all(bytes).expect("write tempfile"); + f.flush().expect("flush tempfile"); + f +} + +#[tokio::test] +async fn upload_and_download_round_trips_bytes() { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let server = spawn_http(router(storage.clone())); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + + let payload: Vec = (0u8..=255).cycle().take(64 * 1024).collect(); + let tmp = write_tempfile(&payload); + + let client = DemoClient::builder(base_str.clone()) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + + let upload = client + .upload( + tmp.path(), + Some("blob.bin"), + Some("application/octet-stream"), + ) + .await + .expect("upload ok"); + assert_eq!(upload.size, payload.len() as u64); + + let resp = client.download(upload.file_id).await.expect("download ok"); + let bytes = resp.bytes().await.expect("read body"); + assert_eq!(bytes.as_ref(), payload.as_slice()); +} + +#[tokio::test] +async fn upload_rejected_without_token() { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let server = spawn_http(router(storage)); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + + let payload = b"hello world"; + let tmp = write_tempfile(payload); + + let client = DemoClient::builder(base_str).build().expect("client build"); + // No bearer token. + + let result = client + .upload(tmp.path(), Some("hi.txt"), Some("text/plain")) + .await; + // The server short-circuits with 401 before consuming the multipart body, + // so reqwest may surface that either as the parsed status or as a generic + // connection error depending on how the upload stream was cut. Either is a + // valid signal of rejection — the only outcome we want to rule out is + // success. + assert!( + result.is_err(), + "upload must be rejected without a bearer token, got: {result:?}" + ); +} + +#[tokio::test] +async fn download_unknown_file_returns_404() { + let storage: Storage = Arc::new(Mutex::new(Vec::new())); + let server = spawn_http(router(storage)); + let base = server.server_address().unwrap(); + let base_str = base.as_str().trim_end_matches('/').to_string(); + + let client = DemoClient::builder(base_str).build().expect("client build"); + let err = client + .download("does-not-exist".to_string()) + .await + .expect_err("missing file must error"); + assert!(err.to_string().contains("404"), "got: {err}"); +} diff --git a/crates/rest/ras-file-macro/tests/integration.rs b/crates/rest/ras-file-macro/tests/integration.rs index 84e3947..abe7983 100644 --- a/crates/rest/ras-file-macro/tests/integration.rs +++ b/crates/rest/ras-file-macro/tests/integration.rs @@ -179,13 +179,13 @@ mod tests { client.set_bearer_token(Some("test-token")); // Test that client methods exist - assert!(true); // Basic compilation test + let _ = client; } #[test] fn test_file_error_variants() { // Test that all error variants exist - let _errors = vec![ + let _errors = [ TestFileServiceFileError::NotFound, TestFileServiceFileError::UploadFailed("test".to_string()), TestFileServiceFileError::DownloadFailed("test".to_string()), @@ -195,8 +195,8 @@ mod tests { ]; } - #[test] - fn test_error_into_response() { + #[tokio::test] + async fn test_error_into_response() { use axum::http::StatusCode; use axum::response::IntoResponse; @@ -229,5 +229,14 @@ mod tests { let (parts, _) = response.into_parts(); assert_eq!(parts.status, expected_status); } + + let response = TestFileServiceFileError::Internal("database password leaked".to_string()) + .into_response(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.contains("Internal server error")); + assert!(!body.contains("database password leaked")); } } diff --git a/crates/rest/ras-file-macro/tests/minimal_test.rs b/crates/rest/ras-file-macro/tests/minimal_test.rs index 7fe190b..e76413e 100644 --- a/crates/rest/ras-file-macro/tests/minimal_test.rs +++ b/crates/rest/ras-file-macro/tests/minimal_test.rs @@ -46,5 +46,4 @@ fn test_compilation() { let service = MyService; let auth = DummyAuth; let _builder = MinimalServiceBuilder::new(service).auth_provider(auth); - assert!(true); } diff --git a/crates/rest/ras-file-macro/tests/simple_test.rs b/crates/rest/ras-file-macro/tests/simple_test.rs index f696099..aaa3f0e 100644 --- a/crates/rest/ras-file-macro/tests/simple_test.rs +++ b/crates/rest/ras-file-macro/tests/simple_test.rs @@ -12,5 +12,4 @@ file_service!({ #[test] fn test_compilation() { // If it compiles, the test passes - assert!(true); } diff --git a/crates/rest/ras-rest-core/src/lib.rs b/crates/rest/ras-rest-core/src/lib.rs index 63557e2..47bab98 100644 --- a/crates/rest/ras-rest-core/src/lib.rs +++ b/crates/rest/ras-rest-core/src/lib.rs @@ -168,3 +168,89 @@ impl RestResultExt for Resul .map_err(|e| RestError::with_internal(status, msg, e)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rest_response_constructors_set_correct_status() { + assert_eq!(RestResponse::ok(1).status, 200); + assert_eq!(RestResponse::created("x").status, 201); + assert_eq!(RestResponse::accepted(true).status, 202); + let nc: RestResponse<()> = RestResponse::no_content(); + assert_eq!(nc.status, 204); + assert_eq!(RestResponse::with_status(418, "tea").status, 418); + // Body is preserved. + assert_eq!(RestResponse::ok(42).body, 42); + } + + #[test] + fn rest_error_constructors_set_correct_status_and_message() { + let cases = [ + (RestError::bad_request("a"), 400), + (RestError::unauthorized("a"), 401), + (RestError::forbidden("a"), 403), + (RestError::not_found("a"), 404), + (RestError::conflict("a"), 409), + (RestError::unprocessable_entity("a"), 422), + (RestError::internal_server_error("a"), 500), + (RestError::bad_gateway("a"), 502), + (RestError::service_unavailable("a"), 503), + ]; + for (err, expected) in cases { + assert_eq!(err.status, expected); + assert_eq!(err.message, "a"); + assert!(err.internal_error.is_none()); + // Display includes the status and message. + let s = err.to_string(); + assert!(s.contains(&expected.to_string())); + assert!(s.contains("a")); + } + } + + #[test] + fn rest_error_with_internal_carries_source() { + #[derive(Debug, thiserror::Error)] + #[error("inner failure")] + struct Inner; + let err = RestError::with_internal(503, "down", Inner); + assert_eq!(err.status, 503); + assert!(err.internal_error.is_some()); + // source() returns the wrapped error. + let src = std::error::Error::source(&err).unwrap(); + assert_eq!(src.to_string(), "inner failure"); + } + + #[test] + fn into_rest_error_blanket_impl() { + let err = std::io::Error::new(std::io::ErrorKind::Other, "io"); + let rest = err.into_rest_error(); + assert_eq!(rest.status, 500); + assert_eq!(rest.message, "Internal server error"); + assert!(rest.internal_error.is_some()); + } + + #[test] + fn rest_result_ext_maps_ok_and_err() { + let ok: Result = Ok(7); + let mapped: RestResult = ok.internal_server_error(); + let resp = mapped.unwrap(); + assert_eq!(resp.status, 200); + assert_eq!(resp.body, 7); + + let err: Result = + Err(std::io::Error::new(std::io::ErrorKind::Other, "x")); + let mapped: RestResult = err.internal_server_error(); + let e = mapped.unwrap_err(); + assert_eq!(e.status, 500); + + // rest_error variant lets callers customize. + let err: Result = + Err(std::io::Error::new(std::io::ErrorKind::Other, "x")); + let mapped: RestResult = err.rest_error(418, "teapot"); + let e = mapped.unwrap_err(); + assert_eq!(e.status, 418); + assert_eq!(e.message, "teapot"); + } +} diff --git a/crates/rest/ras-rest-macro/Cargo.toml b/crates/rest/ras-rest-macro/Cargo.toml index 45da79c..820c33c 100644 --- a/crates/rest/ras-rest-macro/Cargo.toml +++ b/crates/rest/ras-rest-macro/Cargo.toml @@ -49,4 +49,12 @@ async-trait = { workspace = true } # Server dependencies for tests axum = { workspace = true } ras-auth-core = { path = "../../core/ras-auth-core" } -ras-rest-core = { path = "../ras-rest-core" } \ No newline at end of file +ras-rest-core = { path = "../ras-rest-core" } +ras-test-helpers = { path = "../../test-utils/ras-test-helpers" } +axum-test = { workspace = true } +schemars = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "dispatch" +harness = false \ No newline at end of file diff --git a/crates/rest/ras-rest-macro/benches/dispatch.rs b/crates/rest/ras-rest-macro/benches/dispatch.rs new file mode 100644 index 0000000..6e15e7c --- /dev/null +++ b/crates/rest/ras-rest-macro/benches/dispatch.rs @@ -0,0 +1,72 @@ +//! Criterion bench measuring per-call latency of an authenticated REST GET +//! through the full stack: generated client → axum router → handler. +//! +//! Run with `cargo bench -p ras-rest-macro`. + +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_auth_core::AuthenticatedUser; +use ras_rest_core::{RestResponse, RestResult}; +use ras_rest_macro::rest_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct Item { + id: u32, + name: String, +} + +rest_service!({ + service_name: BenchSvc, + base_path: "/api", + openapi: false, + serve_docs: false, + endpoints: [ + GET WITH_PERMISSIONS(["user"]) items/{id: u32}() -> Item, + ] +}); + +struct BenchImpl; + +#[async_trait::async_trait] +impl BenchSvcTrait for BenchImpl { + async fn get_items_by_id(&self, _user: &AuthenticatedUser, id: u32) -> RestResult { + Ok(RestResponse::ok(Item { + id, + name: "x".into(), + })) + } +} + +fn build_router() -> axum::Router { + BenchSvcBuilder::new(BenchImpl) + .auth_provider(MockAuthProvider::default()) + .build() +} + +fn bench_dispatch(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + let (client, _server) = rt.block_on(async { + let server = spawn_http(build_router()); + let base = server.server_address().unwrap().to_string(); + let mut client = BenchSvcClient::builder(&base) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + (client, server) + }); + + c.bench_function("rest_get_dispatch", |b| { + b.to_async(&rt).iter(|| { + let client = client.clone(); + async move { + let r = client.get_items_by_id(1).await.expect("get ok"); + std::hint::black_box(r); + } + }); + }); +} + +criterion_group!(benches, bench_dispatch); +criterion_main!(benches); diff --git a/crates/rest/ras-rest-macro/src/client.rs b/crates/rest/ras-rest-macro/src/client.rs index dbc0525..bd2f3cd 100644 --- a/crates/rest/ras-rest-macro/src/client.rs +++ b/crates/rest/ras-rest-macro/src/client.rs @@ -1,5 +1,18 @@ use crate::{EndpointDefinition, HttpMethod, ServiceDefinition}; use quote::quote; +use syn::Type; + +/// True if `ty` is syntactically `Option<...>`. Matches the bare `Option` +/// segment as well as fully-qualified forms like `std::option::Option` / +/// `core::option::Option` — anything whose last path segment is `Option`. +fn is_option_type(ty: &Type) -> bool { + if let Type::Path(type_path) = ty { + if let Some(last) = type_path.path.segments.last() { + return last.ident == "Option"; + } + } + false +} /// Generate client code for REST service pub fn generate_client_code(service_def: &ServiceDefinition) -> proc_macro2::TokenStream { @@ -148,6 +161,14 @@ fn generate_client_method(endpoint: &EndpointDefinition) -> proc_macro2::TokenSt call_args.push(quote! { #param_name }); } + // Add query parameters (mirroring the macro syntax order: path → query → body). + for query_param in endpoint.query_params.iter() { + let param_name = &query_param.name; + let param_type = &query_param.param_type; + params.push(quote! { #param_name: #param_type }); + call_args.push(quote! { #param_name }); + } + // Add request body parameter if present if endpoint.request_type.is_some() { let request_type = endpoint.request_type.as_ref().unwrap(); @@ -211,6 +232,44 @@ fn generate_client_method_with_timeout(endpoint: &EndpointDefinition) -> proc_ma }; } + // Build query-string handling. Required params are always serialized; + // `Option` params are skipped when `None`. Values are converted with + // `ToString` and url-encoded by reqwest's `.query()` helper. + let query_handling = if endpoint.query_params.is_empty() { + quote! {} + } else { + let pushes = endpoint.query_params.iter().map(|qp| { + let param_name = &qp.name; + let param_str = qp.name.to_string(); + if is_option_type(&qp.param_type) { + quote! { + if let Some(__v) = &#param_name { + __query_pairs.push((#param_str, __v.to_string())); + } + } + } else { + quote! { + __query_pairs.push((#param_str, #param_name.to_string())); + } + } + }); + quote! { + let mut __query_pairs: Vec<(&'static str, String)> = Vec::new(); + #(#pushes)* + if !__query_pairs.is_empty() { + request_builder = request_builder.query(&__query_pairs); + } + } + }; + + // Add query parameters to the function signature (after path params, + // before the body — matches macro syntax order). + for query_param in endpoint.query_params.iter() { + let param_name = &query_param.name; + let param_type = &query_param.param_type; + params.push(quote! { #param_name: #param_type }); + } + // Add request body parameter if present let request_body_handling = if let Some(request_type) = &endpoint.request_type { params.push(quote! { body: #request_type }); @@ -266,6 +325,8 @@ fn generate_client_method_with_timeout(endpoint: &EndpointDefinition) -> proc_ma request_builder = request_builder.header("Authorization", format!("Bearer {}", token)); } + #query_handling + #request_body_handling // Override timeout if provided (not supported in WASM builds) diff --git a/crates/rest/ras-rest-macro/src/lib.rs b/crates/rest/ras-rest-macro/src/lib.rs index 2e7725c..d7b97eb 100644 --- a/crates/rest/ras-rest-macro/src/lib.rs +++ b/crates/rest/ras-rest-macro/src/lib.rs @@ -542,7 +542,7 @@ fn generate_service_code(service_def: ServiceDefinition) -> syn::ResultJWT Token +
@@ -705,7 +709,7 @@ pub fn generate_static_hosting_code( // Global state let apiSpec = null; let currentEndpoint = null; - let jwtToken = localStorage.getItem('jwt-token') || ''; + let jwtToken = ''; // Initialize the application document.addEventListener('DOMContentLoaded', async () => {{ @@ -732,9 +736,13 @@ pub fn generate_static_hosting_code( function initializeAuth() {{ const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); + const rememberedToken = sessionStorage.getItem('jwt-token') || ''; - if (jwtToken) {{ + if (rememberedToken) {{ + jwtToken = rememberedToken; tokenInput.value = jwtToken; + rememberToken.checked = true; authStatus.classList.add('authenticated'); }} }} @@ -773,13 +781,17 @@ pub fn generate_static_hosting_code( function saveToken() {{ const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); jwtToken = tokenInput.value.trim(); - localStorage.setItem('jwt-token', jwtToken); + sessionStorage.removeItem('jwt-token'); + if (jwtToken && rememberToken.checked) {{ + sessionStorage.setItem('jwt-token', jwtToken); + }} if (jwtToken) {{ authStatus.classList.add('authenticated'); - showSuccess('Token saved successfully'); + showSuccess(rememberToken.checked ? 'Token saved for this tab' : 'Token ready for this page'); }} else {{ authStatus.classList.remove('authenticated'); }} @@ -788,8 +800,9 @@ pub fn generate_static_hosting_code( // Clear JWT token function clearToken() {{ jwtToken = ''; - localStorage.removeItem('jwt-token'); + sessionStorage.removeItem('jwt-token'); document.getElementById('jwt-token').value = ''; + document.getElementById('remember-token').checked = false; document.getElementById('auth-status').classList.remove('authenticated'); showSuccess('Token cleared'); }} diff --git a/crates/rest/ras-rest-macro/tests/e2e.rs b/crates/rest/ras-rest-macro/tests/e2e.rs new file mode 100644 index 0000000..2c95e29 --- /dev/null +++ b/crates/rest/ras-rest-macro/tests/e2e.rs @@ -0,0 +1,285 @@ +//! End-to-end test: generated reqwest client → axum router → trait impl → +//! response → client. Covers GET, POST with body, path params, query params, +//! and auth-related rejection paths. + +use ras_auth_core::AuthenticatedUser; +use ras_rest_core::{RestError, RestResponse, RestResult}; +use ras_rest_macro::rest_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct Item { + id: u32, + name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct CreateItem { + name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] +struct ItemsResponse { + items: Vec, +} + +rest_service!({ + service_name: Demo, + base_path: "/api", + openapi: false, + serve_docs: false, + endpoints: [ + GET UNAUTHORIZED items() -> ItemsResponse, + GET WITH_PERMISSIONS(["user"]) items/{id: u32}() -> Item, + POST WITH_PERMISSIONS(["admin"]) items(CreateItem) -> Item, + GET UNAUTHORIZED search ? q: String & limit: Option & exact: bool () -> ItemsResponse, + POST WITH_PERMISSIONS(["admin"]) items/batch ? notify: bool (CreateItem) -> Item, + GET WITH_PERMISSIONS(["user"]) items/{id: u32}/related ? tag: Option () -> ItemsResponse, + ] +}); + +struct DemoImpl; + +#[async_trait::async_trait] +impl DemoTrait for DemoImpl { + async fn get_items(&self) -> RestResult { + Ok(RestResponse::ok(ItemsResponse { + items: vec![Item { + id: 1, + name: "alpha".into(), + }], + })) + } + + async fn get_items_by_id(&self, _user: &AuthenticatedUser, id: u32) -> RestResult { + if id == 404 { + Err(RestError::not_found("missing")) + } else { + Ok(RestResponse::ok(Item { + id, + name: format!("item-{id}"), + })) + } + } + + async fn post_items(&self, user: &AuthenticatedUser, body: CreateItem) -> RestResult { + // Use the user_id length so we can verify the user actually arrived. + Ok(RestResponse::created(Item { + id: user.user_id.len() as u32, + name: body.name, + })) + } + + async fn get_search( + &self, + q: String, + limit: Option, + exact: bool, + ) -> RestResult { + let n = limit.unwrap_or(2); + let prefix = if exact { "exact" } else { "fuzzy" }; + let items = (0..n) + .map(|i| Item { + id: i, + name: format!("{prefix}:{q}-{i}"), + }) + .collect(); + Ok(RestResponse::ok(ItemsResponse { items })) + } + + async fn post_items_batch( + &self, + _user: &AuthenticatedUser, + notify: bool, + body: CreateItem, + ) -> RestResult { + // Encode the bool query param into the response so we can assert on it. + let suffix = if notify { "(notified)" } else { "(silent)" }; + Ok(RestResponse::created(Item { + id: 0, + name: format!("{}{suffix}", body.name), + })) + } + + async fn get_items_by_id_related( + &self, + _user: &AuthenticatedUser, + id: u32, + tag: Option, + ) -> RestResult { + let label = tag.unwrap_or_else(|| "none".into()); + Ok(RestResponse::ok(ItemsResponse { + items: vec![Item { + id, + name: format!("related/{label}"), + }], + })) + } +} + +fn router() -> axum::Router { + DemoBuilder::new(DemoImpl) + .auth_provider(MockAuthProvider::default()) + .build() +} + +fn client(base: &str) -> DemoClient { + DemoClient::builder(base).build().expect("client build") +} + +#[tokio::test] +async fn unauth_get_round_trips() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let resp = client(&base).get_items().await.expect("get_items ok"); + assert_eq!(resp.items.len(), 1); + assert_eq!(resp.items[0].name, "alpha"); +} + +#[tokio::test] +async fn auth_get_with_path_param_succeeds_with_user_token() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); + + let item = c.get_items_by_id(7).await.expect("get_items_by_id ok"); + assert_eq!(item.id, 7); + assert_eq!(item.name, "item-7"); +} + +#[tokio::test] +async fn auth_get_rejected_without_token() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + // No bearer token set on client. + let err = client(&base) + .get_items_by_id(1) + .await + .expect_err("must be rejected"); + let s = err.to_string(); + assert!(s.contains("401") || s.contains("Unauthorized"), "got: {s}"); +} + +#[tokio::test] +async fn auth_post_rejected_with_insufficient_perms() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); // not admin + + let err = c + .post_items(CreateItem { + name: "x".to_string(), + }) + .await + .expect_err("user-token can't POST items"); + let s = err.to_string(); + assert!(s.contains("403") || s.contains("Forbidden"), "got: {s}"); +} + +#[tokio::test] +async fn auth_post_with_admin_succeeds_and_user_id_propagates() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("admin-token".to_string())); + + let item = c + .post_items(CreateItem { name: "foo".into() }) + .await + .expect("post_items ok"); + assert_eq!(item.name, "foo"); + // admin-1 is 7 chars long. + assert_eq!(item.id, 7); +} + +#[tokio::test] +async fn query_params_required_and_optional_serialize_correctly() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + + // Optional `limit` provided, required `q` and `exact` set. + let resp = client(&base) + .get_search("hi".to_string(), Some(3), true) + .await + .expect("search ok"); + assert_eq!(resp.items.len(), 3); + assert_eq!(resp.items[0].name, "exact:hi-0"); + assert_eq!(resp.items[2].name, "exact:hi-2"); + + // Optional `limit` omitted (None) → handler default of 2 applies, and the + // bool flips the prefix. + let resp = client(&base) + .get_search("zz".to_string(), None, false) + .await + .expect("search ok"); + assert_eq!(resp.items.len(), 2); + assert_eq!(resp.items[0].name, "fuzzy:zz-0"); +} + +#[tokio::test] +async fn query_params_with_body_and_auth() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("admin-token".to_string())); + + let item = c + .post_items_batch( + true, + CreateItem { + name: "alpha".into(), + }, + ) + .await + .expect("post_items_batch ok"); + assert_eq!(item.name, "alpha(notified)"); + + let item = c + .post_items_batch( + false, + CreateItem { + name: "beta".into(), + }, + ) + .await + .expect("post_items_batch ok"); + assert_eq!(item.name, "beta(silent)"); +} + +#[tokio::test] +async fn query_params_with_path_param() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); + + let resp = c + .get_items_by_id_related(42, Some("featured".into())) + .await + .expect("related with tag"); + assert_eq!(resp.items[0].id, 42); + assert_eq!(resp.items[0].name, "related/featured"); + + let resp = c + .get_items_by_id_related(42, None) + .await + .expect("related without tag"); + assert_eq!(resp.items[0].name, "related/none"); +} + +#[tokio::test] +async fn handler_error_surfaces_to_client() { + let server = spawn_http(router()); + let base = server.server_address().unwrap().to_string(); + let mut c = client(&base); + c.set_bearer_token(Some("user-token".to_string())); + + let err = c + .get_items_by_id(404) + .await + .expect_err("404 sentinel must error"); + assert!(err.to_string().contains("404"), "got: {err}"); +} diff --git a/crates/rest/ras-rest-macro/tests/http_integration.rs b/crates/rest/ras-rest-macro/tests/http_integration.rs index a1f3051..c2b53fb 100644 --- a/crates/rest/ras-rest-macro/tests/http_integration.rs +++ b/crates/rest/ras-rest-macro/tests/http_integration.rs @@ -2,7 +2,6 @@ use rand::Rng; use ras_jsonrpc_core::{AuthError, AuthFuture, AuthProvider, AuthenticatedUser}; use ras_rest_core::{RestError, RestResponse}; use ras_rest_macro::rest_service; -use reqwest; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::collections::HashSet; @@ -59,12 +58,6 @@ struct PostsResponse { total: usize, } -#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] -struct ErrorResponse { - error: String, - details: Option, -} - // Simple test auth provider struct TestRestAuthProvider { valid_tokens: HashSet, @@ -940,7 +933,6 @@ async fn test_openapi_generation() { // The fact that this compiles means the REST service macro generated the builder correctly // with OpenAPI configuration enabled - assert!(true, "OpenAPI generation compiled successfully"); } #[tokio::test] @@ -951,7 +943,6 @@ async fn test_missing_dependencies() { // This test ensures that our future handling is working correctly let handles: Vec> = vec![]; let _results = join_all(handles).await; - assert!(true, "Futures dependency is working"); } #[tokio::test] @@ -1049,7 +1040,7 @@ async fn test_generated_rest_client() { assert_eq!(resp.total, 2); - let _resp = client + client .delete_users_by_id_with_timeout(resp.users[0].id.unwrap(), None) .await .expect("failed to get users"); @@ -1118,7 +1109,7 @@ async fn test_query_parameters_with_auth() { assert_eq!(response.status(), 200); let posts_response: PostsResponse = response.json().await.unwrap(); assert!(posts_response.posts[0].tags.contains(&"test".to_string())); - assert_eq!(posts_response.posts[0].published, true); + assert!(posts_response.posts[0].published); // Test with no query parameters - all optional let response = make_rest_request( diff --git a/crates/rest/ras-rest-macro/tests/xss_protection_test.rs b/crates/rest/ras-rest-macro/tests/xss_protection_test.rs index d4b284f..c1b1784 100644 --- a/crates/rest/ras-rest-macro/tests/xss_protection_test.rs +++ b/crates/rest/ras-rest-macro/tests/xss_protection_test.rs @@ -17,6 +17,15 @@ fn test_xss_protection_in_generated_html() { } } +#[test] +fn test_generated_docs_do_not_store_jwt_in_local_storage() { + let source = include_str!("../src/static_hosting.rs"); + assert!(!source.contains("localStorage.getItem('jwt-token')")); + assert!(!source.contains("localStorage.setItem('jwt-token'")); + assert!(!source.contains("localStorage.removeItem('jwt-token'")); + assert!(source.contains("sessionStorage.setItem('jwt-token'")); +} + fn escape_html(unsafe_str: &str) -> String { unsafe_str .replace('&', "&") diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs index 2a7634f..3fa1de6 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/client.rs @@ -734,4 +734,88 @@ mod tests { assert!(!client.is_connected().await); assert!(client.connection_id().await.is_none()); } + + #[tokio::test] + async fn builder_jwt_in_query_params_and_full_setters() { + // Exercise every with_* setter so each path is colored. We don't + // auto-connect (no server), but the resulting config must reflect + // each option. + let custom = ReconnectConfig::default(); + let client = ClientBuilder::new("ws://localhost:8080") + .with_jwt_token("tok".into()) + .with_jwt_in_header(false) + .with_header("X-Custom", "v") + .with_request_timeout(Duration::from_secs(11)) + .with_reconnect_config(custom) + .with_heartbeat_interval(None) + .with_connection_timeout(Duration::from_secs(7)) + .with_auto_connect(false) + .build() + .await + .expect("build"); + + assert!(matches!(client.config().auth, AuthConfig::JwtParams { .. })); + assert_eq!(client.config().request_timeout, Duration::from_secs(11)); + assert_eq!(client.config().connection_timeout, Duration::from_secs(7)); + assert!(client.config().heartbeat_interval.is_none()); + assert_eq!( + client.config().custom_headers.get("X-Custom"), + Some(&"v".to_string()) + ); + assert!(client.active_subscriptions().is_empty()); + assert_eq!(client.pending_requests_count(), 0); + } + + #[tokio::test] + async fn builder_without_token_yields_no_auth() { + let client = ClientBuilder::new("ws://localhost:8080") + .build() + .await + .expect("build"); + assert!(matches!(client.config().auth, AuthConfig::None)); + } + + #[tokio::test] + async fn call_notify_subscribe_unsubscribe_require_connected_state() { + let client = ClientBuilder::new("ws://localhost:8080") + .build() + .await + .expect("build"); + + // call → NotConnected + let err = client.call("m", None).await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + + // notify → NotConnected + let err = client.notify("m", None).await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + + // subscribe → NotConnected + let handler: NotificationHandler = std::sync::Arc::new(|_method: &str, _params: &Value| {}); + let err = client.subscribe("t", handler.clone()).await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + + // unsubscribe → NotConnected + let err = client.unsubscribe("t").await.unwrap_err(); + assert!(matches!(err, ClientError::NotConnected)); + } + + #[tokio::test] + async fn handler_registration_does_not_require_connected_state() { + let client = ClientBuilder::new("ws://localhost:8080") + .build() + .await + .expect("build"); + + let n: NotificationHandler = std::sync::Arc::new(|_, _| {}); + let e: ConnectionEventHandler = std::sync::Arc::new(|_event| {}); + client.on_notification("evt", n); + client.on_connection_event("named", e); + + // cleanup_expired_requests is callable even with nothing pending. + client.cleanup_expired_requests().await; + + // Disconnect-when-already-disconnected is a no-op success. + client.disconnect().await.expect("disconnect ok"); + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs index 78cc666..bf73d1f 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/config.rs @@ -337,4 +337,104 @@ mod tests { config.request_timeout = Duration::from_secs(0); assert!(config.validate().is_err()); } + + #[test] + fn validate_rejects_each_invalid_field() { + let base = ClientConfig::new("ws://localhost:8080"); + + let mut c = base.clone(); + c.request_timeout = Duration::ZERO; + assert!(c.validate().unwrap_err().contains("Request timeout")); + + let mut c = base.clone(); + c.connection_timeout = Duration::ZERO; + assert!(c.validate().unwrap_err().contains("Connection timeout")); + + let mut c = base.clone(); + c.message_buffer_size = 0; + assert!(c.validate().unwrap_err().contains("Message buffer size")); + + let mut c = base.clone(); + c.max_pending_requests = 0; + assert!(c.validate().unwrap_err().contains("Max pending requests")); + + let mut c = base.clone(); + c.reconnect.backoff_multiplier = 0.0; + assert!(c.validate().unwrap_err().contains("Backoff multiplier")); + + let mut c = base.clone(); + c.reconnect.jitter = 1.5; + assert!(c.validate().unwrap_err().contains("Jitter")); + + // Native build also rejects an unparseable URL. + let mut c = base.clone(); + c.url = "not a url".to_string(); + assert!(c.validate().is_err()); + } + + #[test] + fn connection_url_appends_amp_when_query_already_present() { + let cfg = ClientConfig { + auth: AuthConfig::JwtParams { + token: "tok".into(), + }, + ..ClientConfig::new("ws://h/ws?x=1") + }; + assert_eq!(cfg.get_connection_url(), "ws://h/ws?x=1&token=tok"); + } + + #[test] + fn connection_url_with_custom_params() { + let mut params = HashMap::new(); + params.insert("foo".to_string(), "bar".to_string()); + let cfg = ClientConfig { + auth: AuthConfig::CustomParams { params }, + ..ClientConfig::new("ws://h/ws") + }; + let url = cfg.get_connection_url(); + assert!(url.starts_with("ws://h/ws?")); + assert!(url.contains("foo=bar")); + } + + #[test] + fn connection_headers_with_custom_headers_variant() { + let mut headers = HashMap::new(); + headers.insert("X-API-Key".to_string(), "k".to_string()); + let cfg = ClientConfig { + auth: AuthConfig::CustomHeaders { + headers: headers.clone(), + }, + ..ClientConfig::new("ws://h/ws") + }; + let h = cfg.get_connection_headers(); + assert_eq!(h.get("X-API-Key"), Some(&"k".to_string())); + } + + #[test] + fn connection_url_falls_through_for_no_param_auth() { + let cfg = ClientConfig { + auth: AuthConfig::JwtHeader { + token: "tok".into(), + }, + ..ClientConfig::new("ws://h/ws") + }; + // Header-based auth must NOT mutate the URL. + assert_eq!(cfg.get_connection_url(), "ws://h/ws"); + } + + #[test] + fn calculate_delay_zero_attempt_returns_initial() { + let cfg = ReconnectConfig { + jitter: 0.0, + ..ReconnectConfig::default() + }; + assert_eq!(cfg.calculate_delay(0), cfg.initial_delay); + } + + #[test] + fn auth_config_default_via_helper() { + // exercises the `Default` impl + AuthConfig::default branches. + let _ = AuthConfig::default(); + let _ = ReconnectConfig::default(); + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs index 033f852..90f549a 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-client/src/error.rs @@ -232,4 +232,93 @@ mod tests { panic!("Expected reconnection failed error"); } } + + #[test] + fn covers_all_constructors_and_display() { + // Stringy constructors → matching variants and messages. + for (err, expected_prefix) in [ + ( + ClientError::invalid_request_id("rid"), + "Invalid request ID:", + ), + (ClientError::invalid_url("not://valid"), "Invalid URL:"), + (ClientError::send_failed("eof"), "Failed to send message:"), + ( + ClientError::receive_failed("eof"), + "Failed to receive message:", + ), + (ClientError::subscription("topic"), "Subscription error:"), + (ClientError::configuration("bad"), "Configuration error:"), + (ClientError::internal("oops"), "Internal error:"), + (ClientError::authentication("nope"), "Authentication error:"), + ] { + let s = err.to_string(); + assert!( + s.starts_with(expected_prefix), + "expected prefix {expected_prefix:?} in {s:?}" + ); + } + + // Bare variants. + assert_eq!( + ClientError::NotConnected.to_string(), + "Client is not connected" + ); + assert_eq!( + ClientError::AlreadyConnected.to_string(), + "Client is already connected" + ); + } + + #[test] + fn from_impls_route_to_correct_variants() { + let json_err = serde_json::from_str::("not json").unwrap_err(); + assert!(matches!(ClientError::from(json_err), ClientError::Json(_))); + + let bidir_err = BidirectionalError::Timeout; + assert!(matches!( + ClientError::from(bidir_err), + ClientError::Bidirectional(_) + )); + + let io_err = std::io::Error::new(std::io::ErrorKind::Other, "io"); + assert!(matches!(ClientError::from(io_err), ClientError::Io(_))); + + let url_err = url::Url::parse("not a url").unwrap_err(); + assert!(matches!( + ClientError::from(url_err), + ClientError::UrlParse(_) + )); + } + + #[test] + fn recovery_classification_is_exhaustive_for_named_buckets() { + // Should reconnect → also recoverable. + for err in [ + ClientError::connection("x"), + ClientError::receive_failed("x"), + ClientError::NotConnected, + ] { + assert!(err.should_reconnect()); + assert!(err.is_recoverable()); + } + + // Recoverable but no reconnect. + for err in [ClientError::timeout(1), ClientError::send_failed("x")] { + assert!(err.is_recoverable()); + assert!(!err.should_reconnect()); + } + + // Neither. + for err in [ + ClientError::authentication("x"), + ClientError::AlreadyConnected, + ClientError::invalid_url("x"), + ClientError::configuration("x"), + ClientError::internal("x"), + ] { + assert!(!err.is_recoverable()); + assert!(!err.should_reconnect()); + } + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml index 123ca6b..175447b 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/Cargo.toml @@ -24,6 +24,7 @@ ras-jsonrpc-bidirectional-server = { path = "../ras-jsonrpc-bidirectional-server ras-jsonrpc-bidirectional-client = { path = "../ras-jsonrpc-bidirectional-client" } ras-auth-core = { path = "../../../core/ras-auth-core" } ras-jsonrpc-types = { path = "../../ras-jsonrpc-types" } +ras-test-helpers = { path = "../../../test-utils/ras-test-helpers" } tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -37,4 +38,9 @@ async-trait = { workspace = true } uuid = { workspace = true } anyhow = { workspace = true } thiserror = { workspace = true } -chrono = { workspace = true } \ No newline at end of file +chrono = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "roundtrip" +harness = false \ No newline at end of file diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/benches/roundtrip.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/benches/roundtrip.rs new file mode 100644 index 0000000..c3ce50a --- /dev/null +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/benches/roundtrip.rs @@ -0,0 +1,118 @@ +//! Criterion bench measuring c2s call round-trip latency through a real +//! WebSocket connection (tokio-tungstenite client → axum server). + +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use axum::{Router, routing::get}; +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_auth_core::AuthenticatedUser; +use ras_jsonrpc_bidirectional_macro::jsonrpc_bidirectional_service; +use ras_jsonrpc_bidirectional_server::DefaultConnectionManager; +use ras_jsonrpc_bidirectional_server::service::{BuiltWebSocketService, websocket_handler}; +use ras_jsonrpc_bidirectional_types::ConnectionId; +use ras_test_helpers::{MockAuthProvider, spawn_tcp}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct EchoIn { + msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct EchoOut { + msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Ignored; + +jsonrpc_bidirectional_service!({ + service_name: BenchSvc, + client_to_server: [ + WITH_PERMISSIONS(["user"]) echo(EchoIn) -> EchoOut, + ], + server_to_client: [ + unused(Ignored), + ], + server_to_client_calls: [ + ] +}); + +#[derive(Clone)] +struct BenchImpl; + +#[async_trait] +impl BenchSvcService for BenchImpl { + async fn echo( + &self, + _client: ConnectionId, + _conns: &dyn ras_jsonrpc_bidirectional_types::ConnectionManager, + _user: &AuthenticatedUser, + req: EchoIn, + ) -> Result> { + Ok(EchoOut { msg: req.msg }) + } + + async fn notify_unused( + &self, + _connection_id: ConnectionId, + _params: Ignored, + ) -> ras_jsonrpc_bidirectional_types::Result<()> { + Ok(()) + } +} + +async fn start_server() -> String { + let cm = Arc::new(DefaultConnectionManager::new()); + let handler = Arc::new(BenchSvcHandler::new(Arc::new(BenchImpl), cm.clone())); + let svc = ras_jsonrpc_bidirectional_server::WebSocketServiceBuilder::builder() + .handler(handler) + .auth_provider(Arc::new(MockAuthProvider::default())) + .require_auth(false) + .build() + .build_with_manager(cm); + + type SvcType = BuiltWebSocketService< + BenchSvcHandler, + MockAuthProvider, + DefaultConnectionManager, + >; + let app: Router = Router::new() + .route("/ws", get(websocket_handler::)) + .with_state(svc); + let (addr, _h) = spawn_tcp(app).await; + tokio::time::sleep(Duration::from_millis(50)).await; + format!("ws://{addr}/ws") +} + +fn bench_roundtrip(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let client = rt.block_on(async { + let url = start_server().await; + let client = BenchSvcClientBuilder::new(url) + .with_jwt_token("user-token".to_string()) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + client + }); + + c.bench_function("ws_echo_roundtrip", |b| { + b.to_async(&rt).iter(|| async { + let r = client.echo(EchoIn { msg: "x".into() }).await.expect("echo"); + std::hint::black_box(r); + }); + }); + + rt.block_on(async { + let _ = client.disconnect().await; + }); +} + +criterion_group!(benches, bench_roundtrip); +criterion_main!(benches); diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/tests/e2e.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/tests/e2e.rs new file mode 100644 index 0000000..8d3d917 --- /dev/null +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-macro/tests/e2e.rs @@ -0,0 +1,202 @@ +//! End-to-end test for `jsonrpc_bidirectional_service!`: +//! generated client → real WebSocket → server handler → response/notification. +//! +//! Existing `bidirectional_integration.rs` exercises this thoroughly. This file +//! is a slim companion test that uses the shared `MockAuthProvider` and proves +//! the helper integration works. + +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use axum::{Router, routing::get}; +use ras_auth_core::AuthenticatedUser; +use ras_jsonrpc_bidirectional_macro::jsonrpc_bidirectional_service; +use ras_jsonrpc_bidirectional_server::DefaultConnectionManager; +use ras_jsonrpc_bidirectional_server::service::{BuiltWebSocketService, websocket_handler}; +use ras_jsonrpc_bidirectional_types::ConnectionId; +use ras_test_helpers::{MockAuthProvider, spawn_tcp}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EchoIn { + pub msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EchoOut { + pub msg: String, + pub user: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PushNote { + pub kind: String, +} + +jsonrpc_bidirectional_service!({ + service_name: Demo, + client_to_server: [ + UNAUTHORIZED hello(String) -> String, + WITH_PERMISSIONS(["user"]) echo(EchoIn) -> EchoOut, + ], + server_to_client: [ + ping(PushNote), + ], + server_to_client_calls: [ + ] +}); + +#[derive(Clone)] +struct DemoImpl; + +#[async_trait] +impl DemoService for DemoImpl { + async fn hello( + &self, + _client: ConnectionId, + _conns: &dyn ras_jsonrpc_bidirectional_types::ConnectionManager, + name: String, + ) -> Result> { + Ok(format!("hello, {name}")) + } + + async fn echo( + &self, + client: ConnectionId, + conns: &dyn ras_jsonrpc_bidirectional_types::ConnectionManager, + user: &AuthenticatedUser, + req: EchoIn, + ) -> Result> { + // Also push a server→client notification so the test can observe it. + let note = ras_jsonrpc_bidirectional_types::ServerNotification { + method: "ping".to_string(), + params: serde_json::to_value(PushNote { + kind: "after-echo".into(), + }) + .unwrap(), + metadata: None, + }; + let _ = conns + .send_to_connection( + client, + ras_jsonrpc_bidirectional_types::BidirectionalMessage::ServerNotification(note), + ) + .await; + + Ok(EchoOut { + msg: req.msg, + user: user.user_id.clone(), + }) + } + + async fn notify_ping( + &self, + _connection_id: ConnectionId, + _params: PushNote, + ) -> ras_jsonrpc_bidirectional_types::Result<()> { + Ok(()) + } +} + +async fn start_server() -> String { + let connection_manager = Arc::new(DefaultConnectionManager::new()); + let handler = Arc::new(DemoHandler::new( + Arc::new(DemoImpl), + connection_manager.clone(), + )); + + let ws_service = ras_jsonrpc_bidirectional_server::WebSocketServiceBuilder::builder() + .handler(handler) + .auth_provider(Arc::new(MockAuthProvider::default())) + .require_auth(false) + .build() + .build_with_manager(connection_manager); + + type SvcType = BuiltWebSocketService< + DemoHandler, + MockAuthProvider, + DefaultConnectionManager, + >; + let app: Router = Router::new() + .route("/ws", get(websocket_handler::)) + .with_state(ws_service); + + let (addr, _handle) = spawn_tcp(app).await; + // Give axum a tick to start serving. + tokio::time::sleep(Duration::from_millis(50)).await; + format!("ws://{addr}/ws") +} + +#[tokio::test] +async fn unauthorized_method_round_trips() { + let url = start_server().await; + let client = DemoClientBuilder::new(url) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + + let resp = client.hello("alice".to_string()).await.expect("hello ok"); + assert_eq!(resp, "hello, alice"); + + client.disconnect().await.expect("disconnect"); +} + +#[tokio::test] +async fn auth_method_succeeds_and_pushes_notification() { + let url = start_server().await; + let mut client = DemoClientBuilder::new(url) + .with_jwt_token("user-token".to_string()) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + + let pushed = Arc::new(AtomicBool::new(false)); + let pushed_flag = pushed.clone(); + client.on_ping(move |_n: PushNote| { + pushed_flag.store(true, Ordering::SeqCst); + }); + + let resp = client + .echo(EchoIn { + msg: "hi".to_string(), + }) + .await + .expect("echo ok"); + assert_eq!(resp.msg, "hi"); + assert_eq!(resp.user, "user-1"); + + // Wait briefly for the push to land. + let deadline = std::time::Instant::now() + Duration::from_secs(2); + while !pushed.load(Ordering::SeqCst) && std::time::Instant::now() < deadline { + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert!( + pushed.load(Ordering::SeqCst), + "expected ping notification to arrive" + ); + + client.disconnect().await.expect("disconnect"); +} + +#[tokio::test] +async fn auth_method_rejected_for_readonly_user() { + let url = start_server().await; + let client = DemoClientBuilder::new(url) + .with_jwt_token("readonly-token".to_string()) + .build() + .await + .expect("client build"); + client.connect().await.expect("connect"); + + let result = client.echo(EchoIn { msg: "nope".into() }).await; + assert!( + result.is_err(), + "readonly token must not be able to call echo" + ); + + client.disconnect().await.expect("disconnect"); +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml index 3bb3dd5..7c6b3f7 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/Cargo.toml @@ -30,4 +30,5 @@ futures = { workspace = true } dashmap = { workspace = true } [dev-dependencies] -tokio-test = { workspace = true } \ No newline at end of file +tokio-test = { workspace = true } +ras-jsonrpc-types = { path = "../../ras-jsonrpc-types" } \ No newline at end of file diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs index 7ea454c..63c9bc0 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/connection.rs @@ -9,15 +9,12 @@ use tokio::sync::{RwLock, mpsc}; #[derive(Debug, Clone)] pub struct ChannelMessageSender { connection_id: ConnectionId, - sender: mpsc::UnboundedSender, + sender: mpsc::Sender, } impl ChannelMessageSender { /// Create a new channel message sender - pub fn new( - connection_id: ConnectionId, - sender: mpsc::UnboundedSender, - ) -> Self { + pub fn new(connection_id: ConnectionId, sender: mpsc::Sender) -> Self { Self { connection_id, sender, @@ -26,7 +23,7 @@ impl ChannelMessageSender { /// Send a message through the channel pub async fn send(&self, message: BidirectionalMessage) -> Result<(), String> { - self.sender.send(message).map_err(|e| e.to_string()) + self.sender.send(message).await.map_err(|e| e.to_string()) } /// Get the connection ID @@ -128,3 +125,90 @@ impl ConnectionContext { } } } + +#[cfg(test)] +mod tests { + use super::*; + use ras_auth_core::AuthenticatedUser; + use std::collections::HashSet; + + fn user(id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: id.to_string(), + permissions: perms.iter().map(|s| s.to_string()).collect::>(), + metadata: None, + } + } + + fn ctx() -> ConnectionContext { + let id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(8); + let sender = ChannelMessageSender::new(id, tx); + ConnectionContext::new(id, sender) + } + + #[tokio::test] + async fn channel_sender_send_propagates_and_id_round_trips() { + let id = ConnectionId::new(); + let (tx, mut rx) = mpsc::channel(2); + let sender = ChannelMessageSender::new(id, tx); + assert_eq!(sender.connection_id(), id); + + sender.send(BidirectionalMessage::Ping).await.unwrap(); + let received = rx.recv().await.unwrap(); + assert!(matches!(received, BidirectionalMessage::Ping)); + } + + #[tokio::test] + async fn channel_sender_returns_string_error_when_closed() { + let id = ConnectionId::new(); + let (tx, rx) = mpsc::channel(1); + drop(rx); + let sender = ChannelMessageSender::new(id, tx); + let err = sender.send(BidirectionalMessage::Ping).await.unwrap_err(); + assert!(!err.is_empty()); + } + + #[tokio::test] + async fn auth_state_round_trips() { + let c = ctx(); + assert!(!c.is_authenticated().await); + assert!(c.get_user().await.is_none()); + assert!(!c.has_permission("admin").await); + + c.set_user(user("alice", &["admin"])).await; + assert!(c.is_authenticated().await); + assert_eq!(c.get_user().await.unwrap().user_id, "alice"); + assert!(c.has_permission("admin").await); + assert!(!c.has_permission("nope").await); + + c.clear_user().await; + assert!(!c.is_authenticated().await); + } + + #[tokio::test] + async fn subscriptions_round_trip() { + let c = ctx(); + assert!(c.get_subscriptions().await.is_empty()); + assert!(!c.is_subscribed_to("t1").await); + + c.subscribe("t1".into()).await; + c.subscribe("t2".into()).await; + assert!(c.is_subscribed_to("t1").await); + assert_eq!(c.get_subscriptions().await.len(), 2); + + assert!(c.unsubscribe("t1").await); + assert!(!c.is_subscribed_to("t1").await); + // Idempotent: removing again returns false. + assert!(!c.unsubscribe("t1").await); + } + + #[tokio::test] + async fn metadata_get_set() { + let c = ctx(); + assert!(c.get_metadata("k").await.is_none()); + c.set_metadata("k", serde_json::json!("v")).await; + assert_eq!(c.get_metadata("k").await.unwrap(), serde_json::json!("v")); + assert!(c.get_metadata("missing").await.is_none()); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs index 5105325..52fb538 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/error.rs @@ -74,3 +74,66 @@ impl ServerError { /// Convenience type alias for server operation results pub type ServerResult = Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn status_codes_per_variant() { + assert_eq!( + ServerError::AuthenticationFailed(AuthError::InvalidToken).to_status_code(), + StatusCode::UNAUTHORIZED + ); + assert_eq!( + ServerError::PermissionDenied("nope".into()).to_status_code(), + StatusCode::FORBIDDEN + ); + assert_eq!( + ServerError::ConnectionNotFound("abc".into()).to_status_code(), + StatusCode::NOT_FOUND + ); + assert_eq!( + ServerError::InvalidRequest("bad".into()).to_status_code(), + StatusCode::BAD_REQUEST + ); + assert_eq!( + ServerError::HandlerNotFound("m".into()).to_status_code(), + StatusCode::NOT_IMPLEMENTED + ); + for variant in [ + ServerError::UpgradeFailed("x".into()), + ServerError::RoutingFailed("x".into()), + ServerError::WebSocketError("x".into()), + ServerError::Internal("x".into()), + ] { + assert_eq!(variant.to_status_code(), StatusCode::INTERNAL_SERVER_ERROR); + } + + // From impls + let json_err = serde_json::from_str::("not json").unwrap_err(); + let from_json: ServerError = json_err.into(); + assert_eq!( + from_json.to_status_code(), + StatusCode::INTERNAL_SERVER_ERROR + ); + + let bidir_err: BidirectionalError = BidirectionalError::SendError("e".into()); + let from_bidir: ServerError = bidir_err.into(); + assert_eq!( + from_bidir.to_status_code(), + StatusCode::INTERNAL_SERVER_ERROR + ); + + let auth_err: AuthError = AuthError::TokenExpired; + let from_auth: ServerError = auth_err.into(); + assert_eq!(from_auth.to_status_code(), StatusCode::UNAUTHORIZED); + + // Display formatting: spot-check each kind once. + assert!( + ServerError::UpgradeFailed("x".into()) + .to_string() + .starts_with("WebSocket upgrade failed:") + ); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs index dbae66f..b9f07eb 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/handler.rs @@ -93,7 +93,8 @@ pub struct WebSocketHandler { /// Connection context context: Arc, /// Channel for receiving messages to send to client - message_rx: mpsc::UnboundedReceiver, + message_rx: mpsc::Receiver, + max_message_size: usize, } impl WebSocketHandler { @@ -101,12 +102,14 @@ impl WebSocketHandler { pub fn new( handler: Arc, context: Arc, - message_rx: mpsc::UnboundedReceiver, + message_rx: mpsc::Receiver, + max_message_size: usize, ) -> Self { Self { handler, context, message_rx, + max_message_size, } } @@ -205,10 +208,22 @@ impl WebSocketHandler { ) -> ServerResult<()> { match msg { Message::Text(text) => { - debug!("Received text message: {}", text); + if text.len() > self.max_message_size { + warn!("Received oversized text message: {} bytes", text.len()); + return Err(ServerError::InvalidRequest( + "Message exceeds maximum size".to_string(), + )); + } + debug!("Received text message ({} bytes)", text.len()); self.handle_text_message(text.to_string(), socket).await } Message::Binary(data) => { + if data.len() > self.max_message_size { + warn!("Received oversized binary message: {} bytes", data.len()); + return Err(ServerError::InvalidRequest( + "Message exceeds maximum size".to_string(), + )); + } debug!("Received binary message ({} bytes)", data.len()); // Try to parse as UTF-8 text match String::from_utf8(data.to_vec()) { @@ -259,10 +274,9 @@ impl WebSocketHandler { } // If neither worked, return error - Err(ServerError::InvalidRequest(format!( - "Could not parse message as JSON-RPC or bidirectional message: {}", - text - ))) + Err(ServerError::InvalidRequest( + "Could not parse message as JSON-RPC or bidirectional message".to_string(), + )) } /// Handle bidirectional messages @@ -339,3 +353,71 @@ impl WebSocketHandler { .map_err(|e| ServerError::WebSocketError(e.to_string())) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::connection::ChannelMessageSender; + use ras_jsonrpc_bidirectional_types::ConnectionId; + + /// A minimal MessageHandler that only implements the required method — + /// every other method falls through to the default impl, which is what + /// these tests are verifying. + struct PassThrough; + + #[async_trait] + impl MessageHandler for PassThrough { + async fn handle_request( + &self, + _request: JsonRpcRequest, + _context: Arc, + ) -> ServerResult> { + Ok(None) + } + } + + fn ctx() -> Arc { + let id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(4); + let sender = ChannelMessageSender::new(id, tx); + Arc::new(ConnectionContext::new(id, sender)) + } + + #[tokio::test] + async fn default_handle_subscribe_writes_to_context() { + let h = PassThrough; + let c = ctx(); + h.handle_subscribe(vec!["a".into(), "b".into()], c.clone()) + .await + .unwrap(); + assert!(c.is_subscribed_to("a").await); + assert!(c.is_subscribed_to("b").await); + } + + #[tokio::test] + async fn default_handle_unsubscribe_removes_from_context() { + let h = PassThrough; + let c = ctx(); + c.subscribe("a".into()).await; + c.subscribe("b".into()).await; + h.handle_unsubscribe(vec!["a".into()], c.clone()) + .await + .unwrap(); + assert!(!c.is_subscribed_to("a").await); + assert!(c.is_subscribed_to("b").await); + } + + #[tokio::test] + async fn default_lifecycle_methods_succeed() { + let h = PassThrough; + let c = ctx(); + h.on_connect(c.clone()).await.unwrap(); + h.on_ping(c.clone()).await.unwrap(); + h.on_pong(c.clone()).await.unwrap(); + h.on_disconnect(c.clone(), Some("bye".into())) + .await + .unwrap(); + // None reason path too. + h.on_disconnect(c, None).await.unwrap(); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs index 1d9cbf2..ea9ddf0 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/manager.rs @@ -85,7 +85,7 @@ impl DefaultConnectionManager { impl ConnectionManager for DefaultConnectionManager { async fn add_connection(&self, info: ConnectionInfo) -> Result<()> { // Create a dummy sender - real senders should be added via add_connection_with_sender - let (tx, _rx) = mpsc::unbounded_channel(); + let (tx, _rx) = mpsc::channel(1); let sender = ChannelMessageSender::new(info.id, tx); self.connections.insert(info.id, (info.clone(), sender)); info!("Added connection: {}", info.id); @@ -185,7 +185,7 @@ impl ConnectionManager for DefaultConnectionManager { // Update topic subscriptions self.subscriptions .entry(topic.clone()) - .or_insert_with(Vec::new) + .or_default() .push(id); // Update connection subscriptions @@ -234,7 +234,7 @@ impl ConnectionManager for DefaultConnectionManager { .1 .send(message) .await - .map_err(|e| ras_jsonrpc_bidirectional_types::BidirectionalError::SendError(e))?; + .map_err(ras_jsonrpc_bidirectional_types::BidirectionalError::SendError)?; } else { warn!("Attempted to send to non-existent connection: {}", id); } @@ -340,7 +340,7 @@ impl ConnectionManager for DefaultConnectionManager { ) -> Result<()> { self.pending_requests .entry(connection_id) - .or_insert_with(HashMap::new) + .or_default() .insert(request_id, response_sender); debug!( @@ -373,17 +373,16 @@ impl ConnectionManager for DefaultConnectionManager { connection_id: ConnectionId, response: ras_jsonrpc_types::JsonRpcResponse, ) -> Result { - if let Some(request_id) = &response.id { - if let Some(sender) = self + if let Some(request_id) = &response.id + && let Some(sender) = self .remove_pending_request(connection_id, request_id) .await? - { - if let Err(_) = sender.send(response) { - warn!("Failed to send response to pending request - receiver dropped"); - } - debug!("Handled pending response for connection: {}", connection_id); - return Ok(true); + { + if sender.send(response).is_err() { + warn!("Failed to send response to pending request - receiver dropped"); } + debug!("Handled pending response for connection: {}", connection_id); + return Ok(true); } Ok(false) } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs index 407046b..e92293e 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/router.rs @@ -219,7 +219,7 @@ mod tests { // Create test context let connection_id = ConnectionId::new(); - let (tx, _rx) = mpsc::unbounded_channel(); + let (tx, _rx) = mpsc::channel(1); let sender = crate::connection::ChannelMessageSender::new(connection_id, tx); let context = Arc::new(ConnectionContext::new(connection_id, sender)); @@ -258,4 +258,158 @@ mod tests { assert!(response.error.is_some()); assert_eq!(response.error.unwrap().code, -32601); // METHOD_NOT_FOUND } + + fn test_context() -> Arc { + let connection_id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(1); + let sender = crate::connection::ChannelMessageSender::new(connection_id, tx); + Arc::new(ConnectionContext::new(connection_id, sender)) + } + + #[tokio::test] + async fn register_low_level_handler_returns_explicit_response() { + let mut router = MessageRouter::new(); + router.register("low.echo", |req, _ctx| async move { + // Hand-built response — proves the low-level register path is wired. + Ok(req + .id + .clone() + .map(|id| JsonRpcResponse::success(req.params.unwrap_or(json!(null)), Some(id)))) + }); + + let ctx = test_context(); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "low.echo".into(), + params: Some(json!(42)), + id: Some(json!(7)), + }; + let resp = router.handle_request(req, ctx).await.unwrap().unwrap(); + assert_eq!(resp.id, Some(json!(7))); + assert_eq!(resp.result.unwrap(), json!(42)); + } + + #[tokio::test] + async fn register_notification_never_returns_response() { + let mut router = MessageRouter::new(); + let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let c2 = counter.clone(); + router.register_notification("evt.tick", move |_req, _ctx| { + let c = c2.clone(); + async move { + c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Ok(()) + } + }); + + let ctx = test_context(); + // Even if the request has an id, the notification handler returns None. + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "evt.tick".into(), + params: None, + id: Some(json!(1)), + }; + assert!(router.handle_request(req, ctx).await.unwrap().is_none()); + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn handler_error_with_id_becomes_error_response() { + let mut router = MessageRouter::new(); + router.register("explode", |_req, _ctx| async move { + Err::, _>(ServerError::Internal("kaboom".into())) + }); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "explode".into(), + params: None, + id: Some(json!("rid")), + }; + let resp = router.handle_request(req, test_context()).await.unwrap(); + let resp = resp.unwrap(); + assert_eq!(resp.id, Some(json!("rid"))); + // `internal_error` constructor strips details for security; the wire + // message is the canonical "Internal error" string. We confirm the code + // and the absence of leaked detail. + let err = resp.error.unwrap(); + assert_eq!(err.code, -32603); + assert!(!err.message.contains("kaboom")); + } + + #[tokio::test] + async fn handler_error_without_id_propagates_as_err() { + let mut router = MessageRouter::new(); + router.register("explode", |_req, _ctx| async move { + Err::, _>(ServerError::Internal("kaboom".into())) + }); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "explode".into(), + params: None, + id: None, + }; + let result = router.handle_request(req, test_context()).await; + assert!(matches!(result, Err(ServerError::Internal(_)))); + } + + #[tokio::test] + async fn register_value_with_notification_request_returns_none() { + let mut router = MessageRouter::new(); + router.register_value("v.echo", |req, _ctx| async move { + Ok::(req.params.unwrap_or(json!(null))) + }); + // No id ⇒ notification. The register_value branch should return Ok(None). + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "v.echo".into(), + params: Some(json!(true)), + id: None, + }; + assert!( + router + .handle_request(req, test_context()) + .await + .unwrap() + .is_none() + ); + } + + #[tokio::test] + async fn register_value_handler_error_with_id_becomes_error_response() { + let mut router = MessageRouter::new(); + router.register_value("v.fail", |_req, _ctx| async move { + Err::(ServerError::Internal("nope".into())) + }); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "v.fail".into(), + params: None, + id: Some(json!(9)), + }; + let resp = router.handle_request(req, test_context()).await.unwrap(); + let resp = resp.unwrap(); + assert_eq!(resp.id, Some(json!(9))); + let err = resp.error.unwrap(); + assert_eq!(err.code, -32603); + assert!(!err.message.contains("nope")); + } + + #[tokio::test] + async fn unknown_method_without_id_returns_no_response() { + let router = MessageRouter::new(); + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "ghost".into(), + params: None, + id: None, + }; + assert!( + router + .handle_request(req, test_context()) + .await + .unwrap() + .is_none() + ); + } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs index 3b85587..2ac6df4 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/service.rs @@ -16,7 +16,11 @@ use std::sync::Arc; use tokio::sync::mpsc; use tracing::{error, info}; +const DEFAULT_MESSAGE_CHANNEL_CAPACITY: usize = 1024; +const DEFAULT_MAX_MESSAGE_SIZE: usize = 1024 * 1024; + /// Trait for services that handle WebSocket JSON-RPC communication +#[allow(async_fn_in_trait)] pub trait WebSocketService: Clone + Send + Sync + 'static { /// The message handler type type Handler: MessageHandler; @@ -37,6 +41,16 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { /// Check if authentication is required fn require_auth(&self) -> bool; + /// Maximum queued outbound messages per connection. + fn message_channel_capacity(&self) -> usize { + DEFAULT_MESSAGE_CHANNEL_CAPACITY + } + + /// Maximum accepted inbound WebSocket message size in bytes. + fn max_message_size(&self) -> usize { + DEFAULT_MAX_MESSAGE_SIZE + } + /// Handle WebSocket upgrade async fn handle_upgrade( &self, @@ -73,7 +87,8 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { info!("New WebSocket connection: {}", connection_id); // Create message channel for this connection - let (message_tx, message_rx) = mpsc::unbounded_channel(); + let channel_capacity = service.message_channel_capacity().max(1); + let (message_tx, message_rx) = mpsc::channel(channel_capacity); let sender = ChannelMessageSender::new(connection_id, message_tx); // Create connection info and add to manager @@ -93,10 +108,15 @@ pub trait WebSocketService: Clone + Send + Sync + 'static { .connection_manager() .add_connection_with_sender(info, Box::new(sender.clone())) .await - .map_err(|e| ServerError::ConnectionError(e))?; + .map_err(ServerError::ConnectionError)?; // Create and run WebSocket handler - let handler = WebSocketHandler::new(service.handler(), context.clone(), message_rx); + let handler = WebSocketHandler::new( + service.handler(), + context.clone(), + message_rx, + service.max_message_size(), + ); // Handle the connection (this will block until connection closes) let result = handler.run(socket).await; @@ -127,6 +147,12 @@ pub struct WebSocketServiceBuilder { /// Whether authentication is required #[builder(default = false)] require_auth: bool, + /// Maximum queued outbound messages per connection + #[builder(default = DEFAULT_MESSAGE_CHANNEL_CAPACITY)] + message_channel_capacity: usize, + /// Maximum accepted inbound WebSocket message size in bytes + #[builder(default = DEFAULT_MAX_MESSAGE_SIZE)] + max_message_size: usize, } impl WebSocketServiceBuilder @@ -143,6 +169,8 @@ where .connection_manager .unwrap_or_else(|| Arc::new(DefaultConnectionManager::new())), require_auth: self.require_auth, + message_channel_capacity: self.message_channel_capacity, + max_message_size: self.max_message_size, } } } @@ -160,6 +188,8 @@ where auth_provider: self.auth_provider, connection_manager: manager, require_auth: self.require_auth, + message_channel_capacity: self.message_channel_capacity, + max_message_size: self.max_message_size, } } } @@ -170,6 +200,8 @@ pub struct BuiltWebSocketService { auth_provider: Arc, connection_manager: Arc, require_auth: bool, + message_channel_capacity: usize, + max_message_size: usize, } impl Clone for BuiltWebSocketService { @@ -179,6 +211,8 @@ impl Clone for BuiltWebSocketService { auth_provider: self.auth_provider.clone(), connection_manager: self.connection_manager.clone(), require_auth: self.require_auth, + message_channel_capacity: self.message_channel_capacity, + max_message_size: self.max_message_size, } } } @@ -208,6 +242,14 @@ where fn require_auth(&self) -> bool { self.require_auth } + + fn message_channel_capacity(&self) -> usize { + self.message_channel_capacity + } + + fn max_message_size(&self) -> usize { + self.max_message_size + } } /// Convenience function to create a simple router-based service diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs index e51017d..98b34a4 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/src/upgrade.rs @@ -26,31 +26,31 @@ impl WebSocketUpgrade { /// Extract authentication token from headers pub fn extract_auth_token(&self) -> Option { // Try Authorization header first (Bearer token) - if let Some(auth_header) = self.headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if auth_str.starts_with("Bearer ") { - return Some(auth_str[7..].to_string()); - } - // Also support just the token without "Bearer " prefix - return Some(auth_str.to_string()); + if let Some(auth_header) = self.headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + { + if let Some(token) = auth_str.strip_prefix("Bearer ") { + return Some(token.to_string()); } + // Also support just the token without "Bearer " prefix + return Some(auth_str.to_string()); } // Try custom WebSocket auth headers - if let Some(token_header) = self.headers.get("sec-websocket-protocol") { - if let Ok(token_str) = token_header.to_str() { - // Support protocols like "token.{jwt_token}" - if token_str.starts_with("token.") { - return Some(token_str[6..].to_string()); - } + if let Some(token_header) = self.headers.get("sec-websocket-protocol") + && let Ok(token_str) = token_header.to_str() + { + // Support protocols like "token.{jwt_token}" + if let Some(token) = token_str.strip_prefix("token.") { + return Some(token.to_string()); } } // Try X-Auth-Token header - if let Some(token_header) = self.headers.get("x-auth-token") { - if let Ok(token_str) = token_header.to_str() { - return Some(token_str.to_string()); - } + if let Some(token_header) = self.headers.get("x-auth-token") + && let Ok(token_str) = token_header.to_str() + { + return Some(token_str.to_string()); } None @@ -167,7 +167,7 @@ impl WebSocketUpgrade { for header_name in &ip_headers { if let Some(value) = self.get_header(header_name) { // For X-Forwarded-For, take the first IP - let ip = value.split(',').next().unwrap_or(&value).trim(); + let ip = value.split(',').next().unwrap_or(value.as_str()).trim(); if !ip.is_empty() { return Some(ip.to_string()); } @@ -220,22 +220,21 @@ mod tests { // Test Bearer token extraction logic headers.insert("authorization", "Bearer abc123".parse().unwrap()); - if let Some(auth_header) = headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if auth_str.starts_with("Bearer ") { - assert_eq!(&auth_str[7..], "abc123"); - } - } + if let Some(auth_header) = headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + && let Some(token) = auth_str.strip_prefix("Bearer ") + { + assert_eq!(token, "abc123"); } // Test X-Forwarded-For parsing logic headers.clear(); headers.insert("x-forwarded-for", "192.168.1.1, 10.0.0.1".parse().unwrap()); - if let Some(header_value) = headers.get("x-forwarded-for") { - if let Ok(value) = header_value.to_str() { - let ip = value.split(',').next().unwrap_or(&value).trim(); - assert_eq!(ip, "192.168.1.1"); - } + if let Some(header_value) = headers.get("x-forwarded-for") + && let Ok(value) = header_value.to_str() + { + let ip = value.split(',').next().unwrap_or(value).trim(); + assert_eq!(ip, "192.168.1.1"); } } diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs new file mode 100644 index 0000000..aa58808 --- /dev/null +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-server/tests/manager_unit.rs @@ -0,0 +1,262 @@ +//! Direct unit tests for `DefaultConnectionManager`. +//! +//! The end-to-end suite in `examples/bidirectional-chat` and +//! `ras-jsonrpc-bidirectional-macro/tests/bidirectional_integration.rs` +//! covers the manager's happy path indirectly. This file pins down the +//! manager's contract on its own — subscriptions, broadcast counts, +//! permission filtering, and pending-request lifecycle — without spinning +//! up a real WebSocket. + +use std::collections::HashSet; +use std::sync::Arc; + +use ras_auth_core::AuthenticatedUser; +use ras_jsonrpc_bidirectional_server::DefaultConnectionManager; +use ras_jsonrpc_bidirectional_server::connection::ChannelMessageSender; +use ras_jsonrpc_bidirectional_types::{ + BidirectionalMessage, ConnectionId, ConnectionInfo, ConnectionManager, +}; +use ras_jsonrpc_types::JsonRpcResponse; +use tokio::sync::{mpsc, oneshot}; + +fn user(id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: id.to_string(), + permissions: perms.iter().map(|s| s.to_string()).collect::>(), + metadata: None, + } +} + +/// Build a connection paired with a real receiver so we can observe sends. +async fn join( + mgr: &DefaultConnectionManager, +) -> (ConnectionId, mpsc::Receiver) { + let id = ConnectionId::new(); + let (tx, rx) = mpsc::channel(16); + let sender = ChannelMessageSender::new(id, tx); + let info = ConnectionInfo::new(id); + mgr.add_connection_with_sender_direct(info, sender) + .await + .unwrap(); + (id, rx) +} + +#[tokio::test] +async fn add_remove_round_trip_and_inspect() { + let mgr = DefaultConnectionManager::new(); + assert_eq!(mgr.connection_count(), 0); + + let (a, _ra) = join(&mgr).await; + let (b, _rb) = join(&mgr).await; + assert_eq!(mgr.connection_count(), 2); + + let ids = mgr.get_connection_ids(); + assert!(ids.contains(&a) && ids.contains(&b)); + assert!(mgr.connection_exists(a).await.unwrap()); + assert!(mgr.get_sender(a).is_some()); + + // Removing a missing id is logged-and-ignored, not an error. + mgr.remove_connection(ConnectionId::new()).await.unwrap(); + + mgr.remove_connection(a).await.unwrap(); + assert_eq!(mgr.connection_count(), 1); + assert!(!mgr.connection_exists(a).await.unwrap()); + assert!(mgr.get_sender(a).is_none()); +} + +#[tokio::test] +async fn add_connection_with_sender_box_downcasts() { + let mgr = DefaultConnectionManager::new(); + let id = ConnectionId::new(); + let (tx, _rx) = mpsc::channel(1); + let sender = ChannelMessageSender::new(id, tx); + // Round-trip through Box as the trait method requires. + let boxed: Box = Box::new(sender); + mgr.add_connection_with_sender(ConnectionInfo::new(id), boxed) + .await + .unwrap(); + assert!(mgr.connection_exists(id).await.unwrap()); + assert!(mgr.get_sender(id).is_some()); +} + +#[tokio::test] +async fn add_connection_with_unknown_sender_falls_back_to_dummy() { + let mgr = DefaultConnectionManager::new(); + let id = ConnectionId::new(); + let bogus: Box = Box::new(123u32); + mgr.add_connection_with_sender(ConnectionInfo::new(id), bogus) + .await + .unwrap(); + assert!(mgr.connection_exists(id).await.unwrap()); +} + +#[tokio::test] +async fn subscriptions_track_topics_and_clean_up_on_remove() { + let mgr = DefaultConnectionManager::new(); + let (a, _ra) = join(&mgr).await; + let (b, _rb) = join(&mgr).await; + + mgr.add_subscription(a, "room:1".into()).await.unwrap(); + mgr.add_subscription(b, "room:1".into()).await.unwrap(); + mgr.add_subscription(b, "room:2".into()).await.unwrap(); + + let topics = mgr.get_active_topics(); + assert!(topics.contains(&"room:1".to_string())); + assert!(topics.contains(&"room:2".to_string())); + + let r1: HashSet<_> = mgr.get_topic_connections("room:1").into_iter().collect(); + assert!(r1.contains(&a) && r1.contains(&b)); + + let subs_b = mgr.get_subscriptions(b).await.unwrap(); + assert!(subs_b.iter().any(|s| s == "room:1")); + assert!(subs_b.iter().any(|s| s == "room:2")); + + // Direct unsubscribe on the only-non-empty topic frees it. + mgr.remove_subscription(a, "room:1").await.unwrap(); + mgr.remove_subscription(b, "room:1").await.unwrap(); + assert!(!mgr.get_active_topics().contains(&"room:1".to_string())); + + // Removing a connection prunes any remaining subscriptions for it. + mgr.remove_connection(b).await.unwrap(); + // room:2 had only b, so it should be gone. + assert!(!mgr.get_active_topics().contains(&"room:2".to_string())); +} + +#[tokio::test] +async fn subscribed_connections_returns_full_info_for_topic() { + let mgr = DefaultConnectionManager::new(); + let (a, _ra) = join(&mgr).await; + let (_b, _rb) = join(&mgr).await; + mgr.add_subscription(a, "t".into()).await.unwrap(); + let subs = mgr.get_subscribed_connections("t").await.unwrap(); + assert_eq!(subs.len(), 1); + assert_eq!(subs[0].id, a); +} + +#[tokio::test] +async fn user_set_clear_filter_paths_for_broadcasts() { + let mgr = DefaultConnectionManager::new(); + let (auth_id, mut auth_rx) = join(&mgr).await; + let (admin_id, mut admin_rx) = join(&mgr).await; + let (_anon_id, mut anon_rx) = join(&mgr).await; + + mgr.set_connection_user(auth_id, user("u", &["read"])) + .await + .unwrap(); + mgr.set_connection_user(admin_id, user("a", &["read", "admin"])) + .await + .unwrap(); + // anon stays unauthenticated. + + // broadcast_to_authenticated reaches both authenticated peers. + let n = mgr + .broadcast_to_authenticated(BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 2); + assert!(auth_rx.try_recv().is_ok()); + assert!(admin_rx.try_recv().is_ok()); + assert!(anon_rx.try_recv().is_err()); + + // broadcast_to_permission only reaches the admin. + let n = mgr + .broadcast_to_permission("admin", BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 1); + assert!(admin_rx.try_recv().is_ok()); + + // clear_connection_user flips the auth flag back. + mgr.clear_connection_user(auth_id).await.unwrap(); + let n = mgr + .broadcast_to_authenticated(BidirectionalMessage::Pong) + .await + .unwrap(); + assert_eq!(n, 1); + + // Setting/clearing user on missing id is best-effort, not an error. + mgr.set_connection_user(ConnectionId::new(), user("ghost", &[])) + .await + .unwrap(); + mgr.clear_connection_user(ConnectionId::new()) + .await + .unwrap(); +} + +#[tokio::test] +async fn broadcast_to_topic_counts_recipients_and_skips_empty() { + let mgr = DefaultConnectionManager::new(); + let (a, mut ra) = join(&mgr).await; + let (_b, _rb) = join(&mgr).await; + mgr.add_subscription(a, "t".into()).await.unwrap(); + + let n = mgr + .broadcast_to_topic("t", BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 1); + assert!(ra.try_recv().is_ok()); + + // Topic with no subscribers reports zero. + let n = mgr + .broadcast_to_topic("missing", BidirectionalMessage::Ping) + .await + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn pending_request_lifecycle() { + let mgr = DefaultConnectionManager::new(); + let (id, _rx) = join(&mgr).await; + + let (tx, rx) = oneshot::channel(); + mgr.register_pending_request(id, serde_json::json!("rid"), tx) + .await + .unwrap(); + + // remove_pending_request hands back the sender. + let pulled = mgr + .remove_pending_request(id, &serde_json::json!("rid")) + .await + .unwrap(); + assert!(pulled.is_some()); + drop(pulled); + drop(rx); + + // handle_pending_response with no registered id reports false. + let resp = JsonRpcResponse::success(serde_json::json!("ok"), Some(serde_json::json!("rid"))); + let handled = mgr.handle_pending_response(id, resp).await.unwrap(); + assert!(!handled); + + // Register again, then route a real response through handle_pending_response. + let (tx, rx) = oneshot::channel(); + mgr.register_pending_request(id, serde_json::json!("rid2"), tx) + .await + .unwrap(); + let resp = JsonRpcResponse::success(serde_json::json!(7), Some(serde_json::json!("rid2"))); + assert!(mgr.handle_pending_response(id, resp).await.unwrap()); + let received = rx.await.unwrap(); + assert_eq!(received.result.unwrap(), serde_json::json!(7)); + + // Removing for a connection that never registered any returns None. + let pulled = mgr + .remove_pending_request(ConnectionId::new(), &serde_json::json!("nope")) + .await + .unwrap(); + assert!(pulled.is_none()); +} + +#[tokio::test] +async fn send_to_missing_connection_is_silent_ok() { + let mgr = DefaultConnectionManager::new(); + // Nothing registered — manager logs and returns Ok. + mgr.send_to_connection(ConnectionId::new(), BidirectionalMessage::Ping) + .await + .unwrap(); +} + +#[tokio::test] +async fn default_impl_is_equivalent_to_new() { + let _ = Arc::new(DefaultConnectionManager::default()); +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs index 1bc6a62..2535447 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/error.rs @@ -86,3 +86,80 @@ impl BidirectionalError { Self::Custom(error.to_string()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ConnectionId; + + #[test] + fn helpers_wrap_display_into_correct_variant() { + let id = ConnectionId::new(); + assert!( + BidirectionalError::ConnectionNotFound(id) + .to_string() + .contains(&id.to_string()) + ); + assert_eq!( + BidirectionalError::ConnectionAlreadyExists(id).to_string(), + format!("Connection already exists: {id}") + ); + assert_eq!( + BidirectionalError::SendError("oops".into()).to_string(), + "Failed to send message: oops" + ); + assert_eq!( + BidirectionalError::BroadcastError { + topic: "t".into(), + reason: "r".into(), + } + .to_string(), + "Failed to broadcast to topic 't': r" + ); + assert_eq!( + BidirectionalError::AuthenticationRequired.to_string(), + "Authentication required" + ); + assert_eq!( + BidirectionalError::PermissionDenied("admin".into()).to_string(), + "Permission denied: admin" + ); + assert_eq!( + BidirectionalError::InvalidTopic("foo".into()).to_string(), + "Invalid subscription topic: foo" + ); + assert_eq!( + BidirectionalError::ConnectionClosed.to_string(), + "Connection closed" + ); + assert_eq!(BidirectionalError::Timeout.to_string(), "Request timeout"); + assert_eq!( + BidirectionalError::RpcError("nope".into()).to_string(), + "RPC error: nope" + ); + assert_eq!( + BidirectionalError::InvalidResponse("garbage".into()).to_string(), + "Invalid response: garbage" + ); + assert_eq!( + BidirectionalError::ConnectionError("eof".into()).to_string(), + "Connection error: eof" + ); + + // Constructor helpers. + let we = BidirectionalError::websocket("boom"); + assert!(matches!(we, BidirectionalError::WebSocketError(ref s) if s == "boom")); + let ie = BidirectionalError::internal("ugh"); + assert!(matches!(ie, BidirectionalError::InternalError(ref s) if s == "ugh")); + let ce = BidirectionalError::custom("plain"); + assert!(matches!(ce, BidirectionalError::Custom(ref s) if s == "plain")); + } + + #[test] + fn from_serde_json_error() { + let parse_err = serde_json::from_str::("{not json}").unwrap_err(); + let wrapped: BidirectionalError = parse_err.into(); + assert!(matches!(wrapped, BidirectionalError::SerializationError(_))); + assert!(wrapped.to_string().starts_with("Serialization error:")); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs index c593806..b78829f 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/manager.rs @@ -184,3 +184,291 @@ pub trait ConnectionManagerExt: ConnectionManager { // Blanket implementation for all ConnectionManager types impl ConnectionManagerExt for T {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{BidirectionalError, BroadcastMessage, ConnectionId, ConnectionInfo}; + use ras_auth_core::AuthenticatedUser; + use ras_jsonrpc_types::JsonRpcResponse; + use std::collections::HashMap; + use std::collections::HashSet; + use std::sync::Mutex; + use tokio::sync::oneshot; + + /// Tiny stub manager: enough state to exercise the default `connection_*` + /// methods and the `ConnectionManagerExt` helpers without dragging in the + /// full `DefaultConnectionManager`. Uses sync `Mutex` for simplicity. + #[derive(Default)] + struct StubManager { + conns: Mutex>, + subs: Mutex>>, + sent: Mutex>, + broadcasts: Mutex>, + } + + #[async_trait] + impl ConnectionManager for StubManager { + async fn add_connection(&self, info: ConnectionInfo) -> Result<()> { + self.conns.lock().unwrap().insert(info.id, info); + Ok(()) + } + async fn remove_connection(&self, id: ConnectionId) -> Result<()> { + self.conns + .lock() + .unwrap() + .remove(&id) + .ok_or(BidirectionalError::ConnectionNotFound(id))?; + Ok(()) + } + async fn get_connection(&self, id: ConnectionId) -> Result> { + Ok(self.conns.lock().unwrap().get(&id).cloned()) + } + async fn get_all_connections(&self) -> Result> { + Ok(self.conns.lock().unwrap().values().cloned().collect()) + } + async fn get_subscribed_connections(&self, topic: &str) -> Result> { + let ids = self + .subs + .lock() + .unwrap() + .get(topic) + .cloned() + .unwrap_or_default(); + let conns = self.conns.lock().unwrap(); + Ok(ids.iter().filter_map(|id| conns.get(id).cloned()).collect()) + } + async fn set_connection_user( + &self, + id: ConnectionId, + user: AuthenticatedUser, + ) -> Result<()> { + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.set_user(user); + Ok(()) + } else { + Err(BidirectionalError::ConnectionNotFound(id)) + } + } + async fn clear_connection_user(&self, id: ConnectionId) -> Result<()> { + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.clear_user(); + Ok(()) + } else { + Err(BidirectionalError::ConnectionNotFound(id)) + } + } + async fn add_subscription(&self, id: ConnectionId, topic: String) -> Result<()> { + self.subs + .lock() + .unwrap() + .entry(topic.clone()) + .or_default() + .insert(id); + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.subscribe(topic); + } + Ok(()) + } + async fn remove_subscription(&self, id: ConnectionId, topic: &str) -> Result<()> { + if let Some(set) = self.subs.lock().unwrap().get_mut(topic) { + set.remove(&id); + } + if let Some(info) = self.conns.lock().unwrap().get_mut(&id) { + info.unsubscribe(topic); + } + Ok(()) + } + async fn get_subscriptions(&self, id: ConnectionId) -> Result> { + Ok(self + .conns + .lock() + .unwrap() + .get(&id) + .map(|c| c.subscriptions.iter().cloned().collect()) + .unwrap_or_default()) + } + async fn send_to_connection( + &self, + id: ConnectionId, + message: BidirectionalMessage, + ) -> Result<()> { + self.sent.lock().unwrap().push((id, message)); + Ok(()) + } + async fn broadcast_to_topic( + &self, + topic: &str, + message: BidirectionalMessage, + ) -> Result { + let n = self + .subs + .lock() + .unwrap() + .get(topic) + .map(|s| s.len()) + .unwrap_or(0); + self.broadcasts + .lock() + .unwrap() + .push((topic.to_string(), message)); + Ok(n) + } + async fn broadcast_to_authenticated( + &self, + _message: BidirectionalMessage, + ) -> Result { + Ok(self.authenticated_connection_count().await?) + } + async fn broadcast_to_permission( + &self, + permission: &str, + _message: BidirectionalMessage, + ) -> Result { + Ok(self + .conns + .lock() + .unwrap() + .values() + .filter(|c| c.has_permission(permission)) + .count()) + } + async fn register_pending_request( + &self, + _connection_id: ConnectionId, + _request_id: serde_json::Value, + _response_sender: oneshot::Sender, + ) -> Result<()> { + Ok(()) + } + async fn remove_pending_request( + &self, + _connection_id: ConnectionId, + _request_id: &serde_json::Value, + ) -> Result>> { + Ok(None) + } + async fn handle_pending_response( + &self, + _connection_id: ConnectionId, + _response: JsonRpcResponse, + ) -> Result { + Ok(false) + } + } + + fn user(id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: id.to_string(), + permissions: perms.iter().map(|s| s.to_string()).collect(), + metadata: None, + } + } + + #[tokio::test] + async fn default_methods_delegate_to_required_methods() { + let mgr = StubManager::default(); + + // Initially nothing is registered. + assert_eq!(mgr.connection_count().await.unwrap(), 0); + assert_eq!(mgr.authenticated_connection_count().await.unwrap(), 0); + assert_eq!(mgr.cleanup_stale_connections().await.unwrap(), 0); + + let id1 = ConnectionId::new(); + let id2 = ConnectionId::new(); + mgr.add_connection(ConnectionInfo::new(id1)).await.unwrap(); + mgr.add_connection(ConnectionInfo::new(id2)).await.unwrap(); + + assert!(mgr.connection_exists(id1).await.unwrap()); + assert!(mgr.connection_exists(id2).await.unwrap()); + assert_eq!(mgr.connection_count().await.unwrap(), 2); + + // Authenticate one connection. + mgr.set_connection_user(id1, user("u1", &["read"])) + .await + .unwrap(); + assert_eq!(mgr.authenticated_connection_count().await.unwrap(), 1); + + // The default `add_connection_with_sender` must fall through to + // `add_connection`. + let id3 = ConnectionId::new(); + let dummy: Box = Box::new(()) as _; + mgr.add_connection_with_sender(ConnectionInfo::new(id3), dummy) + .await + .unwrap(); + assert!(mgr.connection_exists(id3).await.unwrap()); + } + + #[tokio::test] + async fn ext_helpers_route_to_correct_messages() { + let mgr = StubManager::default(); + let id = ConnectionId::new(); + mgr.add_connection(ConnectionInfo::new(id)).await.unwrap(); + mgr.add_subscription(id, "room:1".into()).await.unwrap(); + + // notify_connection wraps as ServerNotification. + mgr.notify_connection(id, "evt", serde_json::json!({"k": 1})) + .await + .unwrap(); + let sent = mgr.sent.lock().unwrap(); + assert_eq!(sent.len(), 1); + match &sent[0].1 { + BidirectionalMessage::ServerNotification(n) => assert_eq!(n.method, "evt"), + other => panic!("unexpected: {other:?}"), + } + drop(sent); + + // notify_topic broadcasts to the topic with one subscriber. + let n = mgr + .notify_topic("room:1", "msg", serde_json::json!("hi")) + .await + .unwrap(); + assert_eq!(n, 1); + let bs = mgr.broadcasts.lock().unwrap(); + assert!(matches!( + &bs[0].1, + BidirectionalMessage::Broadcast(BroadcastMessage { method, .. }) if method == "msg" + )); + drop(bs); + + // ping_connection should produce a Ping payload. + mgr.ping_connection(id).await.unwrap(); + let sent = mgr.sent.lock().unwrap(); + assert!(matches!(sent.last().unwrap().1, BidirectionalMessage::Ping)); + } + + #[tokio::test] + async fn user_helpers_filter_and_disconnect() { + let mgr = StubManager::default(); + let alice1 = ConnectionId::new(); + let alice2 = ConnectionId::new(); + let bob = ConnectionId::new(); + mgr.add_connection(ConnectionInfo::new(alice1)) + .await + .unwrap(); + mgr.add_connection(ConnectionInfo::new(alice2)) + .await + .unwrap(); + mgr.add_connection(ConnectionInfo::new(bob)).await.unwrap(); + mgr.set_connection_user(alice1, user("alice", &[])) + .await + .unwrap(); + mgr.set_connection_user(alice2, user("alice", &[])) + .await + .unwrap(); + mgr.set_connection_user(bob, user("bob", &[])) + .await + .unwrap(); + + assert_eq!(mgr.get_user_connections("alice").await.unwrap().len(), 2); + assert_eq!(mgr.get_user_connections("bob").await.unwrap().len(), 1); + assert_eq!(mgr.get_user_connections("nobody").await.unwrap().len(), 0); + + let dropped = mgr.disconnect_user("alice").await.unwrap(); + assert_eq!(dropped, 2); + assert!(!mgr.connection_exists(alice1).await.unwrap()); + assert!(!mgr.connection_exists(alice2).await.unwrap()); + // Bob unaffected. + assert!(mgr.connection_exists(bob).await.unwrap()); + } +} diff --git a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs index 9fb0be5..5d26954 100644 --- a/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs +++ b/crates/rpc/bidirectional/ras-jsonrpc-bidirectional-types/src/sender.rs @@ -245,4 +245,124 @@ mod tests { BidirectionalMessage::ServerNotification(n) if n.method == "test.method" )); } + + #[tokio::test] + async fn message_sender_ext_request_response_subscription() { + struct Recorder { + id: ConnectionId, + sent: Arc>>, + } + #[async_trait] + impl MessageSender for Recorder { + async fn send_message(&self, message: BidirectionalMessage) -> Result<()> { + self.sent.lock().await.push(message); + Ok(()) + } + async fn close(&self) -> Result<()> { + Ok(()) + } + async fn is_connected(&self) -> bool { + true + } + fn connection_id(&self) -> ConnectionId { + self.id + } + } + let r = Recorder { + id: ConnectionId::new(), + sent: Arc::new(Mutex::new(Vec::new())), + }; + + r.send_request(ras_jsonrpc_types::JsonRpcRequest { + jsonrpc: "2.0".into(), + method: "m".into(), + params: None, + id: Some(serde_json::json!(1)), + }) + .await + .unwrap(); + r.send_response(ras_jsonrpc_types::JsonRpcResponse::success( + serde_json::json!("ok"), + Some(serde_json::json!(1)), + )) + .await + .unwrap(); + r.send_subscription_update(vec!["t1".into()], true) + .await + .unwrap(); + r.send_subscription_update(vec!["t1".into()], false) + .await + .unwrap(); + + let s = r.sent.lock().await; + assert!(matches!(s[0], BidirectionalMessage::Request(_))); + assert!(matches!(s[1], BidirectionalMessage::Response(_))); + assert!(matches!(s[2], BidirectionalMessage::Subscribe { .. })); + assert!(matches!(s[3], BidirectionalMessage::Unsubscribe { .. })); + } + + #[tokio::test] + async fn noop_message_sender_round_trip() { + let id = ConnectionId::new(); + let sender = NoOpMessageSender::with_connection_id(id); + assert_eq!(sender.connection_id(), id); + assert!(sender.is_connected().await); + sender + .send_message(BidirectionalMessage::Ping) + .await + .unwrap(); + sender.close().await.unwrap(); + + // Default constructor + Default impl. + let s2 = NoOpMessageSender::new(); + let s3 = NoOpMessageSender::default(); + assert_ne!(s2.connection_id(), s3.connection_id()); + } + + #[tokio::test] + async fn websocket_sender_drives_real_sink() { + use futures::channel::mpsc; + use futures::stream::StreamExt; + + // mpsc::channel's Sender impls Sink, satisfying the SinkExt bound + // on `WebSocketMessageSender::new`. + let (tx, mut rx) = mpsc::channel::(8); + let id = ConnectionId::new(); + let sender = WebSocketMessageSender::new(id, tx); + + assert_eq!(sender.connection_id(), id); + assert!(sender.is_connected().await); + + sender + .send_message(BidirectionalMessage::Ping) + .await + .unwrap(); + // close once → emits a Close frame and flips is_closed. + sender.close().await.unwrap(); + assert!(!sender.is_connected().await); + // close again is idempotent (no panic, no extra send). + sender.close().await.unwrap(); + + // Sending after close yields ConnectionClosed. + let err = sender + .send_message(BidirectionalMessage::Pong) + .await + .unwrap_err(); + assert!(matches!(err, BidirectionalError::ConnectionClosed)); + + // Drain what we actually pushed: a Text(Ping) and a Close. + let mut received: Vec = Vec::new(); + while let Some(m) = rx.next().await { + received.push(m); + if received.len() == 2 { + break; + } + } + assert_eq!(received.len(), 2); + match &received[0] { + WsMessage::Text(t) => assert!(t.contains("ping")), + other => panic!("expected Text(ping), got {other:?}"), + } + assert!(matches!(received[1], WsMessage::Close(_))); + } } diff --git a/crates/rpc/ras-jsonrpc-macro/Cargo.toml b/crates/rpc/ras-jsonrpc-macro/Cargo.toml index 542d528..a7630eb 100644 --- a/crates/rpc/ras-jsonrpc-macro/Cargo.toml +++ b/crates/rpc/ras-jsonrpc-macro/Cargo.toml @@ -46,3 +46,12 @@ axum = { workspace = true } ras-jsonrpc-core = { path = "../ras-jsonrpc-core" } ras-auth-core = { path = "../../core/ras-auth-core" } async-trait = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +ras-test-helpers = { path = "../../test-utils/ras-test-helpers" } +axum-test = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } + +[[bench]] +name = "dispatch" +harness = false diff --git a/crates/rpc/ras-jsonrpc-macro/benches/dispatch.rs b/crates/rpc/ras-jsonrpc-macro/benches/dispatch.rs new file mode 100644 index 0000000..ff87b8f --- /dev/null +++ b/crates/rpc/ras-jsonrpc-macro/benches/dispatch.rs @@ -0,0 +1,66 @@ +//! Criterion bench measuring per-call latency of an authenticated JSON-RPC +//! method through the full stack: generated client → axum router → handler. +//! +//! Run with `cargo bench -p ras-jsonrpc-macro`. + +use criterion::{Criterion, criterion_group, criterion_main}; +use ras_jsonrpc_macro::jsonrpc_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AddRequest { + a: i64, + b: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AddResponse { + sum: i64, +} + +jsonrpc_service!({ + service_name: BenchSvc, + openrpc: false, + methods: [ + WITH_PERMISSIONS(["user"]) add(AddRequest) -> AddResponse, + ] +}); + +fn build_router() -> axum::Router { + BenchSvcBuilder::new("/rpc") + .auth_provider(MockAuthProvider::default()) + .add_handler(|_user, req: AddRequest| async move { Ok(AddResponse { sum: req.a + req.b }) }) + .build() + .expect("router build") +} + +fn bench_dispatch(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + // Spin the server up once and reuse across every iteration. + let (client, _server) = rt.block_on(async { + let server = spawn_http(build_router()); + let url = server.server_url("/rpc").unwrap().to_string(); + let mut client = BenchSvcClientBuilder::new() + .server_url(url) + .build() + .expect("client build"); + client.set_bearer_token(Some("user-token".to_string())); + (client, server) + }); + + c.bench_function("jsonrpc_add_dispatch", |b| { + b.to_async(&rt).iter(|| { + let client = client.clone(); + async move { + let r = client.add(AddRequest { a: 1, b: 2 }).await.expect("add ok"); + std::hint::black_box(r); + } + }); + }); +} + +criterion_group!(benches, bench_dispatch); +criterion_main!(benches); diff --git a/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs b/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs index aa5a710..180cd42 100644 --- a/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs +++ b/crates/rpc/ras-jsonrpc-macro/examples/explorer_params_demo.rs @@ -178,15 +178,15 @@ impl UserManagementServiceTrait for UserManagementServiceImpl { let filtered_users: Vec = users .into_iter() .filter(|u| { - if let Some(pattern) = &req.username_pattern { - if !u.username.contains(pattern) { - return false; - } + if let Some(pattern) = &req.username_pattern + && !u.username.contains(pattern) + { + return false; } - if let Some(pattern) = &req.email_pattern { - if !u.email.contains(pattern) { - return false; - } + if let Some(pattern) = &req.email_pattern + && !u.email.contains(pattern) + { + return false; } true }) diff --git a/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html b/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html index 6194b56..9d0eb48 100644 --- a/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html +++ b/crates/rpc/ras-jsonrpc-macro/src/jsonrpc_explorer_template.html @@ -571,6 +571,10 @@
+
@@ -709,7 +713,7 @@

Select a method

// Global state let openrpcDoc = null; let currentMethod = null; - let jwtToken = localStorage.getItem('jwt-token') || ''; + let jwtToken = ''; // Initialize the application document.addEventListener('DOMContentLoaded', async () => { @@ -736,9 +740,13 @@

Select a method

function initializeAuth() { const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); + const rememberedToken = sessionStorage.getItem('jwt-token') || ''; - if (jwtToken) { + if (rememberedToken) { + jwtToken = rememberedToken; tokenInput.value = jwtToken; + rememberToken.checked = true; authStatus.classList.add('authenticated'); } } @@ -772,13 +780,17 @@

Select a method

function saveToken() { const tokenInput = document.getElementById('jwt-token'); const authStatus = document.getElementById('auth-status'); + const rememberToken = document.getElementById('remember-token'); jwtToken = tokenInput.value.trim(); - localStorage.setItem('jwt-token', jwtToken); + sessionStorage.removeItem('jwt-token'); + if (jwtToken && rememberToken.checked) { + sessionStorage.setItem('jwt-token', jwtToken); + } if (jwtToken) { authStatus.classList.add('authenticated'); - showToast('Token saved successfully', 'success'); + showToast(rememberToken.checked ? 'Token saved for this tab' : 'Token ready for this page', 'success'); } else { authStatus.classList.remove('authenticated'); } @@ -787,8 +799,9 @@

Select a method

// Clear JWT token function clearToken() { jwtToken = ''; - localStorage.removeItem('jwt-token'); + sessionStorage.removeItem('jwt-token'); document.getElementById('jwt-token').value = ''; + document.getElementById('remember-token').checked = false; document.getElementById('auth-status').classList.remove('authenticated'); showToast('Token cleared', 'success'); } @@ -1196,4 +1209,4 @@

Authentication Required

} - \ No newline at end of file + diff --git a/crates/rpc/ras-jsonrpc-macro/tests/e2e.rs b/crates/rpc/ras-jsonrpc-macro/tests/e2e.rs new file mode 100644 index 0000000..8652a56 --- /dev/null +++ b/crates/rpc/ras-jsonrpc-macro/tests/e2e.rs @@ -0,0 +1,188 @@ +//! End-to-end test that exercises the full chain: +//! generated reqwest client → axum router → handler → response → client. +//! +//! Covers: success path, missing-permission rejection, malformed input. + +use ras_jsonrpc_macro::jsonrpc_service; +use ras_test_helpers::{MockAuthProvider, spawn_http}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct EchoRequest { + msg: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct EchoResponse { + msg: String, + user_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct AddRequest { + a: i64, + b: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +struct AddResponse { + sum: i64, +} + +jsonrpc_service!({ + service_name: Demo, + openrpc: false, + methods: [ + UNAUTHORIZED ping(EchoRequest) -> EchoResponse, + WITH_PERMISSIONS(["user"]) add(AddRequest) -> AddResponse, + WITH_PERMISSIONS(["admin"]) admin_only(EchoRequest) -> EchoResponse, + ] +}); + +fn router() -> axum::Router { + DemoBuilder::new("/rpc") + .auth_provider(MockAuthProvider::default()) + .ping_handler(|req: EchoRequest| async move { + Ok(EchoResponse { + msg: req.msg, + user_id: None, + }) + }) + .add_handler(|_user, req: AddRequest| async move { Ok(AddResponse { sum: req.a + req.b }) }) + .admin_only_handler(|user, req: EchoRequest| async move { + Ok(EchoResponse { + msg: req.msg, + user_id: Some(user.user_id), + }) + }) + .build() + .expect("build router") +} + +fn client(url: String) -> DemoClient { + DemoClientBuilder::new() + .server_url(url) + .build() + .expect("client build") +} + +#[tokio::test] +async fn unauth_method_round_trips() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").expect("server url").to_string(); + + let mut c = client(url); + c.set_bearer_token(Option::::None); + + let resp = c + .ping(EchoRequest { + msg: "hello".to_string(), + }) + .await + .expect("ping ok"); + + assert_eq!(resp.msg, "hello"); + assert_eq!(resp.user_id, None); +} + +#[tokio::test] +async fn permission_required_method_rejects_anonymous() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Option::::None); + + let err = c + .add(AddRequest { a: 2, b: 3 }) + .await + .expect_err("anonymous add must be rejected"); + + let s = err.to_string(); + assert!( + s.contains("Authentication") || s.contains("AUTH") || s.contains("auth"), + "expected auth-related error, got: {s}" + ); +} + +#[tokio::test] +async fn permission_required_method_rejects_wrong_perms() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Some("readonly-token".to_string())); + + let err = c + .add(AddRequest { a: 2, b: 3 }) + .await + .expect_err("readonly user must not be allowed to call add"); + let s = err.to_string(); + assert!( + s.contains("permission") || s.contains("Permission") || s.contains("PERMISSION"), + "expected permission-related error, got: {s}" + ); +} + +#[tokio::test] +async fn permission_required_method_succeeds_with_correct_perms() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Some("user-token".to_string())); + + let resp = c.add(AddRequest { a: 7, b: 35 }).await.expect("add ok"); + assert_eq!(resp.sum, 42); +} + +#[tokio::test] +async fn admin_method_succeeds_with_admin_token() { + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let mut c = client(url); + c.set_bearer_token(Some("admin-token".to_string())); + + let resp = c + .admin_only(EchoRequest { + msg: "secret".to_string(), + }) + .await + .expect("admin call ok"); + + assert_eq!(resp.msg, "secret"); + assert_eq!(resp.user_id.as_deref(), Some("admin-1")); +} + +#[tokio::test] +async fn malformed_params_yield_jsonrpc_error() { + // Bypass the typed client to send a malformed body and confirm the + // server returns a JSON-RPC `invalid_params` error rather than a panic. + let server = spawn_http(router()); + let url = server.server_url("/rpc").unwrap().to_string(); + + let body = serde_json::json!({ + "jsonrpc": "2.0", + "method": "ping", + "params": { "bogus": 1 }, + "id": 1, + }); + + let resp: serde_json::Value = reqwest::Client::new() + .post(url) + .json(&body) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + assert!( + resp.get("error").is_some(), + "expected error in response: {resp}" + ); + let code = resp["error"]["code"].as_i64().unwrap(); + assert_eq!(code, -32602, "expected invalid_params (-32602), got {code}"); +} diff --git a/crates/rpc/ras-jsonrpc-macro/tests/explorer_token_storage_test.rs b/crates/rpc/ras-jsonrpc-macro/tests/explorer_token_storage_test.rs new file mode 100644 index 0000000..cd5704a --- /dev/null +++ b/crates/rpc/ras-jsonrpc-macro/tests/explorer_token_storage_test.rs @@ -0,0 +1,8 @@ +#[test] +fn test_generated_explorer_does_not_store_jwt_in_local_storage() { + let template = include_str!("../src/jsonrpc_explorer_template.html"); + assert!(!template.contains("localStorage.getItem('jwt-token')")); + assert!(!template.contains("localStorage.setItem('jwt-token'")); + assert!(!template.contains("localStorage.removeItem('jwt-token'")); + assert!(template.contains("sessionStorage.setItem('jwt-token'")); +} diff --git a/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs b/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs index 79f4cc2..3d7b370 100644 --- a/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs +++ b/crates/rpc/ras-jsonrpc-macro/tests/http_integration.rs @@ -1,7 +1,6 @@ use rand::Rng; use ras_jsonrpc_core::{AuthError, AuthFuture, AuthProvider, AuthenticatedUser}; use ras_jsonrpc_macro::jsonrpc_service; -use reqwest; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::collections::HashSet; @@ -61,12 +60,6 @@ struct ProcessingResult { success: bool, } -#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] -struct ErrorResponse { - error: String, - details: Option, -} - // Simple test auth provider struct TestAuthProvider { valid_tokens: HashSet, diff --git a/crates/rpc/ras-jsonrpc-types/src/lib.rs b/crates/rpc/ras-jsonrpc-types/src/lib.rs index b574757..6ec268a 100644 --- a/crates/rpc/ras-jsonrpc-types/src/lib.rs +++ b/crates/rpc/ras-jsonrpc-types/src/lib.rs @@ -202,3 +202,76 @@ impl JsonRpcError { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn jsonrpc_request_constructor_sets_version() { + let r = JsonRpcRequest::new( + "m".into(), + Some(serde_json::json!(1)), + Some(serde_json::json!("rid")), + ); + assert_eq!(r.jsonrpc, "2.0"); + assert_eq!(r.method, "m"); + } + + #[test] + fn jsonrpc_response_success_and_error() { + let s = JsonRpcResponse::success(serde_json::json!("ok"), Some(serde_json::json!(1))); + assert_eq!(s.jsonrpc, "2.0"); + assert!(s.error.is_none()); + assert_eq!(s.result, Some(serde_json::json!("ok"))); + + let e = JsonRpcResponse::error(JsonRpcError::parse_error(), Some(serde_json::json!(1))); + assert!(e.result.is_none()); + assert_eq!(e.error.unwrap().code, error_codes::PARSE_ERROR); + } + + #[test] + fn json_rpc_error_constructors_use_canonical_codes() { + assert_eq!(JsonRpcError::parse_error().code, error_codes::PARSE_ERROR); + assert_eq!( + JsonRpcError::invalid_request().code, + error_codes::INVALID_REQUEST + ); + let nf = JsonRpcError::method_not_found("m"); + assert_eq!(nf.code, error_codes::METHOD_NOT_FOUND); + assert!(nf.message.contains("m")); + assert_eq!( + JsonRpcError::invalid_params("bad".into()).code, + error_codes::INVALID_PARAMS + ); + assert_eq!( + JsonRpcError::internal_error("e".into()).code, + error_codes::INTERNAL_ERROR + ); + assert_eq!( + JsonRpcError::authentication_required().code, + error_codes::AUTHENTICATION_REQUIRED + ); + assert_eq!( + JsonRpcError::token_expired().code, + error_codes::TOKEN_EXPIRED + ); + } + + #[test] + fn insufficient_permissions_carries_data() { + let err = JsonRpcError::insufficient_permissions(vec!["admin".into()], vec!["user".into()]); + assert_eq!(err.code, error_codes::INSUFFICIENT_PERMISSIONS); + let data = err.data.unwrap(); + assert_eq!(data["required"], serde_json::json!(["admin"])); + assert_eq!(data["has"], serde_json::json!(["user"])); + } + + #[test] + fn request_with_no_id_skips_field_in_serialization() { + let req = JsonRpcRequest::new("notify".into(), None, None); + let s = serde_json::to_string(&req).unwrap(); + assert!(!s.contains("\"id\"")); + assert!(!s.contains("\"params\"")); + } +} diff --git a/crates/specs/openrpc-types/src/schema.rs b/crates/specs/openrpc-types/src/schema.rs index 7b776b4..68c6431 100644 --- a/crates/specs/openrpc-types/src/schema.rs +++ b/crates/specs/openrpc-types/src/schema.rs @@ -484,45 +484,45 @@ impl Default for Schema { impl Validate for Schema { fn validate(&self) -> OpenRpcResult<()> { // Validate numeric constraints - if let (Some(min), Some(max)) = (self.minimum, self.maximum) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minimum cannot be greater than maximum", - )); - } + if let (Some(min), Some(max)) = (self.minimum, self.maximum) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minimum cannot be greater than maximum", + )); } - if let (Some(min), Some(max)) = (self.min_length, self.max_length) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minLength cannot be greater than maxLength", - )); - } + if let (Some(min), Some(max)) = (self.min_length, self.max_length) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minLength cannot be greater than maxLength", + )); } - if let (Some(min), Some(max)) = (self.min_items, self.max_items) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minItems cannot be greater than maxItems", - )); - } + if let (Some(min), Some(max)) = (self.min_items, self.max_items) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minItems cannot be greater than maxItems", + )); } - if let (Some(min), Some(max)) = (self.min_properties, self.max_properties) { - if min > max { - return Err(crate::error::OpenRpcError::validation( - "minProperties cannot be greater than maxProperties", - )); - } + if let (Some(min), Some(max)) = (self.min_properties, self.max_properties) + && min > max + { + return Err(crate::error::OpenRpcError::validation( + "minProperties cannot be greater than maxProperties", + )); } // Validate multipleOf - if let Some(multiple_of) = self.multiple_of { - if multiple_of <= 0.0 { - return Err(crate::error::OpenRpcError::validation( - "multipleOf must be greater than 0", - )); - } + if let Some(multiple_of) = self.multiple_of + && multiple_of <= 0.0 + { + return Err(crate::error::OpenRpcError::validation( + "multipleOf must be greater than 0", + )); } // Validate pattern if present diff --git a/crates/test-utils/ras-test-helpers/Cargo.toml b/crates/test-utils/ras-test-helpers/Cargo.toml new file mode 100644 index 0000000..68bab77 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "ras-test-helpers" +version = "0.0.0" +edition = "2024" +publish = false +description = "Internal test helpers for the rust-agent-stack workspace (dev-dependency only)" + +[dependencies] +ras-auth-core = { path = "../../core/ras-auth-core" } +axum = { workspace = true } +axum-test = { workspace = true } +tokio = { workspace = true } diff --git a/crates/test-utils/ras-test-helpers/src/auth.rs b/crates/test-utils/ras-test-helpers/src/auth.rs new file mode 100644 index 0000000..5c18644 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/src/auth.rs @@ -0,0 +1,133 @@ +use std::collections::{HashMap, HashSet}; + +use ras_auth_core::{AuthError, AuthFuture, AuthProvider, AuthenticatedUser}; + +/// A small fixed-token auth provider for tests. +/// +/// The default token table: +/// - `"user-token"` → user `user-1`, perms `["user"]` +/// - `"admin-token"` → user `admin-1`, perms `["admin", "user"]` +/// - `"readonly-token"` → user `ro-1`, perms `["read"]` +/// +/// Any other (or empty) token returns [`AuthError::InvalidToken`]. +#[derive(Clone, Debug)] +pub struct MockAuthProvider { + table: HashMap, +} + +impl Default for MockAuthProvider { + fn default() -> Self { + let mut table = HashMap::new(); + table.insert("user-token".to_string(), mock_user("user-1", &["user"])); + table.insert( + "admin-token".to_string(), + mock_user("admin-1", &["admin", "user"]), + ); + table.insert("readonly-token".to_string(), mock_user("ro-1", &["read"])); + Self { table } + } +} + +impl MockAuthProvider { + /// New empty auth provider with no recognized tokens. + pub fn empty() -> Self { + Self { + table: HashMap::new(), + } + } + + /// Insert or replace a token → user mapping. Useful for adding bespoke + /// fixtures on top of the default table. + pub fn with_token(mut self, token: impl Into, user: AuthenticatedUser) -> Self { + self.table.insert(token.into(), user); + self + } +} + +impl AuthProvider for MockAuthProvider { + fn authenticate(&self, token: String) -> AuthFuture<'_> { + let result = self + .table + .get(&token) + .cloned() + .ok_or(AuthError::InvalidToken); + Box::pin(async move { result }) + } +} + +/// Build an [`AuthenticatedUser`] from a string id and a slice of permission +/// names. Convenience for tests that need to construct a user by hand. +pub fn mock_user(user_id: &str, perms: &[&str]) -> AuthenticatedUser { + AuthenticatedUser { + user_id: user_id.to_string(), + permissions: perms + .iter() + .map(|p| (*p).to_string()) + .collect::>(), + metadata: None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mock_user_builds_expected_fields() { + let u = mock_user("alice", &["a", "b"]); + assert_eq!(u.user_id, "alice"); + assert!(u.permissions.contains("a")); + assert!(u.permissions.contains("b")); + assert!(u.metadata.is_none()); + } + + #[tokio::test] + async fn default_provider_resolves_well_known_tokens() { + let p = MockAuthProvider::default(); + let user = p.authenticate("user-token".to_string()).await.unwrap(); + assert_eq!(user.user_id, "user-1"); + assert!(user.permissions.contains("user")); + + let admin = p.authenticate("admin-token".to_string()).await.unwrap(); + assert!(admin.permissions.contains("admin")); + assert!(admin.permissions.contains("user")); + + let ro = p.authenticate("readonly-token".to_string()).await.unwrap(); + assert!(ro.permissions.contains("read")); + + let err = p + .authenticate("totally-bogus".to_string()) + .await + .unwrap_err(); + assert!(matches!(err, ras_auth_core::AuthError::InvalidToken)); + } + + #[tokio::test] + async fn empty_provider_rejects_everything() { + let p = MockAuthProvider::empty(); + let err = p.authenticate("user-token".to_string()).await.unwrap_err(); + assert!(matches!(err, ras_auth_core::AuthError::InvalidToken)); + } + + #[tokio::test] + async fn with_token_extends_table() { + let p = MockAuthProvider::empty().with_token("custom", mock_user("zed", &["god"])); + let user = p.authenticate("custom".to_string()).await.unwrap(); + assert_eq!(user.user_id, "zed"); + assert!(user.permissions.contains("god")); + } + + #[test] + fn check_permissions_returns_specific_error() { + let p = MockAuthProvider::default(); + let user = mock_user("u", &["read"]); + // Has the permission → ok. + p.check_permissions(&user, &["read".into()]).unwrap(); + // Missing → InsufficientPermissions. + let err = p.check_permissions(&user, &["admin".into()]).unwrap_err(); + assert!(matches!( + err, + ras_auth_core::AuthError::InsufficientPermissions { .. } + )); + } +} diff --git a/crates/test-utils/ras-test-helpers/src/lib.rs b/crates/test-utils/ras-test-helpers/src/lib.rs new file mode 100644 index 0000000..222ceb8 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/src/lib.rs @@ -0,0 +1,11 @@ +//! Internal test helpers shared across the rust-agent-stack workspace. +//! +//! This crate is `publish = false` and intended only as a `dev-dependency` for +//! integration tests and benches. It exists to avoid duplicating mock auth +//! providers and server-spawn boilerplate across crates. + +mod auth; +mod server; + +pub use auth::{MockAuthProvider, mock_user}; +pub use server::{spawn_http, spawn_tcp}; diff --git a/crates/test-utils/ras-test-helpers/src/server.rs b/crates/test-utils/ras-test-helpers/src/server.rs new file mode 100644 index 0000000..397ac16 --- /dev/null +++ b/crates/test-utils/ras-test-helpers/src/server.rs @@ -0,0 +1,40 @@ +use std::net::SocketAddr; + +use axum::Router; +use axum_test::TestServer; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; + +/// Spawn the given router behind an `axum-test::TestServer` configured with a +/// real TCP listener on a random port. The returned [`TestServer`] exposes a +/// real `http://127.0.0.1:PORT` URL via [`TestServer::server_address`], which +/// lets generated reqwest-based clients talk to it. +/// +/// Use this for HTTP / JSON-RPC over HTTP / file service tests. +pub fn spawn_http(router: Router) -> TestServer { + TestServer::builder() + .http_transport() + .build(router) + .expect("failed to start axum-test TestServer with http transport") +} + +/// Spawn the given router on a freshly-bound `127.0.0.1` port using a real +/// `axum::serve` task. Returns the bound address and the join handle for the +/// server task. Drop the handle to abort the server. +/// +/// Use this for WebSocket tests where the generated client uses +/// `tokio-tungstenite` and needs a genuine TCP socket. +pub async fn spawn_tcp(router: Router) -> (SocketAddr, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind ephemeral test port"); + let addr = listener + .local_addr() + .expect("failed to read local addr from test listener"); + + let handle = tokio::spawn(async move { + let _ = axum::serve(listener, router).await; + }); + + (addr, handle) +} diff --git a/crates/tools/openrpc-to-bruno/src/converter.rs b/crates/tools/openrpc-to-bruno/src/converter.rs index 6a500f5..9a02f56 100644 --- a/crates/tools/openrpc-to-bruno/src/converter.rs +++ b/crates/tools/openrpc-to-bruno/src/converter.rs @@ -6,6 +6,7 @@ use openrpc_types::{ Schema, SchemaType, }; use serde_json::{Map, Value}; +use std::path::Path; use tokio::fs; pub struct OpenRpcToBrunoConverter { @@ -240,7 +241,8 @@ impl OpenRpcToBrunoConverter { } let content = bruno_request.to_bru_format(); - let path = self.args.output.join(format!("{}.bru", method.name)); + let file_name = safe_method_file_name(&method.name, sequence)?; + let path = self.safe_output_path(&file_name).await?; fs::write(&path, content) .await @@ -252,6 +254,31 @@ impl OpenRpcToBrunoConverter { Ok(()) } + async fn safe_output_path(&self, file_name: &str) -> Result { + let output_dir = + fs::canonicalize(&self.args.output) + .await + .map_err(|e| ToolError::OutputDirCreate { + path: self.args.output.clone(), + source: e, + })?; + + let candidate = output_dir.join(file_name); + let parent = candidate.parent().unwrap_or(&output_dir); + let parent = fs::canonicalize(parent) + .await + .map_err(|e| ToolError::OutputDirCreate { + path: parent.to_path_buf(), + source: e, + })?; + + if !parent.starts_with(&output_dir) { + return Err(ToolError::UnsafeMethodName(file_name.to_string())); + } + + Ok(candidate) + } + fn create_jsonrpc_request_body(&self, method: &Method) -> Result { let mut request = Map::new(); request.insert("jsonrpc".to_string(), Value::String("2.0".to_string())); @@ -331,3 +358,34 @@ impl OpenRpcToBrunoConverter { } } } + +fn safe_method_file_name(method_name: &str, sequence: u32) -> Result { + let name = method_name.trim(); + if name.is_empty() + || name == "." + || name == ".." + || name.contains('\0') + || name.contains('/') + || name.contains('\\') + || Path::new(name).is_absolute() + { + return Err(ToolError::UnsafeMethodName(method_name.to_string())); + } + + let sanitized: String = name + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || matches!(ch, '.' | '_' | '-') { + ch + } else { + '_' + } + }) + .collect(); + + if sanitized.is_empty() || sanitized == "." || sanitized == ".." { + return Err(ToolError::UnsafeMethodName(method_name.to_string())); + } + + Ok(format!("{sequence:03}_{sanitized}.bru")) +} diff --git a/crates/tools/openrpc-to-bruno/src/error.rs b/crates/tools/openrpc-to-bruno/src/error.rs index f2001a5..7f3309a 100644 --- a/crates/tools/openrpc-to-bruno/src/error.rs +++ b/crates/tools/openrpc-to-bruno/src/error.rs @@ -32,6 +32,9 @@ pub enum ToolError { source: std::io::Error, }, + #[error("Unsafe OpenRPC method name for file generation: {0}")] + UnsafeMethodName(String), + #[error("Invalid base URL: {0}")] #[allow(dead_code)] InvalidBaseUrl(String), diff --git a/crates/tools/openrpc-to-bruno/tests/integration.rs b/crates/tools/openrpc-to-bruno/tests/integration.rs index 7430aad..15f62a7 100644 --- a/crates/tools/openrpc-to-bruno/tests/integration.rs +++ b/crates/tools/openrpc-to-bruno/tests/integration.rs @@ -35,8 +35,10 @@ async fn test_conversion( assert!(env_file.exists(), "environment file should be created"); // Check that method files were created - for method in expected_methods { - let method_file = output_dir.path().join(format!("{}.bru", method)); + for (index, method) in expected_methods.iter().enumerate() { + let method_file = output_dir + .path() + .join(format!("{:03}_{}.bru", index + 1, method)); assert!( method_file.exists(), "method file {} should be created", @@ -62,6 +64,92 @@ async fn test_conversion( Ok(()) } +#[tokio::test] +async fn test_rejects_path_traversal_method_name() { + use clap::Parser; + use openrpc_to_bruno::{cli::Args, error::ToolError}; + + let temp = tempdir().unwrap(); + let input_path = temp.path().join("openrpc.json"); + let output_dir = temp.path().join("out"); + let escaped = temp.path().join("evil.bru"); + + let spec = serde_json::json!({ + "openrpc": "1.3.2", + "info": { + "title": "Unsafe API", + "version": "1.0.0" + }, + "methods": [{ + "name": "../evil", + "params": [], + "result": { + "name": "result", + "schema": { "type": "string" } + } + }] + }); + fs::write(&input_path, serde_json::to_vec(&spec).unwrap()) + .await + .unwrap(); + + let args = Args::try_parse_from(vec![ + "openrpc-to-bruno", + "--input", + input_path.to_str().unwrap(), + "--output", + output_dir.to_str().unwrap(), + "--force", + ]) + .unwrap(); + + let err = args.run().await.unwrap_err(); + assert!(matches!(err, ToolError::UnsafeMethodName(_))); + assert!(!escaped.exists()); +} + +#[tokio::test] +async fn test_sanitizes_safe_method_filename() { + use clap::Parser; + use openrpc_to_bruno::cli::Args; + + let temp = tempdir().unwrap(); + let input_path = temp.path().join("openrpc.json"); + let output_dir = temp.path().join("out"); + + let spec = serde_json::json!({ + "openrpc": "1.3.2", + "info": { + "title": "Safe API", + "version": "1.0.0" + }, + "methods": [{ + "name": "system status", + "params": [], + "result": { + "name": "result", + "schema": { "type": "string" } + } + }] + }); + fs::write(&input_path, serde_json::to_vec(&spec).unwrap()) + .await + .unwrap(); + + let args = Args::try_parse_from(vec![ + "openrpc-to-bruno", + "--input", + input_path.to_str().unwrap(), + "--output", + output_dir.to_str().unwrap(), + "--force", + ]) + .unwrap(); + + args.run().await.unwrap(); + assert!(output_dir.join("001_system_status.bru").exists()); +} + #[tokio::test] async fn test_simple_conversion() { test_conversion("simple-api-basic.json", &["hello"]) diff --git a/examples/bidirectional-chat/server/src/main.rs b/examples/bidirectional-chat/server/src/main.rs index 52cf3a1..840cb04 100644 --- a/examples/bidirectional-chat/server/src/main.rs +++ b/examples/bidirectional-chat/server/src/main.rs @@ -1500,6 +1500,7 @@ async fn main() -> Result<()> { ); let session_service = Arc::new( SessionService::new(session_config) + .map_err(anyhow::Error::from)? .with_permissions(Arc::new(ChatPermissions::new(config.admin.users.clone()))), ); diff --git a/examples/bidirectional-chat/server/tests/server_tests.rs b/examples/bidirectional-chat/server/tests/server_tests.rs index 29f7a15..d784052 100644 --- a/examples/bidirectional-chat/server/tests/server_tests.rs +++ b/examples/bidirectional-chat/server/tests/server_tests.rs @@ -123,7 +123,7 @@ async fn create_test_config() -> Result<(Config, TempDir)> { cors: Default::default(), }, auth: AuthConfig { - jwt_secret: "test-secret-key".to_string(), + jwt_secret: "test-secret-key-that-is-long-enough".to_string(), jwt_ttl_seconds: 3600, refresh_enabled: true, jwt_algorithm: "HS256".to_string(), diff --git a/examples/bidirectional-chat/server/tests/websocket_tests.rs b/examples/bidirectional-chat/server/tests/websocket_tests.rs index f9e7b9a..d983b4c 100644 --- a/examples/bidirectional-chat/server/tests/websocket_tests.rs +++ b/examples/bidirectional-chat/server/tests/websocket_tests.rs @@ -58,7 +58,7 @@ impl TestChatServer { cors: Default::default(), }, auth: AuthConfig { - jwt_secret: "test-secret-key".to_string(), + jwt_secret: "test-secret-key-that-is-long-enough".to_string(), jwt_ttl_seconds: 3600, refresh_enabled: true, jwt_algorithm: "HS256".to_string(), @@ -145,9 +145,13 @@ impl TestChatServer { algorithm: jsonwebtoken::Algorithm::HS256, }; - let session_service = Arc::new(SessionService::new(session_config).with_permissions( - Arc::new(TestChatPermissions::new(config.admin.users.clone())), - )); + let session_service = Arc::new( + SessionService::new(session_config) + .unwrap() + .with_permissions(Arc::new(TestChatPermissions::new( + config.admin.users.clone(), + ))), + ); // Register identity provider with session service let session_identity_provider = LocalUserProvider::new(); diff --git a/examples/file-service-example/src/main.rs b/examples/file-service-example/src/main.rs index 8d241fc..3b8ae92 100644 --- a/examples/file-service-example/src/main.rs +++ b/examples/file-service-example/src/main.rs @@ -102,34 +102,37 @@ impl DocumentServiceTrait for DocumentServiceImpl { ) -> Result { println!("User {} is uploading a file", user.user_id); - // Process the multipart upload - while let Some(field) = multipart.next_field().await.map_err(|e| { - DocumentServiceFileError::UploadFailed(format!("Failed to get next field: {}", e)) - })? { - let name = field.name().unwrap_or("unknown").to_string(); - let file_name = field.file_name().unwrap_or("unknown").to_string(); - let data = field.bytes().await.map_err(|e| { - DocumentServiceFileError::UploadFailed(format!("Failed to read field data: {}", e)) + // Process the first multipart field — that's the uploaded file in the + // demo's contract. Real implementations would loop and accept several. + let field = multipart + .next_field() + .await + .map_err(|e| { + DocumentServiceFileError::UploadFailed(format!("Failed to get next field: {}", e)) + })? + .ok_or_else(|| { + DocumentServiceFileError::UploadFailed("No file in multipart data".to_string()) })?; - println!( - "Received field '{}' with filename '{}', size: {} bytes", - name, - file_name, - data.len() - ); - - // In a real implementation, you would save this to storage - return Ok(UploadResponse { - file_id: format!("file_{}", Uuid::new_v4()), - size: data.len() as u64, - filename: file_name, - }); - } - - Err(DocumentServiceFileError::UploadFailed( - "No file in multipart data".to_string(), - )) + let name = field.name().unwrap_or("unknown").to_string(); + let file_name = field.file_name().unwrap_or("unknown").to_string(); + let data = field.bytes().await.map_err(|e| { + DocumentServiceFileError::UploadFailed(format!("Failed to read field data: {}", e)) + })?; + + println!( + "Received field '{}' with filename '{}', size: {} bytes", + name, + file_name, + data.len() + ); + + // In a real implementation, you would save this to storage + Ok(UploadResponse { + file_id: format!("file_{}", Uuid::new_v4()), + size: data.len() as u64, + filename: file_name, + }) } async fn info( diff --git a/examples/oauth2-demo/server/src/main.rs b/examples/oauth2-demo/server/src/main.rs index 971e37c..038be11 100644 --- a/examples/oauth2-demo/server/src/main.rs +++ b/examples/oauth2-demo/server/src/main.rs @@ -45,7 +45,7 @@ impl AppConfig { redirect_uri: std::env::var("REDIRECT_URI") .unwrap_or_else(|_| "http://localhost:3000/auth/callback".to_string()), jwt_secret: std::env::var("JWT_SECRET") - .unwrap_or_else(|_| "change-me-in-production-please".to_string()), + .context("JWT_SECRET environment variable is required")?, server_host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()), server_port: std::env::var("SERVER_PORT") .unwrap_or_else(|_| "3000".to_string()) @@ -164,8 +164,9 @@ fn create_session_service(config: &AppConfig) -> Result { }; let permissions_provider = Arc::new(GoogleOAuth2Permissions::new()); - let session_service = - SessionService::new(session_config).with_permissions(permissions_provider); + let session_service = SessionService::new(session_config) + .map_err(anyhow::Error::from)? + .with_permissions(permissions_provider); Ok(session_service) } diff --git a/examples/rest-wasm-example/rest-api/Cargo.toml b/examples/rest-wasm-example/rest-api/Cargo.toml index a7ac0a6..fa68cfe 100644 --- a/examples/rest-wasm-example/rest-api/Cargo.toml +++ b/examples/rest-wasm-example/rest-api/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [lib] -crate-type = ["cdylib", "rlib"] +crate-type = ["rlib"] [dependencies] ras-rest-macro = { path = "../../../crates/rest/ras-rest-macro" }