Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async-trait = "0.1.89"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
thiserror = "2"
tokio = { version = "1", features = ["sync", "macros", "rt", "time"] }
tokio = { version = "1", features = ["sync", "macros", "time"] }
futures = "0.3"
tracing = { version = "0.1" }
tokio-util = { version = "0.7" }
Expand Down Expand Up @@ -56,7 +56,7 @@ process-wrap = { version = "9.0", features = ["tokio1"], optional = true }

# for http-server transport
rand = { version = "0.10", optional = true }
tokio-stream = { version = "0.1", optional = true }
tokio-stream = { version = "0.1", optional = true, features = ["sync"] }
uuid = { version = "1", features = ["v4"], optional = true }
http-body = { version = "1", optional = true }
http-body-util = { version = "0.1", optional = true }
Expand All @@ -77,7 +77,12 @@ chrono = { version = "0.4.38", default-features = false, features = [
[features]
default = ["base64", "macros", "server"]
client = ["dep:tokio-stream"]
server = ["transport-async-rw", "dep:schemars", "dep:pastey"]
server = [
"transport-async-rw",
"dep:schemars",
"dep:pastey",
"dep:tokio-stream",
]
macros = ["dep:rmcp-macros", "dep:pastey"]
elicitation = ["dep:url"]

Expand Down Expand Up @@ -109,14 +114,18 @@ client-side-sse = ["dep:sse-stream", "dep:http"]

# Streamable HTTP client
transport-streamable-http-client = ["client-side-sse", "transport-worker"]
transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"]
transport-streamable-http-client-reqwest = [
"transport-streamable-http-client",
"__reqwest",
]

transport-async-rw = ["tokio/io-util", "tokio-util/codec"]
transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"]
transport-io = ["transport-async-rw", "tokio/io-std"]
transport-child-process = [
transport-child-process = ["transport-async-rw", "tokio/process"]
transport-child-process-tokio = [
"transport-async-rw",
"tokio/process",
"dep:process-wrap",
"tokio/rt",
]
transport-streamable-http-server = [
"transport-streamable-http-server-session",
Expand All @@ -135,7 +144,10 @@ schemars = ["dep:schemars"]
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
schemars = { version = "1.1.0", features = ["chrono04"] }
axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] }
axum = { version = "0.8", default-features = false, features = [
"http1",
"tokio",
] }
anyhow = "1.0"
tracing-subscriber = { version = "0.3", features = [
"env-filter",
Expand All @@ -155,6 +167,7 @@ required-features = [
"server",
"client",
"transport-child-process",
"transport-child-process-tokio",
]
path = "tests/test_with_python.rs"

Expand All @@ -164,8 +177,10 @@ required-features = [
"server",
"client",
"transport-child-process",
"transport-child-process-tokio",
"transport-streamable-http-server",
"transport-streamable-http-client",
"transport-streamable-http-client-reqwest",
"__reqwest",
]
path = "tests/test_with_js.rs"
Expand Down Expand Up @@ -207,12 +222,22 @@ path = "tests/test_task.rs"

[[test]]
name = "test_streamable_http_priming"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
required-features = [
"server",
"client",
"transport-streamable-http-server",
"reqwest",
]
path = "tests/test_streamable_http_priming.rs"

[[test]]
name = "test_streamable_http_json_response"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
required-features = [
"server",
"client",
"transport-streamable-http-server",
"reqwest",
]
path = "tests/test_streamable_http_json_response.rs"


Expand Down Expand Up @@ -249,5 +274,11 @@ path = "tests/test_custom_headers.rs"

[[test]]
name = "test_sse_concurrent_streams"
required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"]
required-features = [
"server",
"client",
"transport-streamable-http-server",
"transport-streamable-http-client",
"reqwest",
]
path = "tests/test_sse_concurrent_streams.rs"
2 changes: 0 additions & 2 deletions crates/rmcp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ pub enum RmcpError {
#[cfg(feature = "server")]
#[error("Server initialization error: {0}")]
ServerInitialize(#[from] crate::service::ServerInitializeError),
#[error("Runtime error: {0}")]
Runtime(#[from] tokio::task::JoinError),
#[error("Transport creation error: {error}")]
// TODO: Maybe we can introduce something like `TryIntoTransport` to auto wrap transport type,
// but it could be an breaking change, so we could do it in the future.
Expand Down
153 changes: 112 additions & 41 deletions crates/rmcp/src/handler/client/progress.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,53 @@
use std::{collections::HashMap, sync::Arc};

use futures::{Stream, StreamExt};
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;

use crate::model::{ProgressNotificationParam, ProgressToken};
type Dispatcher =
Arc<RwLock<HashMap<ProgressToken, tokio::sync::mpsc::Sender<ProgressNotificationParam>>>>;
use crate::{
model::{ProgressNotificationParam, ProgressToken},
util::PinnedStream,
};

/// A dispatcher for progress notifications.
#[derive(Debug, Clone, Default)]
///
/// See [ProgressNotificationParam] and [ProgressToken] for more details on
/// how progress is dispatched to a particular listener.
#[derive(Debug, Clone)]
pub struct ProgressDispatcher {
pub(crate) dispatcher: Dispatcher,
/// A channel of any progress notification. Subscribers will filter
/// on this channel.
pub(crate) any_progress_notification_tx: broadcast::Sender<ProgressNotificationParam>,
pub(crate) unsubscribe_tx: broadcast::Sender<ProgressToken>,
pub(crate) unsubscribe_all_tx: broadcast::Sender<()>,
}

impl ProgressDispatcher {
const CHANNEL_SIZE: usize = 16;
pub fn new() -> Self {
Self::default()
// Note that channel size is per-receiver for broadcast channel. It is up to the receiver to
// keep up with the notifications to avoid missing any (via propper polling)
let (any_progress_notification_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
let (unsubscribe_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
let (unsubscribe_all_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
Self {
any_progress_notification_tx,
unsubscribe_tx,
unsubscribe_all_tx,
}
}

/// Handle a progress notification by sending it to the appropriate subscriber
pub async fn handle_notification(&self, notification: ProgressNotificationParam) {
let token = &notification.progress_token;
if let Some(sender) = self.dispatcher.read().await.get(token).cloned() {
let send_result = sender.send(notification).await;
if let Err(e) = send_result {
tracing::warn!("Failed to send progress notification: {e}");
// Broadcast the notification to all subscribers. Interested subscribers
// will filter on their end.
// ! Note that this implementaiton is very stateless and simple, we cannot
// ! easily inspect which subscribers are interested in which notifications.
// ! However, the stateless-ness and simplicity is also a plus!
// ! Cleanup becomes much easier. Just drop the `ProgressSubscriber`.
match self.any_progress_notification_tx.send(notification) {
Ok(_) => {}
Err(_) => {
// This error only happens if there are no active receivers of the `broadcast` channel.
// Silent error.
}
}
}
Expand All @@ -35,35 +56,97 @@ impl ProgressDispatcher {
///
/// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token.
pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber {
let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE);
self.dispatcher
.write()
.await
.insert(progress_token.clone(), sender);
let receiver = ReceiverStream::new(receiver);
// First, set up the unsubscribe listeners. This will fuse the notifiaction stream below.
let progress_token_clone = progress_token.clone();
let unsub_this_token_rx = BroadcastStream::new(self.unsubscribe_tx.subscribe()).filter_map(
move |token| {
let progress_token_clone = progress_token_clone.clone();
async move {
match token {
Ok(token) => {
if token == progress_token_clone {
Some(())
} else {
None
}
}
Err(e) => {
// An error here means the broadcast stream did not receive values quick enough and
// and we missed some notification. This implies there are notifications
// we missed, but we cannot assume they were for us :(
tracing::warn!(
"Error receiving unsubscribe notification from broadcast channel: {e}"
);
None
}
}
}
},
);
let unsub_any_token_tx =
BroadcastStream::new(self.unsubscribe_all_tx.subscribe()).map(|_| {
// Any reception of a result here indicates we should unsubscribe,
// regardless of if we received an `Ok(())` or an `Err(_)` (which
// indicates the broadcast receiver lagged behind)
()
});
let unsub_fut = futures::stream::select(unsub_this_token_rx, unsub_any_token_tx)
.boxed()
.into_future(); // If the unsub streams end, this will cause unsubscription from the subscriber below.

// Now setup the notification stream. We will receive all notifications and only forward progress notifications
// for the token we're interested in.
let progress_token_clone = progress_token.clone();
let receiver = BroadcastStream::new(self.any_progress_notification_tx.subscribe())
.filter_map(move |notification| {
let progress_token_clone = progress_token_clone.clone();
async move {
// We need to kneed-out the broadcast receive error type here.
match notification {
Ok(notification) => {
let token = notification.progress_token.clone();
if token == progress_token_clone {
Some(notification)
} else {
None
}
}
Err(e) => {
tracing::warn!(
"Error receiving progress notification from broadcast channel: {e}"
);
None
}
}
}
})
// Fuse this stream so it stops once we receive an unsubscribe notification from the stream
// created above
.take_until(unsub_fut)
.boxed();

ProgressSubscriber {
progress_token,
receiver,
dispatcher: self.dispatcher.clone(),
}
}

/// Unsubscribe from progress notifications for a specific token.
pub async fn unsubscribe(&self, token: &ProgressToken) {
self.dispatcher.write().await.remove(token);
pub fn unsubscribe(&self, token: ProgressToken) {
// The only error defined is if there are no listeners, which is fine. Ignore the result.
let _ = self.unsubscribe_tx.send(token);
}

/// Clear all dispatcher.
pub async fn clear(&self) {
let mut dispatcher = self.dispatcher.write().await;
dispatcher.clear();
pub fn clear(&self) {
// The only error defined is if there are no listeners, which is fine. Ignore the result.
let _ = self.unsubscribe_all_tx.send(());
}
}

pub struct ProgressSubscriber {
pub(crate) progress_token: ProgressToken,
pub(crate) receiver: ReceiverStream<ProgressNotificationParam>,
pub(crate) dispatcher: Dispatcher,
pub(crate) receiver: PinnedStream<'static, ProgressNotificationParam>,
}

impl ProgressSubscriber {
Expand All @@ -86,15 +169,3 @@ impl Stream for ProgressSubscriber {
self.receiver.size_hint()
}
}

impl Drop for ProgressSubscriber {
fn drop(&mut self) {
let token = self.progress_token.clone();
self.receiver.close();
let dispatcher = self.dispatcher.clone();
tokio::spawn(async move {
let mut dispatcher = dispatcher.write_owned().await;
dispatcher.remove(&token);
});
}
}
2 changes: 2 additions & 0 deletions crates/rmcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#![doc = include_str!("../README.md")]

mod error;
mod util;

#[allow(deprecated)]
pub use error::{Error, ErrorData, RmcpError};

Expand Down
Loading