diff --git a/crates/client-api-messages/DEVELOP.md b/crates/client-api-messages/DEVELOP.md index 47868d4d3ce..48341bd80aa 100644 --- a/crates/client-api-messages/DEVELOP.md +++ b/crates/client-api-messages/DEVELOP.md @@ -19,3 +19,12 @@ spacetime generate -p spacetimedb-cli --lang \ --out-dir \ --module-def ws_schema_v2.json ``` + +For the v3 WebSocket transport schema: + +```sh +cargo run --example get_ws_schema_v3 > ws_schema_v3.json +spacetime generate -p spacetimedb-cli --lang \ + --out-dir \ + --module-def ws_schema_v3.json +``` diff --git a/crates/client-api-messages/examples/get_ws_schema_v3.rs b/crates/client-api-messages/examples/get_ws_schema_v3.rs new file mode 100644 index 00000000000..b4a752a5664 --- /dev/null +++ b/crates/client-api-messages/examples/get_ws_schema_v3.rs @@ -0,0 +1,13 @@ +use spacetimedb_client_api_messages::websocket::v3::{ClientFrame, ServerFrame}; +use spacetimedb_lib::ser::serde::SerializeWrapper; +use spacetimedb_lib::{RawModuleDef, RawModuleDefV8}; + +fn main() -> Result<(), serde_json::Error> { + let module = RawModuleDefV8::with_builder(|module| { + module.add_type::(); + module.add_type::(); + }); + let module = RawModuleDef::V8BackCompat(module); + + serde_json::to_writer(std::io::stdout().lock(), SerializeWrapper::from_ref(&module)) +} diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 0935d2e3c55..14ec394670f 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -17,3 +17,4 @@ pub mod common; pub mod v1; pub mod v2; +pub mod v3; diff --git a/crates/client-api-messages/src/websocket/v3.rs b/crates/client-api-messages/src/websocket/v3.rs new file mode 100644 index 00000000000..5be37768299 --- /dev/null +++ b/crates/client-api-messages/src/websocket/v3.rs @@ -0,0 +1,28 @@ +use bytes::Bytes; +pub use spacetimedb_sats::SpacetimeType; + +pub const BIN_PROTOCOL: &str = "v3.bsatn.spacetimedb"; + +/// Transport envelopes sent by the client over the v3 websocket protocol. +/// +/// The inner bytes are BSATN-encoded v2 [`crate::websocket::v2::ClientMessage`] values. +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub enum ClientFrame { + /// A single logical client message. + Single(Bytes), + /// Multiple logical client messages that should be processed in-order. + Batch(Box<[Bytes]>), +} + +/// Transport envelopes sent by the server over the v3 websocket protocol. +/// +/// The inner bytes are BSATN-encoded v2 [`crate::websocket::v2::ServerMessage`] values. +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub enum ServerFrame { + /// A single logical server message. + Single(Bytes), + /// Multiple logical server messages that should be processed in-order. + Batch(Box<[Bytes]>), +} diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index d1bb1d2b11f..868e7425eda 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -23,8 +23,8 @@ use prometheus::{Histogram, IntGauge}; use scopeguard::{defer, ScopeGuard}; use serde::Deserialize; use spacetimedb::client::messages::{ - serialize, serialize_v2, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, - ToProtocol, + serialize, serialize_v2, serialize_v3, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, + SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, DataMessage, MessageExecutionError, @@ -38,6 +38,7 @@ use spacetimedb::worker_metrics::WORKER_METRICS; use spacetimedb::Identity; use spacetimedb_client_api_messages::websocket::v1 as ws_v1; use spacetimedb_client_api_messages::websocket::v2 as ws_v2; +use spacetimedb_client_api_messages::websocket::v3 as ws_v3; use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl}; use tokio::sync::{mpsc, watch}; @@ -62,6 +63,8 @@ pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v1::TEXT_PROT pub const BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v1::BIN_PROTOCOL); #[allow(clippy::declare_interior_mutable_const)] pub const V2_BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v2::BIN_PROTOCOL); +#[allow(clippy::declare_interior_mutable_const)] +pub const V3_BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v3::BIN_PROTOCOL); pub trait HasWebSocketOptions { fn websocket_options(&self) -> WebSocketOptions; @@ -101,7 +104,7 @@ fn resolve_confirmed_reads_default(version: WsVersion, confirmed: Option) } match version { WsVersion::V1 => false, - WsVersion::V2 => crate::DEFAULT_CONFIRMED_READS, + WsVersion::V2 | WsVersion::V3 => crate::DEFAULT_CONFIRMED_READS, } } @@ -151,6 +154,13 @@ where } let (res, ws_upgrade, protocol) = ws.select_protocol([ + ( + V3_BIN_PROTOCOL, + NegotiatedProtocol { + protocol: Protocol::Binary, + version: WsVersion::V3, + }, + ), ( V2_BIN_PROTOCOL, NegotiatedProtocol { @@ -284,7 +294,7 @@ where }; client.send_message(None, OutboundMessage::V1(message.into())) } - WsVersion::V2 => { + WsVersion::V2 | WsVersion::V3 => { let message = ws_v2::ServerMessage::InitialConnection(ws_v2::InitialConnection { identity: client_identity, connection_id, @@ -1293,10 +1303,15 @@ async fn ws_encode_task( // copied to the wire. Since we don't know when that will happen, we prepare // for a few messages to be in-flight, i.e. encoded but not yet sent. const BUF_POOL_CAPACITY: usize = 16; + let binary_message_serializer = match config.version { + WsVersion::V1 => None, + WsVersion::V2 => Some(serialize_v2 as BinarySerializeFn), + WsVersion::V3 => Some(serialize_v3 as BinarySerializeFn), + }; let buf_pool = ArrayQueue::new(BUF_POOL_CAPACITY); let mut in_use_bufs: Vec> = Vec::with_capacity(BUF_POOL_CAPACITY); - while let Some(message) = messages.recv().await { + 'send: while let Some(message) = messages.recv().await { // Drop serialize buffers with no external referent, // returning them to the pool. in_use_bufs.retain(|in_use| !in_use.is_unique()); @@ -1306,16 +1321,22 @@ async fn ws_encode_task( let in_use_buf = match message { OutboundWsMessage::Error(message) => { - if config.version == WsVersion::V2 { - log::error!("dropping v1 error message sent to a v2 client: {:?}", message); + if config.version != WsVersion::V1 { + log::error!( + "dropping v1 error message sent to a binary websocket client: {:?}", + message + ); continue; } - let (stats, in_use, mut frames) = ws_encode_message(config, buf, message, false, &bsatn_rlb_pool).await; - metrics.report(None, None, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + None, + None, + ws_encode_message(config, buf, message, false, &bsatn_rlb_pool).await, + ) else { + break 'send; + }; in_use } OutboundWsMessage::Message(message) => { @@ -1323,38 +1344,47 @@ async fn ws_encode_task( let num_rows = message.num_rows(); match message { OutboundMessage::V2(server_message) => { - if config.version != WsVersion::V2 { + if config.version == WsVersion::V1 { log::error!("dropping v2 message on v1 connection"); continue; } - let (stats, in_use, mut frames) = - ws_encode_message_v2(config, buf, server_message, false, &bsatn_rlb_pool).await; - metrics.report(workload, num_rows, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + workload, + num_rows, + ws_encode_binary_message( + config, + buf, + server_message, + binary_message_serializer.expect("v2 message should not be sent on a v1 connection"), + false, + &bsatn_rlb_pool, + ) + .await, + ) else { + break 'send; + }; in_use } OutboundMessage::V1(message) => { - if config.version == WsVersion::V2 { - log::error!( - "dropping v1 message for v2 connection until v2 serialization is implemented: {:?}", - message - ); + if config.version != WsVersion::V1 { + log::error!("dropping v1 message for a binary websocket connection: {:?}", message); continue; } let is_large = num_rows.is_some_and(|n| n > 1024); - let (stats, in_use, mut frames) = - ws_encode_message(config, buf, message, is_large, &bsatn_rlb_pool).await; - metrics.report(workload, num_rows, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + workload, + num_rows, + ws_encode_message(config, buf, message, is_large, &bsatn_rlb_pool).await, + ) else { + break 'send; + }; in_use } } @@ -1370,6 +1400,24 @@ async fn ws_encode_task( } } +/// Reports encode metrics for an already-encoded message and forwards all of +/// its frames to the websocket send task. +fn ws_forward_frames( + metrics: &SendMetrics, + outgoing_frames: &mpsc::UnboundedSender, + workload: Option, + num_rows: Option, + encoded: (EncodeMetrics, InUseSerializeBuffer, I), +) -> Result> +where + I: Iterator, +{ + let (stats, in_use, frames) = encoded; + metrics.report(workload, num_rows, stats); + frames.into_iter().try_for_each(|frame| outgoing_frames.send(frame))?; + Ok(in_use) +} + /// Some stats about serialization and compression. /// /// Returned by [`ws_encode_message`]. @@ -1443,21 +1491,29 @@ async fn ws_encode_message( (metrics, msg_alloc, frames) } -#[allow(dead_code, unused_variables)] -async fn ws_encode_message_v2( +type BinarySerializeFn = fn( + &BsatnRowListBuilderPool, + SerializeBuffer, + ws_v2::ServerMessage, + ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes); + +async fn ws_encode_binary_message( config: ClientConfig, buf: SerializeBuffer, message: ws_v2::ServerMessage, + serialize_message: BinarySerializeFn, is_large_message: bool, bsatn_rlb_pool: &BsatnRowListBuilderPool, ) -> (EncodeMetrics, InUseSerializeBuffer, impl Iterator + use<>) { let start = Instant::now(); + let compression = config.compression; let (in_use, data) = if is_large_message { let bsatn_rlb_pool = bsatn_rlb_pool.clone(); - spawn_rayon(move || serialize_v2(&bsatn_rlb_pool, buf, message, config.compression)).await + spawn_rayon(move || serialize_message(&bsatn_rlb_pool, buf, message, compression)).await } else { - serialize_v2(bsatn_rlb_pool, buf, message, config.compression) + serialize_message(bsatn_rlb_pool, buf, message, compression) }; let metrics = EncodeMetrics { @@ -2298,9 +2354,11 @@ mod tests { #[test] fn confirmed_reads_default_depends_on_ws_version() { + assert!(resolve_confirmed_reads_default(WsVersion::V3, None)); assert!(resolve_confirmed_reads_default(WsVersion::V2, None)); assert!(!resolve_confirmed_reads_default(WsVersion::V1, None)); assert!(resolve_confirmed_reads_default(WsVersion::V1, Some(true))); + assert!(!resolve_confirmed_reads_default(WsVersion::V3, Some(false))); assert!(!resolve_confirmed_reads_default(WsVersion::V2, Some(false))); } diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index cad4f79adcf..4411192c625 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -7,6 +7,7 @@ pub mod consume_each_list; mod message_handlers; mod message_handlers_v1; mod message_handlers_v2; +mod message_handlers_v3; pub mod messages; pub use client_connection::{ diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 6fb8d8e1623..0a7a7f1a11b 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -47,6 +47,7 @@ pub enum Protocol { pub enum WsVersion { V1, V2, + V3, } impl Protocol { @@ -384,7 +385,7 @@ impl ClientConnectionSender { debug_assert!( matches!( (&self.config.version, &message), - (WsVersion::V1, OutboundMessage::V1(_)) | (WsVersion::V2, OutboundMessage::V2(_)) + (WsVersion::V1, OutboundMessage::V1(_)) | (WsVersion::V2 | WsVersion::V3, OutboundMessage::V2(_)) ), "attempted to send message variant that does not match client websocket version" ); diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index 76f5fa53afa..fb85730c11c 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -23,5 +23,6 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst match client.config.version { WsVersion::V1 => super::message_handlers_v1::handle(client, message, timer).await, WsVersion::V2 => super::message_handlers_v2::handle(client, message, timer).await, + WsVersion::V3 => super::message_handlers_v3::handle(client, message, timer).await, } } diff --git a/crates/core/src/client/message_handlers_v2.rs b/crates/core/src/client/message_handlers_v2.rs index 5dd2f80d01b..2db523e472d 100644 --- a/crates/core/src/client/message_handlers_v2.rs +++ b/crates/core/src/client/message_handlers_v2.rs @@ -20,6 +20,14 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst ))) } }; + handle_decoded_message(client, message, timer).await +} + +pub(super) async fn handle_decoded_message( + client: &ClientConnection, + message: ws_v2::ClientMessage, + timer: Instant, +) -> Result<(), MessageHandleError> { let module = client.module(); let mod_info = module.info(); let mod_metrics = &mod_info.metrics; diff --git a/crates/core/src/client/message_handlers_v3.rs b/crates/core/src/client/message_handlers_v3.rs new file mode 100644 index 00000000000..696e7337ed0 --- /dev/null +++ b/crates/core/src/client/message_handlers_v3.rs @@ -0,0 +1,32 @@ +use super::{ClientConnection, DataMessage, MessageHandleError}; +use serde::de::Error as _; +use spacetimedb_client_api_messages::websocket::{v2 as ws_v2, v3 as ws_v3}; +use spacetimedb_lib::bsatn; +use std::time::Instant; + +pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Instant) -> Result<(), MessageHandleError> { + client.observe_websocket_request_message(&message); + let frame = match message { + DataMessage::Binary(message_buf) => bsatn::from_slice::(&message_buf)?, + DataMessage::Text(_) => { + return Err(MessageHandleError::TextDecode(serde_json::Error::custom( + "v3 websocket does not support text messages", + ))) + } + }; + + match frame { + ws_v3::ClientFrame::Single(message) => { + let message = bsatn::from_slice::(&message)?; + super::message_handlers_v2::handle_decoded_message(client, message, timer).await?; + } + ws_v3::ClientFrame::Batch(messages) => { + for message in messages { + let message = bsatn::from_slice::(&message)?; + super::message_handlers_v2::handle_decoded_message(client, message, timer).await?; + } + } + } + + Ok(()) +} diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index ed65e092d0e..38c5fadb260 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -10,6 +10,7 @@ use derive_more::From; use spacetimedb_client_api_messages::websocket::common::{self as ws_common, RowListLen as _}; use spacetimedb_client_api_messages::websocket::v1::{self as ws_v1}; use spacetimedb_client_api_messages::websocket::v2 as ws_v2; +use spacetimedb_client_api_messages::websocket::v3 as ws_v3; use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; @@ -97,6 +98,20 @@ impl SerializeBuffer { } } +fn finalize_binary_serialize_buffer( + buffer: SerializeBuffer, + uncompressed_len: usize, + compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + match decide_compression(uncompressed_len, compression) { + ws_v1::Compression::None => buffer.uncompressed(), + ws_v1::Compression::Brotli => { + buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) + } + ws_v1::Compression::Gzip => buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress), + } +} + type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>; pub enum InUseSerializeBuffer { @@ -159,21 +174,14 @@ pub fn serialize( let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { bsatn::to_writer(w.into_inner(), &msg).unwrap() }); + let srv_msg_len = srv_msg.len(); // At this point, we no longer have a use for `msg`, // so try to reclaim its buffers. msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); // Conditionally compress the message. - let (in_use, msg_bytes) = match decide_compression(srv_msg.len(), config.compression) { - ws_v1::Compression::None => buffer.uncompressed(), - ws_v1::Compression::Brotli => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) - } - ws_v1::Compression::Gzip => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress) - } - }; + let (in_use, msg_bytes) = finalize_binary_serialize_buffer(buffer, srv_msg_len, config.compression); (in_use, msg_bytes.into()) } } @@ -192,18 +200,40 @@ pub fn serialize_v2( let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { bsatn::to_writer(w.into_inner(), &msg).expect("should be able to bsatn encode v2 message"); }); + let srv_msg_len = srv_msg.len(); // At this point, we no longer have a use for `msg`, // so try to reclaim its buffers. msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); - match decide_compression(srv_msg.len(), compression) { - ws_v1::Compression::None => buffer.uncompressed(), - ws_v1::Compression::Brotli => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) - } - ws_v1::Compression::Gzip => buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress), - } + finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) +} + +/// Serialize `msg` into a [`DataMessage`] containing a [`ws_v3::ServerFrame::Single`] +/// whose payload is a BSATN-encoded [`ws_v2::ServerMessage`]. +/// +/// This mirrors the v2 framing by prepending the compression tag and applying +/// conditional compression when configured. +pub fn serialize_v3( + bsatn_rlb_pool: &BsatnRowListBuilderPool, + mut buffer: SerializeBuffer, + msg: ws_v2::ServerMessage, + compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + let mut inner = BytesMut::with_capacity(SERIALIZE_BUFFER_INIT_CAP); + bsatn::to_writer((&mut inner).writer().into_inner(), &msg).expect("should be able to bsatn encode v2 message"); + + // At this point, we no longer have a use for `msg`, + // so try to reclaim its buffers. + msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); + + let frame = ws_v3::ServerFrame::Single(inner.freeze()); + let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { + bsatn::to_writer(w.into_inner(), &frame).expect("should be able to bsatn encode v3 server frame"); + }); + let srv_msg_len = srv_msg.len(); + + finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) } #[derive(Debug, From)] diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 92f296f3b8c..4ab8b0c28a7 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1639,7 +1639,7 @@ impl ModuleSubscriptions { message, ); } - WsVersion::V2 => { + WsVersion::V2 | WsVersion::V3 => { if let Some(request_id) = event.request_id { self.send_reducer_failure_result_v2(client, &event, request_id); } diff --git a/sdks/rust/src/websocket.rs b/sdks/rust/src/websocket.rs index 235ef06138f..8e9cfdaae8a 100644 --- a/sdks/rust/src/websocket.rs +++ b/sdks/rust/src/websocket.rs @@ -2,7 +2,6 @@ //! //! This module is internal, and may incompatibly change without warning. -#[cfg(not(feature = "browser"))] use bytes::Bytes; #[cfg(not(feature = "browser"))] use futures::TryStreamExt; @@ -12,6 +11,8 @@ use http::uri::{InvalidUri, Scheme, Uri}; use spacetimedb_client_api_messages::websocket as ws; use spacetimedb_lib::{bsatn, ConnectionId}; #[cfg(not(feature = "browser"))] +use std::collections::VecDeque; +#[cfg(not(feature = "browser"))] use std::fs::File; #[cfg(not(feature = "browser"))] use std::io::Write; @@ -107,11 +108,51 @@ pub enum WsError { pub(crate) struct WsConnection { db_name: Box, #[cfg(not(feature = "browser"))] + protocol: NegotiatedWsProtocol, + #[cfg(not(feature = "browser"))] sock: WebSocketStream>, #[cfg(feature = "browser")] sock: WebSocketStream, } +#[cfg(not(feature = "browser"))] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +enum NegotiatedWsProtocol { + #[default] + V2, + V3, +} + +#[cfg(not(feature = "browser"))] +impl NegotiatedWsProtocol { + /// Maps the negotiated websocket subprotocol string onto the transport + /// framing rules understood by the native SDK. + fn from_negotiated_protocol(protocol: &str) -> Self { + match protocol { + ws::v3::BIN_PROTOCOL => Self::V3, + "" | ws::v2::BIN_PROTOCOL => Self::V2, + unknown => { + log::warn!( + "Unexpected websocket subprotocol \"{unknown}\", falling back to {}", + ws::v2::BIN_PROTOCOL + ); + Self::V2 + } + } + } +} + +#[cfg(not(feature = "browser"))] +#[allow(clippy::declare_interior_mutable_const)] +const V3_PREFERRED_PROTOCOL_HEADER: http::HeaderValue = + http::HeaderValue::from_static("v3.bsatn.spacetimedb, v2.bsatn.spacetimedb"); +#[cfg(not(feature = "browser"))] +const MAX_V3_OUTBOUND_FRAME_BYTES: usize = 256 * 1024; +#[cfg(not(feature = "browser"))] +const BSATN_SUM_TAG_BYTES: usize = 1; +#[cfg(not(feature = "browser"))] +const BSATN_LENGTH_PREFIX_BYTES: usize = 4; + fn parse_scheme(scheme: Option) -> Result { Ok(match scheme { Some(s) => match s.as_str() { @@ -245,10 +286,10 @@ fn make_request( #[cfg(not(feature = "browser"))] fn request_insert_protocol_header(req: &mut http::Request<()>) { - req.headers_mut().insert( - http::header::SEC_WEBSOCKET_PROTOCOL, - const { http::HeaderValue::from_static(ws::v2::BIN_PROTOCOL) }, - ); + // Prefer v3 for transport batching, but continue advertising v2 so older + // servers can negotiate the legacy wire format unchanged. + req.headers_mut() + .insert(http::header::SEC_WEBSOCKET_PROTOCOL, V3_PREFERRED_PROTOCOL_HEADER); } #[cfg(not(feature = "browser"))] @@ -259,6 +300,150 @@ fn request_insert_auth_header(req: &mut http::Request<()>, token: Option<&str>) } } +/// Decodes one logical v2 server message from an already-decompressed payload. +fn decode_v2_server_message(bytes: &[u8]) -> Result { + bsatn::from_slice(bytes).map_err(|source| WsError::DeserializeMessage { source }) +} + +/// Expands a v3 server frame into the ordered sequence of encoded inner v2 +/// server messages it carries. +#[cfg(not(feature = "browser"))] +fn flatten_server_frame(frame: ws::v3::ServerFrame) -> Box<[Bytes]> { + match frame { + ws::v3::ServerFrame::Single(message) => Box::new([message]), + ws::v3::ServerFrame::Batch(messages) => messages, + } +} + +/// Encodes one logical v2 client message into raw BSATN bytes. +fn encode_v2_client_message_bytes(msg: &ws::v2::ClientMessage) -> Bytes { + Bytes::from(bsatn::to_vec(msg).expect("should be able to bsatn encode v2 client message")) +} + +/// Wraps one or more encoded v2 client messages in a v3 transport frame. +#[cfg(not(feature = "browser"))] +fn encode_v3_client_frame(messages: Vec) -> Bytes { + let frame = if messages.len() == 1 { + ws::v3::ClientFrame::Single(messages.into_iter().next().unwrap()) + } else { + ws::v3::ClientFrame::Batch(messages.into_boxed_slice()) + }; + Bytes::from(bsatn::to_vec(&frame).expect("should be able to bsatn encode v3 client frame")) +} + +/// Returns the encoded size of a v3 `Single` frame carrying `message`. +#[cfg(not(feature = "browser"))] +fn encoded_v3_single_frame_size(message: &Bytes) -> usize { + BSATN_SUM_TAG_BYTES + BSATN_LENGTH_PREFIX_BYTES + message.len() +} + +/// Returns the encoded size of a v3 `Batch` frame containing only its first logical message. +#[cfg(not(feature = "browser"))] +fn encoded_v3_batch_frame_size_for_first_message(message: &Bytes) -> usize { + BSATN_SUM_TAG_BYTES + BSATN_LENGTH_PREFIX_BYTES + BSATN_LENGTH_PREFIX_BYTES + message.len() +} + +/// Returns the encoded contribution of one additional logical message inside a v3 `Batch` frame. +#[cfg(not(feature = "browser"))] +fn encoded_v3_batch_element_size(message: &Bytes) -> usize { + BSATN_LENGTH_PREFIX_BYTES + message.len() +} + +/// Builds one bounded v3 transport frame from `first_message` and as many +/// queued logical messages as fit under the configured frame-size cap. +#[cfg(not(feature = "browser"))] +fn encode_v3_outbound_frame( + first_message: ws::v2::ClientMessage, + pending_outgoing: &mut VecDeque, + mut try_next_outgoing_now: F, +) -> Bytes +where + F: FnMut() -> Option, +{ + let first_message = encode_v2_client_message_bytes(&first_message); + // Oversized logical messages are still sent alone so they cannot block the + // queue forever behind the frame-size limit. + if encoded_v3_single_frame_size(&first_message) > MAX_V3_OUTBOUND_FRAME_BYTES { + if pending_outgoing.is_empty() + && let Some(next_message) = try_next_outgoing_now() + { + pending_outgoing.push_front(next_message); + } + + return encode_v3_client_frame(vec![first_message]); + } + + let mut messages = vec![first_message]; + let mut batch_size = encoded_v3_batch_frame_size_for_first_message(messages.first().unwrap()); + + loop { + let Some(next_message) = pending_outgoing.pop_front().or_else(&mut try_next_outgoing_now) else { + break; + }; + let next_message_bytes = encode_v2_client_message_bytes(&next_message); + let next_batch_size = batch_size + encoded_v3_batch_element_size(&next_message_bytes); + if next_batch_size > MAX_V3_OUTBOUND_FRAME_BYTES { + pending_outgoing.push_front(next_message); + break; + } + batch_size = next_batch_size; + messages.push(next_message_bytes); + } + + encode_v3_client_frame(messages) +} + +/// Encodes the next outbound logical message according to the negotiated +/// transport and reports whether a capped v3 flush left queued work behind. +#[cfg(not(feature = "browser"))] +fn encode_outgoing_message( + protocol: NegotiatedWsProtocol, + first_message: ws::v2::ClientMessage, + pending_outgoing: &mut VecDeque, + try_next_outgoing_now: F, +) -> (Bytes, bool) +where + F: FnMut() -> Option, +{ + match protocol { + NegotiatedWsProtocol::V2 => (encode_v2_client_message_bytes(&first_message), false), + NegotiatedWsProtocol::V3 => { + let frame = encode_v3_outbound_frame(first_message, pending_outgoing, try_next_outgoing_now); + (frame, !pending_outgoing.is_empty()) + } + } +} + +/// Parses one native websocket payload and forwards each decoded logical v2 +/// server message to the SDK's inbound queue, logging decode or enqueue +/// failures locally. +#[cfg(not(feature = "browser"))] +fn forward_parsed_responses_native( + protocol: NegotiatedWsProtocol, + incoming_messages: &mpsc::UnboundedSender, + extra_logging: &Option>>, + bytes: &[u8], +) { + match WsConnection::parse_responses(protocol, bytes) { + Err(e) => { + debug_log(extra_logging, |file| { + writeln!(file, "Error decoding WebSocketMessage::Binary payload: {e:?}") + }); + log::warn!("Error decoding WebSocketMessage::Binary payload: {e:?}"); + } + Ok(messages) => { + for msg in messages { + if let Err(e) = incoming_messages.unbounded_send(msg) { + debug_log(extra_logging, |file| { + writeln!(file, "Error sending decoded message to incoming_messages queue: {e:?}") + }); + log::warn!("Error sending decoded message to incoming_messages queue: {e:?}"); + } + } + } + } +} + #[cfg(feature = "browser")] async fn fetch_ws_token(host: &Uri, auth_token: &str) -> Result { use gloo_net::http::{Method, RequestBuilder}; @@ -334,7 +519,7 @@ impl WsConnection { // Grab the URI for error-reporting. let uri = req.uri().clone(); - let (sock, _): (WebSocketStream>, _) = connect_async_with_config( + let (sock, response): (WebSocketStream>, _) = connect_async_with_config( req, // TODO(kim): In order to be able to replicate module WASM blobs, // `cloud-next` cannot have message / frame size limits. That's @@ -347,8 +532,15 @@ impl WsConnection { uri, source: Arc::new(source), })?; + let negotiated_protocol = response + .headers() + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .and_then(|protocol| protocol.to_str().ok()) + .map(NegotiatedWsProtocol::from_negotiated_protocol) + .unwrap_or_default(); Ok(WsConnection { db_name: db_name.into(), + protocol: negotiated_protocol, sock, }) } @@ -368,6 +560,9 @@ impl WsConnection { }; let uri = make_uri(host, db_name, connection_id, params, token.as_deref())?; + // Browser targets stay on v2 for now. `tokio-tungstenite-wasm` does not + // expose the negotiated subprotocol, so we cannot safely offer v3 with + // a real v2 fallback here without replacing the wrapper entirely. let sock = tokio_tungstenite_wasm::connect_with_protocols(&uri.to_string(), &[ws::v2::BIN_PROTOCOL]) .await .map_err(|source| WsError::Tungstenite { @@ -381,13 +576,30 @@ impl WsConnection { }) } - pub(crate) fn parse_response(bytes: &[u8]) -> Result { + /// Parses one native websocket payload into the ordered logical v2 server + /// messages carried by the negotiated transport. + #[cfg(not(feature = "browser"))] + fn parse_responses(protocol: NegotiatedWsProtocol, bytes: &[u8]) -> Result, WsError> { let bytes = &*decompress_server_message(bytes)?; - bsatn::from_slice(bytes).map_err(|source| WsError::DeserializeMessage { source }) + match protocol { + NegotiatedWsProtocol::V2 => Ok(vec![decode_v2_server_message(bytes)?]), + NegotiatedWsProtocol::V3 => { + let frame: ws::v3::ServerFrame = + bsatn::from_slice(bytes).map_err(|source| WsError::DeserializeMessage { source })?; + flatten_server_frame(frame) + .into_vec() + .into_iter() + .map(|message| decode_v2_server_message(&message)) + .collect() + } + } } - pub(crate) fn encode_message(msg: ws::v2::ClientMessage) -> WebSocketMessage { - WebSocketMessage::Binary(bsatn::to_vec(&msg).unwrap().into()) + /// Parses one browser websocket payload, which always uses legacy v2 framing. + #[cfg(feature = "browser")] + fn parse_v2_response(bytes: &[u8]) -> Result { + let bytes = &*decompress_server_message(bytes)?; + decode_v2_server_message(bytes) } #[cfg(not(feature = "browser"))] @@ -439,7 +651,17 @@ impl WsConnection { let mut want_pong = false; let mut outgoing_messages = Some(outgoing_messages); + let mut pending_outgoing = VecDeque::new(); + let mut yield_after_capped_flush = false; loop { + if yield_after_capped_flush { + // Under v3 we emit at most one bounded frame per flush. If there + // are still queued messages after hitting the cap, yield before + // sending the next frame so inbound socket work is not starved by + // a tight outbound-only drain loop. + yield_after_capped_flush = false; + tokio::task::yield_now().await; + } tokio::select! { incoming = self.sock.try_next() => match incoming { Err(tokio_tungstenite::tungstenite::error::Error::ConnectionClosed) | Ok(None) => { @@ -459,18 +681,7 @@ impl WsConnection { Ok(Some(WebSocketMessage::Binary(bytes))) => { idle = false; record_metrics(bytes.len()); - match Self::parse_response(&bytes) { - Err(e) => maybe_log_error!( - &extra_logging, - "Error decoding WebSocketMessage::Binary payload", - Result::<(), _>::Err(e) - ), - Ok(msg) => maybe_log_error!( - &extra_logging, - "Error sending decoded message to incoming_messages queue", - incoming_messages.unbounded_send(msg) - ), - } + forward_parsed_responses_native(self.protocol, &incoming_messages, &extra_logging, &bytes); } Ok(Some(WebSocketMessage::Ping(payload))) => { @@ -518,14 +729,26 @@ impl WsConnection { }, // this is stupid. we want to handle the channel close *once*, and then disable this branch - Some(outgoing) = async { Some(outgoing_messages.as_mut()?.next().await) } => match outgoing { + Some(outgoing) = async { + Some(if let Some(outgoing) = pending_outgoing.pop_front() { + Some(outgoing) + } else { + outgoing_messages.as_mut()?.next().await + }) + } => match outgoing { Some(outgoing) => { - let msg = Self::encode_message(outgoing); - if let Err(e) = self.sock.send(msg).await { + let (msg, has_leftover_pending_outgoing) = encode_outgoing_message( + self.protocol, + outgoing, + &mut pending_outgoing, + || outgoing_messages.as_mut().and_then(|outgoing| outgoing.try_next().ok().flatten()), + ); + if let Err(e) = self.sock.send(WebSocketMessage::Binary(msg)).await { debug_log(&extra_logging, |file| writeln!(file, "Error sending outgoing message: {e:?}")); log::warn!("Error sending outgoing message: {e:?}"); break; } + yield_after_capped_flush = has_leftover_pending_outgoing; } None => { maybe_log_error!(&extra_logging, "Error sending close frame", SinkExt::close(&mut self.sock).await); @@ -570,7 +793,6 @@ impl WsConnection { let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); let (incoming_tx, incoming_rx) = mpsc::unbounded::(); - let (mut ws_writer, ws_reader) = self.sock.split(); wasm_bindgen_futures::spawn_local(async move { @@ -588,18 +810,17 @@ impl WsConnection { Some(Ok(WebSocketMessage::Binary(bytes))) => { record_metrics(bytes.len()); - // parse + forward into `incoming_tx` - match Self::parse_response(&bytes) { + match Self::parse_v2_response(&bytes) { Ok(msg) => if let Err(_e) = incoming_tx.unbounded_send(msg) { gloo_console::warn!("Incoming receiver dropped."); break; }, Err(e) => { gloo_console::warn!( - "Error decoding WebSocketMessage::Binay payload: ", + "Error decoding WebSocketMessage::Binary payload: ", format!("{:?}", e) ); - }, + } } }, @@ -623,12 +844,12 @@ impl WsConnection { Some(Ok(other)) => { record_metrics(other.len()); gloo_console::warn!("Unexpected WebSocket message: ", format!("{:?}",other)); - } + }, }, // 2) outbound messages outbound = outgoing.next() => if let Some(client_msg) = outbound { - let raw = Self::encode_message(client_msg); + let raw = WebSocketMessage::Binary(encode_v2_client_message_bytes(&client_msg)); if let Err(e) = ws_writer.send(raw).await { gloo_console::warn!("Error sending outgoing message:", format!("{:?}",e)); break; @@ -647,3 +868,150 @@ impl WsConnection { (incoming_rx, outgoing_tx) } } + +#[cfg(all(test, not(feature = "browser")))] +mod tests { + use super::*; + use spacetimedb_lib::{Identity, TimeDuration, Timestamp}; + + fn reducer_call(request_id: u32, arg_len: usize) -> ws::v2::ClientMessage { + ws::v2::ClientMessage::CallReducer(ws::v2::CallReducer { + request_id, + flags: ws::v2::CallReducerFlags::Default, + reducer: "reducer".into(), + args: Bytes::from(vec![0; arg_len]), + }) + } + + fn procedure_result(request_id: u32) -> ws::v2::ServerMessage { + ws::v2::ServerMessage::ProcedureResult(ws::v2::ProcedureResult { + status: ws::v2::ProcedureStatus::Returned(Bytes::new()), + timestamp: Timestamp::UNIX_EPOCH, + total_host_execution_duration: TimeDuration::ZERO, + request_id, + }) + } + + fn encode_server_message(message: &ws::v2::ServerMessage) -> Vec { + let mut encoded = vec![ws::common::SERVER_MSG_COMPRESSION_TAG_NONE]; + encoded.extend(bsatn::to_vec(message).unwrap()); + encoded + } + + fn encode_server_frame(frame: &ws::v3::ServerFrame) -> Vec { + let mut encoded = vec![ws::common::SERVER_MSG_COMPRESSION_TAG_NONE]; + encoded.extend(bsatn::to_vec(frame).unwrap()); + encoded + } + + #[test] + fn negotiated_protocol_defaults_to_v2() { + assert_eq!( + NegotiatedWsProtocol::from_negotiated_protocol(""), + NegotiatedWsProtocol::V2 + ); + assert_eq!( + NegotiatedWsProtocol::from_negotiated_protocol(ws::v2::BIN_PROTOCOL), + NegotiatedWsProtocol::V2 + ); + assert_eq!( + NegotiatedWsProtocol::from_negotiated_protocol("unexpected-protocol"), + NegotiatedWsProtocol::V2 + ); + } + + #[test] + fn negotiated_protocol_recognizes_v3() { + assert_eq!( + NegotiatedWsProtocol::from_negotiated_protocol(ws::v3::BIN_PROTOCOL), + NegotiatedWsProtocol::V3 + ); + } + + #[test] + fn encode_outgoing_message_batches_small_v3_messages() { + let mut pending = VecDeque::new(); + let (raw, has_leftover_pending_outgoing) = + encode_outgoing_message(NegotiatedWsProtocol::V3, reducer_call(1, 8), &mut pending, { + let mut extra = VecDeque::from([reducer_call(2, 8)]); + move || extra.pop_front() + }); + + assert!(!has_leftover_pending_outgoing); + assert!(pending.is_empty()); + + let frame: ws::v3::ClientFrame = bsatn::from_slice(&raw).unwrap(); + let ws::v3::ClientFrame::Batch(messages) = frame else { + panic!("expected batched v3 client frame"); + }; + assert_eq!(messages.len(), 2); + } + + #[test] + fn encode_outgoing_message_caps_v3_frames_at_256_kib() { + let mut pending = VecDeque::new(); + let oversized = 200 * 1024; + let (raw, has_leftover_pending_outgoing) = + encode_outgoing_message(NegotiatedWsProtocol::V3, reducer_call(1, oversized), &mut pending, { + let mut extra = VecDeque::from([reducer_call(2, oversized)]); + move || extra.pop_front() + }); + + assert!(has_leftover_pending_outgoing); + assert_eq!(pending.len(), 1); + + let frame: ws::v3::ClientFrame = bsatn::from_slice(&raw).unwrap(); + let ws::v3::ClientFrame::Single(message) = frame else { + panic!("expected single v3 client frame"); + }; + let inner: ws::v2::ClientMessage = bsatn::from_slice(&message).unwrap(); + match inner { + ws::v2::ClientMessage::CallReducer(call) => assert_eq!(call.request_id, 1), + _ => panic!("expected CallReducer inner message"), + } + } + + #[test] + fn parse_response_supports_v2_messages() { + let encoded = encode_server_message(&ws::v2::ServerMessage::InitialConnection(ws::v2::InitialConnection { + identity: Identity::ZERO, + connection_id: ConnectionId::ZERO, + token: "token".into(), + })); + + let messages = WsConnection::parse_responses(NegotiatedWsProtocol::V2, &encoded).unwrap(); + assert_eq!(messages.len(), 1); + match &messages[0] { + ws::v2::ServerMessage::InitialConnection(message) => { + assert_eq!(message.identity, Identity::ZERO); + assert_eq!(message.connection_id, ConnectionId::ZERO); + } + other => panic!("unexpected v2 message: {other:?}"), + } + } + + #[test] + fn parse_response_unwraps_v3_batches() { + let first = procedure_result(1); + let second = procedure_result(2); + let frame = ws::v3::ServerFrame::Batch( + vec![ + Bytes::from(bsatn::to_vec(&first).unwrap()), + Bytes::from(bsatn::to_vec(&second).unwrap()), + ] + .into_boxed_slice(), + ); + let encoded = encode_server_frame(&frame); + + let messages = WsConnection::parse_responses(NegotiatedWsProtocol::V3, &encoded).unwrap(); + assert_eq!(messages.len(), 2); + for (expected_request_id, message) in [1, 2].into_iter().zip(messages) { + match message { + ws::v2::ServerMessage::ProcedureResult(result) => { + assert_eq!(result.request_id, expected_request_id); + } + other => panic!("unexpected v3 inner message: {other:?}"), + } + } + } +}