diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 96c319dc4..9117db704 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -19,7 +19,7 @@ async-trait = "0.1.89" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" thiserror = "2" -tokio = { version = "1", features = ["sync", "macros", "rt", "time"] } +tokio = { version = "1", features = ["sync", "macros", "time"] } futures = "0.3" tracing = { version = "0.1" } tokio-util = { version = "0.7" } @@ -56,7 +56,7 @@ process-wrap = { version = "9.0", features = ["tokio1"], optional = true } # for http-server transport rand = { version = "0.10", optional = true } -tokio-stream = { version = "0.1", optional = true } +tokio-stream = { version = "0.1", optional = true, features = ["sync"] } uuid = { version = "1", features = ["v4"], optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } @@ -77,7 +77,12 @@ chrono = { version = "0.4.38", default-features = false, features = [ [features] default = ["base64", "macros", "server"] client = ["dep:tokio-stream"] -server = ["transport-async-rw", "dep:schemars", "dep:pastey"] +server = [ + "transport-async-rw", + "dep:schemars", + "dep:pastey", + "dep:tokio-stream", +] macros = ["dep:rmcp-macros", "dep:pastey"] elicitation = ["dep:url"] @@ -109,14 +114,18 @@ client-side-sse = ["dep:sse-stream", "dep:http"] # Streamable HTTP client transport-streamable-http-client = ["client-side-sse", "transport-worker"] -transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"] +transport-streamable-http-client-reqwest = [ + "transport-streamable-http-client", + "__reqwest", +] -transport-async-rw = ["tokio/io-util", "tokio-util/codec"] +transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"] transport-io = ["transport-async-rw", "tokio/io-std"] -transport-child-process = [ +transport-child-process = ["transport-async-rw", "tokio/process"] +transport-child-process-tokio = [ "transport-async-rw", "tokio/process", - "dep:process-wrap", + "tokio/rt", ] transport-streamable-http-server = [ "transport-streamable-http-server-session", @@ -135,7 +144,10 @@ schemars = ["dep:schemars"] [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "1.1.0", features = ["chrono04"] } -axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] } +axum = { version = "0.8", default-features = false, features = [ + "http1", + "tokio", +] } anyhow = "1.0" tracing-subscriber = { version = "0.3", features = [ "env-filter", @@ -155,6 +167,7 @@ required-features = [ "server", "client", "transport-child-process", + "transport-child-process-tokio", ] path = "tests/test_with_python.rs" @@ -164,8 +177,10 @@ required-features = [ "server", "client", "transport-child-process", + "transport-child-process-tokio", "transport-streamable-http-server", "transport-streamable-http-client", + "transport-streamable-http-client-reqwest", "__reqwest", ] path = "tests/test_with_js.rs" @@ -207,12 +222,22 @@ path = "tests/test_task.rs" [[test]] name = "test_streamable_http_priming" -required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "reqwest", +] path = "tests/test_streamable_http_priming.rs" [[test]] name = "test_streamable_http_json_response" -required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "reqwest", +] path = "tests/test_streamable_http_json_response.rs" @@ -249,5 +274,11 @@ path = "tests/test_custom_headers.rs" [[test]] name = "test_sse_concurrent_streams" -required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "transport-streamable-http-client", + "reqwest", +] path = "tests/test_sse_concurrent_streams.rs" diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs index c7901f4b5..3bd528f48 100644 --- a/crates/rmcp/src/error.rs +++ b/crates/rmcp/src/error.rs @@ -30,8 +30,6 @@ pub enum RmcpError { #[cfg(feature = "server")] #[error("Server initialization error: {0}")] ServerInitialize(#[from] crate::service::ServerInitializeError), - #[error("Runtime error: {0}")] - Runtime(#[from] tokio::task::JoinError), #[error("Transport creation error: {error}")] // TODO: Maybe we can introduce something like `TryIntoTransport` to auto wrap transport type, // but it could be an breaking change, so we could do it in the future. diff --git a/crates/rmcp/src/handler/client/progress.rs b/crates/rmcp/src/handler/client/progress.rs index 04f31610f..892de197f 100644 --- a/crates/rmcp/src/handler/client/progress.rs +++ b/crates/rmcp/src/handler/client/progress.rs @@ -1,32 +1,53 @@ -use std::{collections::HashMap, sync::Arc}; - use futures::{Stream, StreamExt}; -use tokio::sync::RwLock; -use tokio_stream::wrappers::ReceiverStream; +use tokio::sync::broadcast; +use tokio_stream::wrappers::BroadcastStream; -use crate::model::{ProgressNotificationParam, ProgressToken}; -type Dispatcher = - Arc>>>; +use crate::{ + model::{ProgressNotificationParam, ProgressToken}, + util::PinnedStream, +}; /// A dispatcher for progress notifications. -#[derive(Debug, Clone, Default)] +/// +/// See [ProgressNotificationParam] and [ProgressToken] for more details on +/// how progress is dispatched to a particular listener. +#[derive(Debug, Clone)] pub struct ProgressDispatcher { - pub(crate) dispatcher: Dispatcher, + /// A channel of any progress notification. Subscribers will filter + /// on this channel. + pub(crate) any_progress_notification_tx: broadcast::Sender, + pub(crate) unsubscribe_tx: broadcast::Sender, + pub(crate) unsubscribe_all_tx: broadcast::Sender<()>, } impl ProgressDispatcher { const CHANNEL_SIZE: usize = 16; pub fn new() -> Self { - Self::default() + // Note that channel size is per-receiver for broadcast channel. It is up to the receiver to + // keep up with the notifications to avoid missing any (via propper polling) + let (any_progress_notification_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + let (unsubscribe_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + let (unsubscribe_all_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + Self { + any_progress_notification_tx, + unsubscribe_tx, + unsubscribe_all_tx, + } } /// Handle a progress notification by sending it to the appropriate subscriber pub async fn handle_notification(&self, notification: ProgressNotificationParam) { - let token = ¬ification.progress_token; - if let Some(sender) = self.dispatcher.read().await.get(token).cloned() { - let send_result = sender.send(notification).await; - if let Err(e) = send_result { - tracing::warn!("Failed to send progress notification: {e}"); + // Broadcast the notification to all subscribers. Interested subscribers + // will filter on their end. + // ! Note that this implementaiton is very stateless and simple, we cannot + // ! easily inspect which subscribers are interested in which notifications. + // ! However, the stateless-ness and simplicity is also a plus! + // ! Cleanup becomes much easier. Just drop the `ProgressSubscriber`. + match self.any_progress_notification_tx.send(notification) { + Ok(_) => {} + Err(_) => { + // This error only happens if there are no active receivers of the `broadcast` channel. + // Silent error. } } } @@ -35,35 +56,97 @@ impl ProgressDispatcher { /// /// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token. pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber { - let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE); - self.dispatcher - .write() - .await - .insert(progress_token.clone(), sender); - let receiver = ReceiverStream::new(receiver); + // First, set up the unsubscribe listeners. This will fuse the notifiaction stream below. + let progress_token_clone = progress_token.clone(); + let unsub_this_token_rx = BroadcastStream::new(self.unsubscribe_tx.subscribe()).filter_map( + move |token| { + let progress_token_clone = progress_token_clone.clone(); + async move { + match token { + Ok(token) => { + if token == progress_token_clone { + Some(()) + } else { + None + } + } + Err(e) => { + // An error here means the broadcast stream did not receive values quick enough and + // and we missed some notification. This implies there are notifications + // we missed, but we cannot assume they were for us :( + tracing::warn!( + "Error receiving unsubscribe notification from broadcast channel: {e}" + ); + None + } + } + } + }, + ); + let unsub_any_token_tx = + BroadcastStream::new(self.unsubscribe_all_tx.subscribe()).map(|_| { + // Any reception of a result here indicates we should unsubscribe, + // regardless of if we received an `Ok(())` or an `Err(_)` (which + // indicates the broadcast receiver lagged behind) + () + }); + let unsub_fut = futures::stream::select(unsub_this_token_rx, unsub_any_token_tx) + .boxed() + .into_future(); // If the unsub streams end, this will cause unsubscription from the subscriber below. + + // Now setup the notification stream. We will receive all notifications and only forward progress notifications + // for the token we're interested in. + let progress_token_clone = progress_token.clone(); + let receiver = BroadcastStream::new(self.any_progress_notification_tx.subscribe()) + .filter_map(move |notification| { + let progress_token_clone = progress_token_clone.clone(); + async move { + // We need to kneed-out the broadcast receive error type here. + match notification { + Ok(notification) => { + let token = notification.progress_token.clone(); + if token == progress_token_clone { + Some(notification) + } else { + None + } + } + Err(e) => { + tracing::warn!( + "Error receiving progress notification from broadcast channel: {e}" + ); + None + } + } + } + }) + // Fuse this stream so it stops once we receive an unsubscribe notification from the stream + // created above + .take_until(unsub_fut) + .boxed(); + ProgressSubscriber { progress_token, receiver, - dispatcher: self.dispatcher.clone(), } } /// Unsubscribe from progress notifications for a specific token. - pub async fn unsubscribe(&self, token: &ProgressToken) { - self.dispatcher.write().await.remove(token); + pub fn unsubscribe(&self, token: ProgressToken) { + // The only error defined is if there are no listeners, which is fine. Ignore the result. + let _ = self.unsubscribe_tx.send(token); } /// Clear all dispatcher. - pub async fn clear(&self) { - let mut dispatcher = self.dispatcher.write().await; - dispatcher.clear(); + pub fn clear(&self) { + // The only error defined is if there are no listeners, which is fine. Ignore the result. + let _ = self.unsubscribe_all_tx.send(()); } } pub struct ProgressSubscriber { pub(crate) progress_token: ProgressToken, - pub(crate) receiver: ReceiverStream, - pub(crate) dispatcher: Dispatcher, + pub(crate) receiver: PinnedStream<'static, ProgressNotificationParam>, } impl ProgressSubscriber { @@ -86,15 +169,3 @@ impl Stream for ProgressSubscriber { self.receiver.size_hint() } } - -impl Drop for ProgressSubscriber { - fn drop(&mut self) { - let token = self.progress_token.clone(); - self.receiver.close(); - let dispatcher = self.dispatcher.clone(); - tokio::spawn(async move { - let mut dispatcher = dispatcher.write_owned().await; - dispatcher.remove(&token); - }); - } -} diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 9ae3f9586..c70e61b5b 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -3,6 +3,8 @@ #![doc = include_str!("../README.md")] mod error; +mod util; + #[allow(deprecated)] pub use error::{Error, ErrorData, RmcpError}; diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index b12839c6f..f4374deaf 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -1,5 +1,10 @@ -use futures::{FutureExt, future::BoxFuture}; +use futures::{ + FutureExt, Stream, StreamExt, + future::{BoxFuture, RemoteHandle}, + stream::FuturesUnordered, +}; use thiserror::Error; +use tokio_stream::wrappers::ReceiverStream; #[cfg(feature = "server")] use crate::model::ServerJsonRpcMessage; @@ -11,6 +16,7 @@ use crate::{ NumberOrString, ProgressToken, RequestId, }, transport::{DynamicTransportError, IntoTransport, Transport}, + util::PinnedFuture, }; #[cfg(feature = "client")] mod client; @@ -111,7 +117,9 @@ pub trait ServiceExt: Service + Sized { fn serve( self, transport: T, - ) -> impl Future, R::InitializeError>> + Send + ) -> impl Future< + Output = Result<(RunningService, impl Future), R::InitializeError>, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -123,7 +131,9 @@ pub trait ServiceExt: Service + Sized { self, transport: T, ct: CancellationToken, - ) -> impl Future, R::InitializeError>> + Send + ) -> impl Future< + Output = Result<(RunningService, impl Future), R::InitializeError>, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -188,6 +198,7 @@ impl> DynService for S { use std::{ collections::{HashMap, VecDeque}, + fmt::Debug, ops::Deref, sync::{Arc, atomic::AtomicU64}, time::Duration, @@ -246,6 +257,8 @@ impl RequestHandle { pub const REQUEST_TIMEOUT_REASON: &str = "request timeout"; pub async fn await_response(self) -> Result { if let Some(timeout) = self.options.timeout { + // TODO: tokio timeout won't work if not in the tokio RT + // Find an alternative let timeout_result = tokio::time::timeout(timeout, async move { self.rx.await.map_err(|_e| ServiceError::TransportClosed)? }) @@ -426,14 +439,29 @@ impl Peer { } } -#[derive(Debug)] pub struct RunningService> { service: Arc, peer: Peer, - handle: Option>, + handle: Option>, cancellation_token: CancellationToken, dg: DropGuard, } + +impl> Debug for RunningService +where + S: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RunningService") + .field("service", &self.service) + .field("peer", &self.peer) + .field("handle", &self.handle.as_ref().map(|_| "")) + .field("cancellation_token", &self.cancellation_token) + .field("dg", &self.dg) + .finish() + } +} + impl> Deref for RunningService { type Target = Peer; @@ -467,10 +495,10 @@ impl> RunningService { /// This will block until the service loop terminates (either due to /// cancellation, transport closure, or an error). #[inline] - pub async fn waiting(mut self) -> Result { + pub async fn waiting(mut self) -> QuitReason { match self.handle.take() { Some(handle) => handle.await, - None => Ok(QuitReason::Closed), + None => QuitReason::Closed, } } @@ -491,7 +519,7 @@ impl> RunningService { /// // ... use the client ... /// client.close().await?; /// ``` - pub async fn close(&mut self) -> Result { + pub async fn close(&mut self) -> QuitReason { if let Some(handle) = self.handle.take() { // Disarm the drop guard so it doesn't try to cancel again // We need to cancel manually and wait for completion @@ -499,7 +527,7 @@ impl> RunningService { handle.await } else { // Already closed - Ok(QuitReason::Closed) + QuitReason::Closed } } @@ -511,24 +539,22 @@ impl> RunningService { /// /// Returns `Ok(Some(reason))` if shutdown completed within the timeout, /// `Ok(None)` if the timeout was reached, or `Err` if there was a join error. - pub async fn close_with_timeout( - &mut self, - timeout: Duration, - ) -> Result, tokio::task::JoinError> { + pub async fn close_with_timeout(&mut self, timeout: Duration) -> Option { if let Some(handle) = self.handle.take() { self.cancellation_token.cancel(); + // TODO: tokio timeout won't work if not in the tokio RT, find an alternative match tokio::time::timeout(timeout, handle).await { - Ok(result) => result.map(Some), + Ok(reason) => Some(reason), Err(_elapsed) => { tracing::warn!( "close_with_timeout: cleanup did not complete within {:?}", timeout ); - Ok(None) + None } } } else { - Ok(Some(QuitReason::Closed)) + Some(QuitReason::Closed) } } @@ -536,7 +562,7 @@ impl> RunningService { /// /// This consumes the `RunningService` and ensures the connection is properly /// closed. For a non-consuming alternative, see [`close`](Self::close). - pub async fn cancel(mut self) -> Result { + pub async fn cancel(mut self) -> QuitReason { // Disarm the drop guard since we're handling cancellation explicitly let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard()); self.close().await @@ -546,6 +572,9 @@ impl> RunningService { impl> Drop for RunningService { fn drop(&mut self) { if self.handle.is_some() && !self.cancellation_token.is_cancelled() { + // Make sure we don't stop the work itself, the work future should + // handle that via cancellation token or drop guard + self.handle.take().unwrap().forget(); tracing::debug!( "RunningService dropped without explicit close(). \ The connection will be closed asynchronously. \ @@ -569,7 +598,6 @@ impl RunningServiceCancellationToken { pub enum QuitReason { Cancelled, Closed, - JoinError(tokio::task::JoinError), } /// Request execution context @@ -594,11 +622,20 @@ pub struct NotificationContext { } /// Use this function to skip initialization process +/// +/// TODO: What initialization process? Reference that here +/// +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. pub fn serve_directly( service: S, transport: T, peer_info: Option, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, @@ -609,12 +646,21 @@ where } /// Use this function to skip initialization process +/// +/// TODO: What initialization process? Reference that here +/// +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. pub fn serve_directly_with_ct( service: S, transport: T, peer_info: Option, ct: CancellationToken, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, @@ -622,25 +668,30 @@ where E: std::error::Error + Send + Sync + 'static, { let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info); + let peer_rx = ReceiverStream::new(peer_rx); serve_inner(service, transport.into_transport(), peer, peer_rx, ct) } +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. #[instrument(skip_all)] -fn serve_inner( +fn serve_inner( service: S, transport: T, peer: Peer, - mut peer_rx: tokio::sync::mpsc::Receiver>, + peer_rx: PeerStream, ct: CancellationToken, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, T: Transport + 'static, + PeerStream: Stream> + Unpin, { - const SINK_PROXY_BUFFER_SIZE: usize = 64; - let (sink_proxy_tx, mut sink_proxy_rx) = - tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); let peer_info = peer.peer_info(); if R::IS_CLIENT { tracing::info!(?peer_info, "Service initialized as client"); @@ -648,9 +699,6 @@ where tracing::info!(?peer_info, "Service initialized as server"); } - let mut local_responder_pool = - HashMap::>>::new(); - let mut local_ct_pool = HashMap::::new(); let shared_service = Arc::new(service); // for return let service = shared_service.clone(); @@ -658,283 +706,325 @@ where // let message_sink = tokio::sync:: // let mut stream = std::pin::pin!(stream); let serve_loop_ct = ct.child_token(); - let peer_return: Peer = peer.clone(); + let peer_return = peer.clone(); let current_span = tracing::Span::current(); - let handle = tokio::spawn(async move { - let mut transport = transport.into_transport(); - let mut batch_messages = VecDeque::>::new(); - let mut send_task_set = tokio::task::JoinSet::::new(); - #[derive(Debug)] - enum SendTaskResult { - Request { - id: RequestId, - result: Result<(), DynamicTransportError>, - }, - Notification { - responder: Responder>, - cancellation_param: Option, - result: Result<(), DynamicTransportError>, - }, - } - #[derive(Debug)] - enum Event { - ProxyMessage(PeerSinkMessage), - PeerMessage(RxJsonRpcMessage), - ToSink(TxJsonRpcMessage), - SendTaskResult(SendTaskResult), - } - let quit_reason = loop { - let evt = if let Some(m) = batch_messages.pop_front() { - Event::PeerMessage(m) - } else { - tokio::select! { - m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => { - if let Some(m) = m { - Event::ToSink(m) - } else { - continue - } - } - m = transport.receive() => { - if let Some(m) = m { - Event::PeerMessage(m) - } else { - // input stream closed - tracing::info!("input stream terminated"); - break QuitReason::Closed - } - } - m = peer_rx.recv(), if !peer_rx.is_closed() => { - if let Some(m) = m { - Event::ProxyMessage(m) - } else { - continue - } - } - m = send_task_set.join_next(), if !send_task_set.is_empty() => { - let Some(result) = m else { - continue - }; - match result { - Err(e) => { - // join error, which is serious, we should quit. - tracing::error!(%e, "send request task encounter a tokio join error"); - break QuitReason::JoinError(e) - } - Ok(result) => { - Event::SendTaskResult(result) - } - } - } - _ = serve_loop_ct.cancelled() => { - tracing::info!("task cancelled"); - break QuitReason::Cancelled + let work = controller(transport, peer_rx, serve_loop_ct, shared_service, peer) + .instrument(current_span); + + let (work, work_handle) = work.remote_handle(); + + let running_service = RunningService { + service, + peer: peer_return, + handle: Some(work_handle), + cancellation_token: ct.clone(), + dg: ct.drop_guard(), + }; + + (running_service, work) +} + +/// Main business logic for event dispatching and handling. +async fn controller( + transport: T, + peer_rx: PeerStream, + cancel_token: CancellationToken, + shared_service: Arc>, + peer: Peer, +) -> QuitReason +where + R: ServiceRole, + T: Transport + 'static, + PeerStream: Stream> + Unpin, +{ + let mut transport = transport.into_transport(); + let mut batch_messages = VecDeque::>::new(); + let mut send_task_set = FuturesUnordered::>::new(); + let mut side_effects_set = FuturesUnordered::>::new(); + + let mut local_responder_pool = + HashMap::>>::new(); + let mut local_ct_pool = HashMap::::new(); + + const SINK_PROXY_BUFFER_SIZE: usize = 64; + let (sink_proxy_tx, mut rpc_rx) = + tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); + + // Fuse the stream, so that once it return `None` it is guaranteed to never + // be polled again. Additionally, we can check if it have been fused by checking + // `is_done()`, which we use in the select branches below. + let mut peer_rx = peer_rx.fuse(); + + #[derive(Debug)] + enum SendTaskResult { + Request { + id: RequestId, + result: Result<(), DynamicTransportError>, + }, + Notification { + responder: Responder>, + cancellation_param: Option, + result: Result<(), DynamicTransportError>, + }, + } + #[derive(Debug)] + enum Event { + ProxyMessage(PeerSinkMessage), + PeerMessage(RxJsonRpcMessage), + ToSink(TxJsonRpcMessage), + SendTaskResult(SendTaskResult), + } + + let quit_reason = loop { + // Prioritize processing batch messages before other things + let evt = if let Some(m) = batch_messages.pop_front() { + Event::PeerMessage(m) + } else { + tokio::select! { + m = rpc_rx.recv(), if !rpc_rx.is_closed() => { + if let Some(m) = m { + Event::ToSink(m) + } else { + continue } } - }; - - tracing::trace!(?evt, "new event"); - match evt { - Event::SendTaskResult(SendTaskResult::Request { id, result }) => { - if let Err(e) = result { - if let Some(responder) = local_responder_pool.remove(&id) { - let _ = responder.send(Err(ServiceError::TransportSend(e))); - } + m = transport.receive() => { + if let Some(m) = m { + Event::PeerMessage(m) + } else { + // input stream closed + tracing::info!("input stream terminated"); + break QuitReason::Closed } } - Event::SendTaskResult(SendTaskResult::Notification { - responder, - result, - cancellation_param, - }) => { - let response = if let Err(e) = result { - Err(ServiceError::TransportSend(e)) + m = peer_rx.next(), if !peer_rx.is_done() => { + if let Some(m) = m { + Event::ProxyMessage(m) } else { - Ok(()) - }; - let _ = responder.send(response); - if let Some(param) = cancellation_param { - if let Some(responder) = local_responder_pool.remove(¶m.request_id) { - tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); - let _response_result = responder.send(Err(ServiceError::Cancelled { - reason: param.reason.clone(), - })); - } + continue } } - // response and error - Event::ToSink(m) => { - if let Some(id) = match &m { - JsonRpcMessage::Response(response) => Some(&response.id), - JsonRpcMessage::Error(error) => Some(&error.id), - _ => None, - } { - if let Some(ct) = local_ct_pool.remove(id) { - ct.cancel(); - } - let send = transport.send(m); - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let send_result = send.await; - if let Err(error) = send_result { - tracing::error!(%error, "fail to response message"); - } - }.instrument(current_span)); + m = send_task_set.next(), if !send_task_set.is_empty() => { + let Some(send_result) = m else { + continue + }; + Event::SendTaskResult(send_result) + } + _ = side_effects_set.next(), if !side_effects_set.is_empty() => { + // just drive the future, we don't care about the result + continue + } + _ = cancel_token.cancelled() => { + tracing::info!("task cancelled"); + break QuitReason::Cancelled + } + } + }; + + tracing::trace!(?evt, "new event"); + match evt { + Event::SendTaskResult(SendTaskResult::Request { id, result }) => { + if let Err(e) = result { + if let Some(responder) = local_responder_pool.remove(&id) { + let _ = responder.send(Err(ServiceError::TransportSend(e))); } } - Event::ProxyMessage(PeerSinkMessage::Request { - request, - id, - responder, - }) => { - local_responder_pool.insert(id.clone(), responder); - let send = transport.send(JsonRpcMessage::request(request, id.clone())); - { - let id = id.clone(); - let current_span = tracing::Span::current(); - send_task_set.spawn(send.map(move |r| SendTaskResult::Request { - id, - result: r.map_err(DynamicTransportError::new::), - }).instrument(current_span)); + } + Event::SendTaskResult(SendTaskResult::Notification { + responder, + result, + cancellation_param, + }) => { + let response = if let Err(e) = result { + Err(ServiceError::TransportSend(e)) + } else { + Ok(()) + }; + let _ = responder.send(response); + if let Some(param) = cancellation_param { + if let Some(responder) = local_responder_pool.remove(¶m.request_id) { + tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); + let _response_result = responder.send(Err(ServiceError::Cancelled { + reason: param.reason.clone(), + })); } } - Event::ProxyMessage(PeerSinkMessage::Notification { - notification, - responder, - }) => { - // catch cancellation notification - let mut cancellation_param = None; - let notification = match notification.try_into() { - Ok::(cancelled) => { - cancellation_param.replace(cancelled.params.clone()); - cancelled.into() - } - Err(notification) => notification, - }; - let send = transport.send(JsonRpcMessage::notification(notification)); + } + // response and error + Event::ToSink(m) => { + if let Some(id) = match &m { + JsonRpcMessage::Response(response) => Some(&response.id), + JsonRpcMessage::Error(error) => Some(&error.id), + _ => None, + } { + if let Some(ct) = local_ct_pool.remove(id) { + ct.cancel(); + } + let send = transport.send(m); let current_span = tracing::Span::current(); - send_task_set.spawn(send.map(move |result| SendTaskResult::Notification { + let send_work = async move { + let send_result = send.await; + if let Err(error) = send_result { + tracing::error!(%error, "fail to response message"); + } + } + .instrument(current_span) + .boxed(); + side_effects_set.push(send_work); + } + } + Event::ProxyMessage(PeerSinkMessage::Request { + request, + id, + responder, + }) => { + local_responder_pool.insert(id.clone(), responder); + let send = transport.send(JsonRpcMessage::request(request, id.clone())); + let id = id.clone(); + let current_span = tracing::Span::current(); + + let send = send + .map(move |r| SendTaskResult::Request { + id, + result: r.map_err(DynamicTransportError::new::), + }) + .instrument(current_span) + .boxed(); + send_task_set.push(send); + } + Event::ProxyMessage(PeerSinkMessage::Notification { + notification, + responder, + }) => { + // catch cancellation notification + let mut cancellation_param = None; + let notification = match notification.try_into() { + Ok::(cancelled) => { + cancellation_param.replace(cancelled.params.clone()); + cancelled.into() + } + Err(notification) => notification, + }; + let send = transport.send(JsonRpcMessage::notification(notification)); + let current_span = tracing::Span::current(); + let send = send + .map(move |result| SendTaskResult::Notification { responder, cancellation_param, result: result.map_err(DynamicTransportError::new::), - }).instrument(current_span)); - } - Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { - id, - mut request, - .. - })) => { - tracing::debug!(%id, ?request, "received request"); - { - let service = shared_service.clone(); - let sink = sink_proxy_tx.clone(); - let request_ct = serve_loop_ct.child_token(); - let context_ct = request_ct.child_token(); - local_ct_pool.insert(id.clone(), request_ct); - let mut extensions = Extensions::new(); - let mut meta = Meta::new(); - // avoid clone - // swap meta firstly, otherwise progress token will be lost - std::mem::swap(&mut meta, request.get_meta_mut()); - std::mem::swap(&mut extensions, request.extensions_mut()); - let context = RequestContext { - ct: context_ct, - id: id.clone(), - peer: peer.clone(), - meta, - extensions, + }) + .instrument(current_span) + .boxed(); + send_task_set.push(send); + } + Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { + id, mut request, .. + })) => { + tracing::debug!(%id, ?request, "received request"); + { + let service = shared_service.clone(); + let sink = sink_proxy_tx.clone(); + let request_ct = cancel_token.child_token(); + let context_ct = request_ct.child_token(); + local_ct_pool.insert(id.clone(), request_ct); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + // swap meta firstly, otherwise progress token will be lost + std::mem::swap(&mut meta, request.get_meta_mut()); + std::mem::swap(&mut extensions, request.extensions_mut()); + let context = RequestContext { + ct: context_ct, + id: id.clone(), + peer: peer.clone(), + meta, + extensions, + }; + let current_span = tracing::Span::current(); + let work = async move { + let result = service.handle_request(request, context).await; + let response = match result { + Ok(result) => { + tracing::debug!(%id, ?result, "response message"); + JsonRpcMessage::response(result, id) + } + Err(error) => { + tracing::warn!(%id, ?error, "response error"); + JsonRpcMessage::error(error, id) + } }; - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let result = service - .handle_request(request, context) - .await; - let response = match result { - Ok(result) => { - tracing::debug!(%id, ?result, "response message"); - JsonRpcMessage::response(result, id) - } - Err(error) => { - tracing::warn!(%id, ?error, "response error"); - JsonRpcMessage::error(error, id) - } - }; - let _send_result = sink.send(response).await; - }.instrument(current_span)); + let _send_result = sink.send(response).await; } + .instrument(current_span) + .boxed(); + side_effects_set.push(work); } - Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { - notification, - .. - })) => { - tracing::info!(?notification, "received notification"); - // catch cancelled notification - let mut notification = match notification.try_into() { - Ok::(cancelled) => { - if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { - tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); - ct.cancel(); - } - cancelled.into() + } + Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { + notification, + .. + })) => { + tracing::info!(?notification, "received notification"); + // catch cancelled notification + let mut notification = match notification.try_into() { + Ok::(cancelled) => { + if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { + tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); + ct.cancel(); } - Err(notification) => notification, + cancelled.into() + } + Err(notification) => notification, + }; + { + let service = shared_service.clone(); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + std::mem::swap(&mut extensions, notification.extensions_mut()); + std::mem::swap(&mut meta, notification.get_meta_mut()); + let context = NotificationContext { + peer: peer.clone(), + meta, + extensions, }; - { - let service = shared_service.clone(); - let mut extensions = Extensions::new(); - let mut meta = Meta::new(); - // avoid clone - std::mem::swap(&mut extensions, notification.extensions_mut()); - std::mem::swap(&mut meta, notification.get_meta_mut()); - let context = NotificationContext { - peer: peer.clone(), - meta, - extensions, - }; - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let result = service.handle_notification(notification, context).await; - if let Err(error) = result { - tracing::warn!(%error, "Error sending notification"); - } - }.instrument(current_span)); + let current_span = tracing::Span::current(); + let work = async move { + let result = service.handle_notification(notification, context).await; + if let Err(error) = result { + tracing::warn!(%error, "Error sending notification"); + } } + .instrument(current_span) + .boxed(); + side_effects_set.push(work); } - Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { - result, - id, - .. - })) => { - if let Some(responder) = local_responder_pool.remove(&id) { - let response_result = responder.send(Ok(result)); - if let Err(_error) = response_result { - tracing::warn!(%id, "Error sending response"); - } + } + Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { + result, id, .. + })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let response_result = responder.send(Ok(result)); + if let Err(_error) = response_result { + tracing::warn!(%id, "Error sending response"); } } - Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { - if let Some(responder) = local_responder_pool.remove(&id) { - let _response_result = responder.send(Err(ServiceError::McpError(error))); - if let Err(_error) = _response_result { - tracing::warn!(%id, "Error sending response"); - } + } + Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let _response_result = responder.send(Err(ServiceError::McpError(error))); + if let Err(_error) = _response_result { + tracing::warn!(%id, "Error sending response"); } } } - }; - let sink_close_result = transport.close().await; - if let Err(e) = sink_close_result { - tracing::error!(%e, "fail to close sink"); } - tracing::info!(?quit_reason, "serve finished"); - quit_reason - }.instrument(current_span)); - RunningService { - service, - peer: peer_return, - handle: Some(handle), - cancellation_token: ct.clone(), - dg: ct.drop_guard(), + }; + let sink_close_result = transport.close().await; + if let Err(e) = sink_close_result { + tracing::error!(%e, "fail to close sink"); } + tracing::info!(?quit_reason, "serve finished"); + quit_reason } diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 837fafeff..5c9717913 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -161,7 +161,12 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, ClientInitializeError>> + Send + ) -> impl Future< + Output = Result< + (RunningService, impl Future), + ClientInitializeError, + >, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -174,7 +179,7 @@ impl> ServiceExt for S { pub async fn serve_client( service: S, transport: T, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: IntoTransport, @@ -187,7 +192,7 @@ pub async fn serve_client_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: IntoTransport, @@ -205,7 +210,7 @@ async fn serve_client_with_ct_inner( service: S, transport: T, ct: CancellationToken, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: Transport + 'static, @@ -263,6 +268,7 @@ where transport.send(notification).await.map_err(|error| { ClientInitializeError::transport::(error, "send initialized notification") })?; + let peer_rx = ReceiverStream::new(peer_rx); Ok(serve_inner(service, transport, peer, peer_rx, ct)) } diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 5f54f3dcd..b19a16024 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -94,7 +94,12 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, ServerInitializeError>> + Send + ) -> impl Future< + Output = Result< + (RunningService, impl Future), + ServerInitializeError, + >, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -107,7 +112,7 @@ impl> ServiceExt for S { pub async fn serve_server( service: S, transport: T, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: IntoTransport, @@ -166,7 +171,7 @@ pub async fn serve_server_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: IntoTransport, @@ -180,11 +185,14 @@ where } } +/// Performs handshake and initial protocol setup through the transport, +/// and returns a [RunningService] with a separate work future that will +/// need polled to run the service. async fn serve_server_with_ct_inner( service: S, transport: T, ct: CancellationToken, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: Transport + 'static, @@ -258,6 +266,7 @@ where peer: peer.clone(), }; let _ = service.handle_notification(notification, context).await; + let peer_rx = ReceiverStream::new(peer_rx); // Continue processing service Ok(serve_inner(service, transport, peer, peer_rx, ct)) } diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs index 774c542f8..7cc815753 100644 --- a/crates/rmcp/src/task_manager.rs +++ b/crates/rmcp/src/task_manager.rs @@ -1,6 +1,10 @@ -use std::{any::Any, collections::HashMap, pin::Pin}; +use std::{any::Any, collections::HashMap}; -use futures::Future; +use futures::{ + Future, FutureExt, StreamExt, + future::abortable, + stream::{AbortHandle, FuturesUnordered}, +}; use tokio::{ sync::mpsc, time::{Duration, timeout}, @@ -11,11 +15,14 @@ use crate::{ error::{ErrorData as McpError, RmcpError as Error}, model::{CallToolResult, ClientRequest}, service::RequestContext, + util::PinnedFuture, }; +/// Result of running an operation +pub type OperationResult = Result, Error>; + /// Boxed future that represents an asynchronous operation managed by the processor. -pub type OperationFuture = - Pin, Error>> + Send>>; +pub type OperationFuture<'a> = PinnedFuture<'a, OperationResult>; /// Describes metadata associated with an enqueued task. #[derive(Debug, Clone)] @@ -57,11 +64,11 @@ impl OperationDescriptor { /// Operation message describing a unit of asynchronous work. pub struct OperationMessage { pub descriptor: OperationDescriptor, - pub future: OperationFuture, + pub future: OperationFuture<'static>, } impl OperationMessage { - pub fn new(descriptor: OperationDescriptor, future: OperationFuture) -> Self { + pub fn new(descriptor: OperationDescriptor, future: OperationFuture<'static>) -> Self { Self { descriptor, future } } } @@ -80,17 +87,23 @@ pub struct OperationProcessor { running_tasks: HashMap, /// Completed results waiting to be collected completed_results: Vec, + /// Receiver for asynchronously completed task results. Used + /// to collect back into `completed_results` task_result_receiver: mpsc::UnboundedReceiver, - task_result_sender: mpsc::UnboundedSender, + /// Sender to spawn futures on the worker task associated with this + /// processor. The worker future is created as part of [OperationProcessor::new] + spawn_tx: mpsc::UnboundedSender<(OperationDescriptor, OperationFuture<'static>)>, } +/// A handle to a running operation. struct RunningTask { - task_handle: tokio::task::JoinHandle<()>, + task_handle: AbortHandle, started_at: std::time::Instant, timeout: Option, descriptor: OperationDescriptor, } +/// The result of a running operation. pub struct TaskResult { pub descriptor: OperationDescriptor, pub result: Result, Error>, @@ -126,21 +139,63 @@ impl OperationResultTransport for ToolCallTaskResult { } } -impl Default for OperationProcessor { - fn default() -> Self { - Self::new() - } -} - impl OperationProcessor { - pub fn new() -> Self { + /// Create a new operation processor. + /// + /// This function will return the new [OperationProcessor] + /// facade you can use to queue operations, and also a future + /// that must be polled to handle these operations. + /// + /// Spawn the work function on your runtime of choice, or poll it + /// manually. + pub fn new() -> (Self, impl Future) { let (task_result_sender, task_result_receiver) = mpsc::unbounded_channel(); - Self { + let (spawn_tx, mut spawn_rx) = + mpsc::unbounded_channel::<(OperationDescriptor, OperationFuture)>(); + + let work = async move { + let mut work_set = + FuturesUnordered::>::new(); + + // Loop and listen for new operations incoming that need to be added to the future pool, + // and also listen to operation completions via the future pool. + loop { + tokio::select! { + spawn_req = spawn_rx.recv(), if !spawn_rx.is_closed() => { + if let Some((descriptor, fut)) = spawn_req { + // Map the future back to a descriptor and result tuple + let operation_work = fut.map(|result| (descriptor, result)).boxed(); + // Add it to the set we are polling + work_set.push(operation_work); + } + }, + operation_result = work_set.next(), if !work_set.is_empty() => { + if let Some((descriptor, result)) = operation_result { + match task_result_sender.send(TaskResult { descriptor, result }) { + Err(e) => { + tracing::error!("Failed to send completed task result: {e}"); + } + _ => {} + } + }; + }, + else => { + // Work was empty, and spawn channel was closed. Time + // to break the loop. + break; + } + } + } + }; + + let this = Self { running_tasks: HashMap::new(), completed_results: Vec::new(), task_result_receiver, - task_result_sender, - } + spawn_tx, + }; + + (this, work) } /// Submit an operation for asynchronous execution. @@ -159,12 +214,11 @@ impl OperationProcessor { Ok(()) } + /// Spawns an operation to be executed to completion. fn spawn_async_task(&mut self, message: OperationMessage) { let OperationMessage { descriptor, future } = message; let task_id = descriptor.operation_id.clone(); let timeout_secs = descriptor.ttl.or(Some(DEFAULT_TASK_TIMEOUT_SECS)); - let sender = self.task_result_sender.clone(); - let descriptor_for_result = descriptor.clone(); let timed_future = async move { if let Some(secs) = timeout_secs { @@ -177,16 +231,32 @@ impl OperationProcessor { } }; - let handle = tokio::spawn(async move { - let result = timed_future.await; - let task_result = TaskResult { - descriptor: descriptor_for_result, - result, - }; - let _ = sender.send(task_result); + // Below, we want to give the user a handle to the long-running operation, + // but we don't want to send the result to the user's handle. Rather the + // result gets consumed in the worker task created in the `Self::new` + // function. So here we will use the `Abortable` future utility. + let (work, abort_handle) = abortable(timed_future); + + // Map the error type of abortion (for now) + let work = work.map(|result| { + match result { + // Was not aborted, true operation result + Ok(inner_result) => inner_result, + // Was aborted, flatten to expected error type + Err(e) => Err(Error::TaskError(e.to_string())), + } }); + + // Then send the work to be executed + match self.spawn_tx.send((descriptor.clone(), work.boxed())) { + Ok(_) => {} + Err(e) => { + tracing::error!("Failed to spawn task on worker: {e}"); + } + } + let running_task = RunningTask { - task_handle: handle, + task_handle: abort_handle, started_at: std::time::Instant::now(), timeout: timeout_secs, descriptor, diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index d7dfa9790..de12321cf 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -84,7 +84,9 @@ pub use worker::WorkerTransport; #[cfg(feature = "transport-child-process")] pub mod child_process; #[cfg(feature = "transport-child-process")] -pub use child_process::{ConfigureCommandExt, TokioChildProcess}; +pub use child_process::builder::CommandBuilder; +#[cfg(feature = "transport-child-process")] +pub use child_process::runner::{ChildProcess, ChildProcessInstance, ChildProcessRunner}; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/async_rw.rs b/crates/rmcp/src/transport/async_rw.rs index ff4ecc65b..cb5ea75d7 100644 --- a/crates/rmcp/src/transport/async_rw.rs +++ b/crates/rmcp/src/transport/async_rw.rs @@ -43,9 +43,13 @@ where } pub type TransportWriter = FramedWrite>>; +pub type TransportReader = FramedRead>>; pub struct AsyncRwTransport { - read: FramedRead>>, + read: TransportReader, + /// This is behind a mutex so that concurrent writes can happen. + /// Naturally, the mutex will block parallel writes, but allow + /// multiple futures to be executed at once, even if some are waiting. write: Arc>>>, } diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index e33800b18..44db4a45d 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -1,309 +1,6 @@ -use std::process::Stdio; +pub mod builder; +pub mod runner; +pub mod transport; -use futures::future::Future; -use process_wrap::tokio::{ChildWrapper, CommandWrap}; -use tokio::{ - io::AsyncRead, - process::{ChildStderr, ChildStdin, ChildStdout}, -}; - -use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport}; -use crate::RoleClient; - -const MAX_WAIT_ON_DROP_SECS: u64 = 3; -/// The parts of a child process. -type ChildProcessParts = ( - Box, - ChildStdout, - ChildStdin, - Option, -); - -/// Extract the stdio handles from a spawned child. -/// Returns `(child, stdout, stdin, stderr)` where `stderr` is `Some` only -/// if the process was spawned with `Stdio::piped()`. -#[inline] -fn child_process(mut child: Box) -> std::io::Result { - let child_stdin = match child.inner_mut().stdin().take() { - Some(stdin) => stdin, - None => return Err(std::io::Error::other("stdin was already taken")), - }; - let child_stdout = match child.inner_mut().stdout().take() { - Some(stdout) => stdout, - None => return Err(std::io::Error::other("stdout was already taken")), - }; - let child_stderr = child.inner_mut().stderr().take(); - Ok((child, child_stdout, child_stdin, child_stderr)) -} - -pub struct TokioChildProcess { - child: ChildWithCleanup, - transport: AsyncRwTransport, -} - -pub struct ChildWithCleanup { - inner: Option>, -} - -impl Drop for ChildWithCleanup { - fn drop(&mut self) { - // We should not use start_kill(), instead we should use kill() to avoid zombies - if let Some(mut inner) = self.inner.take() { - // We don't care about the result, just try to kill it - tokio::spawn(async move { - if let Err(e) = Box::into_pin(inner.kill()).await { - tracing::warn!("Error killing child process: {}", e); - } - }); - } - } -} - -// we hold the child process with stdout, for it's easier to implement AsyncRead -pin_project_lite::pin_project! { - pub struct TokioChildProcessOut { - child: ChildWithCleanup, - #[pin] - child_stdout: ChildStdout, - } -} - -impl TokioChildProcessOut { - /// Get the process ID of the child process. - pub fn id(&self) -> Option { - self.child.inner.as_ref()?.id() - } -} - -impl AsyncRead for TokioChildProcessOut { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - self.project().child_stdout.poll_read(cx, buf) - } -} - -impl TokioChildProcess { - /// Convenience: spawn with default `piped` stdio - pub fn new(command: impl Into) -> std::io::Result { - let (proc, _ignored) = TokioChildProcessBuilder::new(command).spawn()?; - Ok(proc) - } - - /// Builder entry-point allowing fine-grained stdio control. - pub fn builder(command: impl Into) -> TokioChildProcessBuilder { - TokioChildProcessBuilder::new(command) - } - - /// Get the process ID of the child process. - pub fn id(&self) -> Option { - self.child.inner.as_ref()?.id() - } - - /// Gracefully shutdown the child process - /// - /// This will first close the transport to the child process (the server), - /// and wait for the child process to exit normally with a timeout. - /// If the child process doesn't exit within the timeout, it will be killed. - pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { - if let Some(mut child) = self.child.inner.take() { - self.transport.close().await?; - - let wait_fut = child.wait(); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => { - if let Err(e) = Box::into_pin(child.kill()).await { - tracing::warn!("Error killing child: {e}"); - return Err(e); - } - }, - res = wait_fut => { - match res { - Ok(status) => { - tracing::info!("Child exited gracefully {}", status); - } - Err(e) => { - tracing::warn!("Error waiting for child: {e}"); - return Err(e); - } - } - } - } - } - Ok(()) - } - - /// Take ownership of the inner child process - pub fn into_inner(mut self) -> Option> { - self.child.inner.take() - } - - /// Split this helper into a reader (stdout) and writer (stdin). - #[deprecated( - since = "0.5.0", - note = "use the Transport trait implementation instead" - )] - pub fn split(self) -> (TokioChildProcessOut, ChildStdin) { - unimplemented!("This method is deprecated, use the Transport trait implementation instead"); - } -} - -/// Builder for `TokioChildProcess` allowing custom `Stdio` configuration. -pub struct TokioChildProcessBuilder { - cmd: CommandWrap, - stdin: Stdio, - stdout: Stdio, - stderr: Stdio, -} - -impl TokioChildProcessBuilder { - fn new(cmd: impl Into) -> Self { - Self { - cmd: cmd.into(), - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: Stdio::inherit(), - } - } - - /// Override the child stdin configuration. - pub fn stdin(mut self, io: impl Into) -> Self { - self.stdin = io.into(); - self - } - /// Override the child stdout configuration. - pub fn stdout(mut self, io: impl Into) -> Self { - self.stdout = io.into(); - self - } - /// Override the child stderr configuration. - pub fn stderr(mut self, io: impl Into) -> Self { - self.stderr = io.into(); - self - } - - /// Spawn the child process. Returns the transport plus an optional captured stderr handle. - pub fn spawn(mut self) -> std::io::Result<(TokioChildProcess, Option)> { - self.cmd - .command_mut() - .stdin(self.stdin) - .stdout(self.stdout) - .stderr(self.stderr); - - let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?; - - let transport = AsyncRwTransport::new(stdout, stdin); - let proc = TokioChildProcess { - child: ChildWithCleanup { inner: Some(child) }, - transport, - }; - Ok((proc, stderr_opt)) - } -} - -impl Transport for TokioChildProcess { - type Error = std::io::Error; - - fn send( - &mut self, - item: TxJsonRpcMessage, - ) -> impl Future> + Send + 'static { - self.transport.send(item) - } - - fn receive(&mut self) -> impl Future>> + Send { - self.transport.receive() - } - - fn close(&mut self) -> impl Future> + Send { - self.graceful_shutdown() - } -} - -pub trait ConfigureCommandExt { - fn configure(self, f: impl FnOnce(&mut Self)) -> Self; -} - -impl ConfigureCommandExt for tokio::process::Command { - fn configure(mut self, f: impl FnOnce(&mut Self)) -> Self { - f(&mut self); - self - } -} - -#[cfg(unix)] -#[cfg(test)] -mod tests { - use tokio::process::Command; - - use super::*; - - #[tokio::test] - async fn test_tokio_child_process_drop() { - let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { - cmd.arg("30"); - })); - assert!(r.is_ok()); - let child_process = r.unwrap(); - let id = child_process.id(); - assert!(id.is_some()); - let id = id.unwrap(); - // Drop the child process - drop(child_process); - // Wait a moment to allow the cleanup task to run - tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; - // Check if the process is still running - let status = Command::new("ps") - .arg("-p") - .arg(id.to_string()) - .status() - .await; - match status { - Ok(status) => { - assert!( - !status.success(), - "Process with PID {} is still running", - id - ); - } - Err(e) => { - panic!("Failed to check process status: {}", e); - } - } - } - - #[tokio::test] - async fn test_tokio_child_process_graceful_shutdown() { - let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { - cmd.arg("30"); - })); - assert!(r.is_ok()); - let mut child_process = r.unwrap(); - let id = child_process.id(); - assert!(id.is_some()); - let id = id.unwrap(); - child_process.graceful_shutdown().await.unwrap(); - // Wait a moment to allow the cleanup task to run - tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; - // Check if the process is still running - let status = Command::new("ps") - .arg("-p") - .arg(id.to_string()) - .status() - .await; - match status { - Ok(status) => { - assert!( - !status.success(), - "Process with PID {} is still running", - id - ); - } - Err(e) => { - panic!("Failed to check process status: {}", e); - } - } - } -} +#[cfg(feature = "transport-child-process-tokio")] +pub mod tokio; diff --git a/crates/rmcp/src/transport/child_process/builder.rs b/crates/rmcp/src/transport/child_process/builder.rs new file mode 100644 index 000000000..e23303ae8 --- /dev/null +++ b/crates/rmcp/src/transport/child_process/builder.rs @@ -0,0 +1,147 @@ +use std::{collections::HashMap, path::PathBuf, process::Stdio}; + +use crate::transport::{ChildProcess, ChildProcessRunner, child_process::runner::RunnerSpawnError}; + +/// A builder for constructing a command to spawn a child process, with typical command +/// configuration options like `args` and `current_dir`. +pub struct CommandBuilder { + config: CommandConfig, + _marker: std::marker::PhantomData, +} + +#[derive(Debug, thiserror::Error)] +pub enum CommandBuilderError { + #[error("Command cannot be empty")] + EmptyCommand, +} + +impl CommandBuilder { + /// Create a CommandBuilder from an argv-style list of strings, where the first element is the command, and the rest are the args. + pub fn from_argv(argv: I) -> Result + where + I: IntoIterator, + S: Into, + { + let mut iter = argv.into_iter(); + + // Pop the first element as the command, and use the rest as args + let command = match iter.next() { + Some(cmd) => cmd.into(), + None => return Err(CommandBuilderError::EmptyCommand), + }; + + let args = iter.map(|s| s.into()).collect(); + Ok(Self { + config: CommandConfig { + command, + args, + ..Default::default() + }, + _marker: std::marker::PhantomData, + }) + } + + /// Create a CommandBuilder from a command and an optional list of args. + pub fn new(command: impl Into) -> Self { + Self { + config: CommandConfig { + command: command.into(), + ..Default::default() + }, + _marker: std::marker::PhantomData, + } + } + + /// Add a single argument to the command. + pub fn arg(mut self, arg: impl Into) -> Self { + self.config.args.push(arg.into()); + self + } + + /// Add multiple arguments to the command. + pub fn args(mut self, args: impl IntoIterator>) -> Self { + self.config + .args + .extend(args.into_iter().map(|arg| arg.into())); + self + } + + /// Set an environment variable for the command. + pub fn env(mut self, key: impl Into, value: impl Into) -> Self { + self.config.env.insert(key.into(), value.into()); + self + } + + /// Set multiple environment variables for the command. + pub fn envs( + mut self, + envs: impl IntoIterator, impl Into)>, + ) -> Self { + self.config + .env + .extend(envs.into_iter().map(|(k, v)| (k.into(), v.into()))); + self + } + + /// Sets what happens to stderr for the command. + /// By default if not set, stderr is inherited from the parent process. + pub fn stderr(mut self, _stdio: Stdio) -> Self { + self.config.stdio_config.stderr = _stdio; + self + } + + pub fn current_dir(mut self, cwd: impl Into) -> Self { + self.config.cwd = Some(cwd.into()); + self + } +} + +/// A structure that requests how the child process streams should +/// be configured when spawning. +#[derive(Debug)] +pub struct StdioConfig { + pub stdin: Stdio, + pub stdout: Stdio, + pub stderr: Stdio, +} + +impl Default for StdioConfig { + fn default() -> Self { + Self { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + } + } +} + +/// A structure that requests how the command should be executed +#[derive(Debug, Default)] +pub struct CommandConfig { + pub command: String, + pub args: Vec, + pub cwd: Option, + pub stdio_config: StdioConfig, + pub env: HashMap, +} + +impl CommandBuilder +where + R: ChildProcessRunner, +{ + /// Spawn the command into its typed child process instance type. + pub fn spawn_raw(self) -> Result { + R::spawn(self.config) + } + + /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. + /// + /// Requires `R::Instance` to be [Send] and `'static`. + pub fn spawn_dyn(self) -> Result + where + R::Instance: Send + 'static, + { + let instance = self.spawn_raw()?; + Ok(ChildProcess::new(instance)) + } +} diff --git a/crates/rmcp/src/transport/child_process/runner.rs b/crates/rmcp/src/transport/child_process/runner.rs new file mode 100644 index 000000000..bc457bb21 --- /dev/null +++ b/crates/rmcp/src/transport/child_process/runner.rs @@ -0,0 +1,212 @@ +use futures::{ + FutureExt, + io::{AsyncRead, AsyncWrite}, +}; + +use crate::{transport::child_process::builder::CommandConfig, util::PinnedFuture}; + +/// A simple enum for describing if a stream is available, unused, or already taken. +#[derive(Debug)] +pub enum StreamSlot { + /// The stream is not used in this implementation. + Unused, + /// The stream is available for use, and can be taken. + Available(S), + /// The stream has already been taken, and is no longer available. + Taken, +} + +impl From> for Option { + fn from(slot: StreamSlot) -> Self { + match slot { + StreamSlot::Unused => None, + StreamSlot::Available(s) => Some(s), + StreamSlot::Taken => None, + } + } +} + +/// The contract for what an instance of a child process +/// must provide to be used with a transport. +pub trait ChildProcessInstance { + /// The input stream for the command + type Stdin: AsyncWrite + Unpin + Send; + + /// The output stream of the command + type Stdout: AsyncRead + Unpin + Send; + + /// The error stream of the command + type Stderr: AsyncRead + Unpin + Send; + + fn take_stdin(&mut self) -> StreamSlot; + fn take_stdout(&mut self) -> StreamSlot; + fn take_stderr(&mut self) -> StreamSlot; + + fn pid(&self) -> u32; + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's; + fn graceful_shutdown<'s>(&'s mut self) + -> impl Future> + Send + 's; + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's; +} + +/// A subset of functionality of [ChildProcessInstance] that only includes the +/// functions used to control or wait for the process. +pub trait ChildProcessControl { + fn pid(&self) -> u32; + fn wait<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result>; + fn graceful_shutdown<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>>; + fn kill<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>>; +} + +/// Auto-implement ChildProcessControl for any ChildProcessInstance, since it has all the required methods. +impl ChildProcessControl for T +where + T: ChildProcessInstance, +{ + fn pid(&self) -> u32 { + ChildProcessInstance::pid(self) + } + + fn wait<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result> { + ChildProcessInstance::wait(self).boxed() + } + + fn graceful_shutdown<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>> { + ChildProcessInstance::graceful_shutdown(self).boxed() + } + + fn kill<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>> { + ChildProcessInstance::kill(self).boxed() + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RunnerSpawnError { + /// The child process instance failed to spawn. + #[error("Failed to spawn child process: {0}")] + SpawnError(#[from] std::io::Error), + /// The child process instance did not have a PID assigned (this is unexpected for a spawned process). + #[error("Child process did not have a PID assigned after spawning")] + NoPidAssigned, + #[error("Other error: {0}")] + Other(Box), +} + +pub trait ChildProcessRunner { + /// The implementation of the child process instance that this runner will spawn. + type Instance: ChildProcessInstance; + + fn spawn(command_config: CommandConfig) -> Result; +} + +/// A containing wrapper around a child process instance. This struct erases the type +/// by extracting some parts of the [ChildProcessInstance] trait into a common struct, +/// and then only exposes the control methods. +pub struct ChildProcess { + stdin: StreamSlot>, + stdout: StreamSlot>, + stderr: StreamSlot>, + inner: Box, +} + +impl ChildProcess { + pub fn new(mut instance: C) -> Self + where + C: ChildProcessInstance + Send + 'static, + { + Self { + stdin: match instance.take_stdin() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stdin stream was already taken during ChildProcess construction") + } + }, + stdout: match instance.take_stdout() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stdout stream was already taken during ChildProcess construction") + } + }, + stderr: match instance.take_stderr() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stderr stream was already taken during ChildProcess construction") + } + }, + inner: Box::new(instance), + } + } + + pub fn split( + self, + ) -> ( + Option>, + Option>, + Option>, + Box, + ) { + ( + self.stdout.into(), + self.stdin.into(), + self.stderr.into(), + self.inner, + ) + } +} + +impl ChildProcessInstance for ChildProcess { + type Stdin = Box; + + type Stdout = Box; + + type Stderr = Box; + + fn take_stdin(&mut self) -> StreamSlot { + match self.stdin { + StreamSlot::Available(_) => std::mem::replace(&mut self.stdin, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn take_stdout(&mut self) -> StreamSlot { + match self.stdout { + StreamSlot::Available(_) => std::mem::replace(&mut self.stdout, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn take_stderr(&mut self) -> StreamSlot { + match self.stderr { + StreamSlot::Available(_) => std::mem::replace(&mut self.stderr, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn pid(&self) -> u32 { + self.inner.pid() + } + + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + self.inner.wait() + } + + fn graceful_shutdown<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + self.inner.graceful_shutdown() + } + + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's { + self.inner.kill() + } +} diff --git a/crates/rmcp/src/transport/child_process/tokio.rs b/crates/rmcp/src/transport/child_process/tokio.rs new file mode 100644 index 000000000..db835afb9 --- /dev/null +++ b/crates/rmcp/src/transport/child_process/tokio.rs @@ -0,0 +1,158 @@ +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +use crate::transport::child_process::{ + builder::CommandConfig, + runner::{ChildProcessInstance, ChildProcessRunner, RunnerSpawnError}, +}; + +pub struct TokioChildProcessRunner {} + +/// An implementation for the tokio Child Process +pub struct TokioChildProcess { + inner: tokio::process::Child, + /// The PID at the time of spawning. + pid: u32, +} + +impl ChildProcessInstance for TokioChildProcess { + type Stdin = Compat; + + type Stdout = Compat; + + type Stderr = Compat; + + fn take_stdin(&mut self) -> super::runner::StreamSlot { + match self.inner.stdin.take() { + Some(stdin) => super::runner::StreamSlot::Available(stdin.compat_write()), + None => super::runner::StreamSlot::Unused, + } + } + + fn take_stdout(&mut self) -> super::runner::StreamSlot { + match self.inner.stdout.take() { + Some(stdout) => super::runner::StreamSlot::Available(stdout.compat()), + None => super::runner::StreamSlot::Unused, + } + } + + fn take_stderr(&mut self) -> super::runner::StreamSlot { + match self.inner.stderr.take() { + Some(stderr) => super::runner::StreamSlot::Available(stderr.compat()), + None => super::runner::StreamSlot::Unused, + } + } + + fn pid(&self) -> u32 { + self.pid + } + + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + self.inner.wait() + } + + fn graceful_shutdown<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + // TODO: Implement graceful shutdown on unix with SIGTERM. And look into graceful shutdown on windows as well. + self.inner.kill() + } + + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's { + self.inner.kill() + } +} + +impl ChildProcessRunner for TokioChildProcessRunner { + type Instance = TokioChildProcess; + fn spawn(command_config: CommandConfig) -> Result { + tokio::process::Command::new(command_config.command) + .args(command_config.args) + .envs(command_config.env) + .stdin(command_config.stdio_config.stdin) + .stdout(command_config.stdio_config.stdout) + .stderr(command_config.stdio_config.stderr) + .current_dir( + command_config + .cwd + .unwrap_or_else(|| std::env::current_dir().unwrap()), + ) + .kill_on_drop(true) + .spawn() + .map_err(RunnerSpawnError::SpawnError) + .and_then(|child| { + let pid = child.id().ok_or_else(|| RunnerSpawnError::NoPidAssigned)?; + Ok(TokioChildProcess { inner: child, pid }) + }) + } +} + +#[cfg(test)] +mod test { + + use crate::transport::CommandBuilder; + use tokio::process::Command; + + use super::*; + + async fn check_pid(pid: u32) -> std::io::Result { + // This command will output only process numbers on each line. + let output = Command::new("ps") + .arg("-o") + .arg("pid=") + .arg("-p") + .arg(pid.to_string()) + .output() + .await?; + + let output_str = String::from_utf8_lossy(&output.stdout); + Ok(output_str + .lines() + .any(|line| line.trim() == pid.to_string())) + } + + #[cfg(unix)] + #[tokio::test] + async fn test_kill_on_drop() { + let child = CommandBuilder::::new("sleep") + .args(["10"]) + .spawn_raw() + .expect("Failed to spawn child process"); + + let pid = child.pid(); + + // Drop the child process without waiting for it to exit, which should kill it due to `kill_on_drop(true)`. + drop(child); + + // Wait a moment to ensure the process has been killed. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let pid_found = check_pid(pid).await.expect("Failed to check if PID exists"); + + assert!(!pid_found, "Child process was not killed on drop"); + } + + #[tokio::test] + async fn test_graceful_shutdown() { + let mut child = CommandBuilder::::new("sleep") + .args(["10"]) + .spawn_raw() + .expect("Failed to spawn child process"); + + let pid = child.pid(); + + // Sleep a moment to ensure the process is running before we attempt to shut it down. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + child + .graceful_shutdown() + .await + .expect("Failed to gracefully shutdown child process"); + + // We should not need to wait here since we await the graceful shutdown above. + // Graceful shutdown *should* cover waiting for the process to exit. + let pid_found = check_pid(pid).await.expect("Failed to check if PID exists"); + assert!(!pid_found, "Child process was not shutdown"); + } +} diff --git a/crates/rmcp/src/transport/child_process/transport.rs b/crates/rmcp/src/transport/child_process/transport.rs new file mode 100644 index 000000000..9f46d9d93 --- /dev/null +++ b/crates/rmcp/src/transport/child_process/transport.rs @@ -0,0 +1,71 @@ +use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}; + +use crate::{ + service::ServiceRole, + transport::{ + Transport, + async_rw::AsyncRwTransport, + child_process::runner::{ChildProcess, ChildProcessControl}, + }, +}; + +pub struct ChildProcessTransport { + _child: Box, + framed_transport: AsyncRwTransport< + R, + Box, + Box, + >, +} + +impl ChildProcessTransport +where + R: ServiceRole, +{ + pub fn new(child: ChildProcess) -> Result> { + let (stdout, stdin, _stderr, control) = child.split(); + + let framed_transport: AsyncRwTransport = AsyncRwTransport::new( + Box::new( + stdout + .ok_or("Failed to capture stdout of child process")? + .compat(), + ) as Box, + Box::new( + stdin + .ok_or("Failed to capture stdin of child process")? + .compat_write(), + ) as Box, + ); + + Ok(Self { + _child: control, + framed_transport, + }) + } +} + +impl Transport for ChildProcessTransport +where + R: ServiceRole, +{ + type Error = std::io::Error; + + fn send( + &mut self, + item: crate::service::TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + self.framed_transport.send(item) + } + + fn receive( + &mut self, + ) -> impl Future>> + Send { + self.framed_transport.receive() + } + + fn close(&mut self) -> impl Future> + Send { + self.framed_transport.close() + } +} diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 74b1fd79e..f62708e8c 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -515,13 +515,15 @@ where let session_manager = self.session_manager.clone(); let session_id = session_id.clone(); async move { - let service = serve_server::( - service, transport, - ) - .await; - match service { - Ok(service) => { + let serve_result = + serve_server::( + service, transport, + ) + .await; + match serve_result { + Ok((service, work)) => { // on service created + tokio::spawn(work); let _ = service.waiting().await; } Err(e) => { @@ -593,11 +595,8 @@ where request.request.extensions_mut().insert(part); let (transport, mut receiver) = OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); - let service = serve_directly(service, transport, None); - tokio::spawn(async move { - // on service created - let _ = service.waiting().await; - }); + let (_, work) = serve_directly(service, transport, None); + tokio::spawn(work); if self.config.json_response { // JSON-direct mode: await the single response and return as // application/json, eliminating SSE framing overhead. diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs new file mode 100644 index 000000000..912e13786 --- /dev/null +++ b/crates/rmcp/src/util.rs @@ -0,0 +1,71 @@ +use futures::{Sink, Stream}; +use std::{pin::Pin, task::Poll}; + +pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; + +pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; + +pub type PinnedStream<'a, T> = Pin + Send + 'a>>; + +pub type PinnedLocalStream<'a, T> = Pin + 'a>>; + +pub enum UnboundedSenderSinkError { + SendError(tokio::sync::mpsc::error::SendError), + Closed, +} + +/// A simple [Sink] wrapper for Tokio's [tokio::sync::mpsc::UnboundedSender] +#[derive(Debug, Clone)] +pub struct UnboundedSenderSink { + sender: tokio::sync::mpsc::UnboundedSender, +} + +impl UnboundedSenderSink { + pub fn new(sender: tokio::sync::mpsc::UnboundedSender) -> Self { + Self { sender } + } +} + +impl Sink for UnboundedSenderSink { + type Error = UnboundedSenderSinkError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.sender.is_closed() { + Poll::Ready(Err(UnboundedSenderSinkError::Closed)) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let this = self.get_mut(); + match this.sender.send(item) { + Ok(_) => Ok(()), + Err(e) => Err(UnboundedSenderSinkError::SendError(e)), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // tokio's unbounded mpsc senders have no flushing required, since the + // receiver is unbounded and will get all messages we send (unless we run + // out of memory) + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Like `poll_flush`, there is nothing to wait on here. A single + // call to `mpsc_sender.send(...)` is immediate from the perspective + // of the sender + Poll::Ready(Ok(())) + } +} diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs index 521219a3b..5c5715b9c 100644 --- a/crates/rmcp/tests/test_progress_subscriber.rs +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -100,11 +100,13 @@ async fn test_progress_subscriber() -> anyhow::Result<()> { let server = MyServer::new(); let (transport_server, transport_client) = tokio::io::duplex(4096); tokio::spawn(async move { - let service = server.serve(transport_server).await?; - service.waiting().await?; + let (service, work) = server.serve(transport_server).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); - let client_service = client.serve(transport_client).await?; + let (client_service, client_work) = client.serve(transport_client).await?; + tokio::spawn(client_work); let handle = client_service .send_cancellable_request( ClientRequest::CallToolRequest(Request::new(CallToolRequestParams { diff --git a/crates/rmcp/tests/test_task.rs b/crates/rmcp/tests/test_task.rs index 9ad0b2006..c0f08de01 100644 --- a/crates/rmcp/tests/test_task.rs +++ b/crates/rmcp/tests/test_task.rs @@ -21,7 +21,10 @@ impl OperationResultTransport for DummyTransport { #[tokio::test] async fn executes_enqueued_future() { - let mut processor = OperationProcessor::new(); + let (mut processor, work) = OperationProcessor::new(); + + tokio::spawn(work); + let descriptor = OperationDescriptor::new("op1", "dummy"); let future = Box::pin(async { tokio::time::sleep(Duration::from_millis(10)).await; @@ -50,7 +53,10 @@ async fn executes_enqueued_future() { #[tokio::test] async fn rejects_duplicate_operation_ids() { - let mut processor = OperationProcessor::new(); + let (mut processor, work) = OperationProcessor::new(); + + tokio::spawn(work); + let descriptor = OperationDescriptor::new("dup", "dummy"); let future = Box::pin(async { Ok(Box::new(DummyTransport { diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index c1e5d81a6..40228b582 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -2,8 +2,11 @@ use rmcp::{ ServiceExt, service::QuitReason, transport::{ - ConfigureCommandExt, StreamableHttpClientTransport, StreamableHttpServerConfig, - TokioChildProcess, + CommandBuilder, StreamableHttpClientTransport, StreamableHttpServerConfig, + child_process::{ + runner::ChildProcessControl, tokio::TokioChildProcessRunner, + transport::ChildProcessTransport, + }, streamable_http_server::{ session::local::LocalSessionManager, tower::StreamableHttpService, }, @@ -32,18 +35,26 @@ async fn test_with_js_stdio_server() -> anyhow::Result<()> { .spawn()? .wait() .await?; - let transport = - TokioChildProcess::new(tokio::process::Command::new("node").configure(|cmd| { - cmd.arg("tests/test_with_js/server.js"); - }))?; - let client = ().serve(transport).await?; + let node_cmd = CommandBuilder::::new("node") + .args(["tests/test_with_js/server.js"]) + .spawn_dyn()?; + + tracing::info!("Spawned child process with PID: {}", node_cmd.pid()); + + let transport = ChildProcessTransport::new(node_cmd) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; + + let (client, work) = ().serve(transport).await?; + + tokio::spawn(work); + let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -124,12 +135,13 @@ async fn test_with_js_streamable_http_server() -> anyhow::Result<()> { // waiting for server up tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - let client = ().serve(transport).await?; + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - let quit_reason = client.cancel().await?; + let quit_reason = client.cancel().await; server.kill().await?; assert!(matches!(quit_reason, QuitReason::Cancelled)); Ok(()) diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 3f883c96f..074b8447d 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -1,10 +1,16 @@ use std::process::Stdio; +use futures::AsyncReadExt; use rmcp::{ ServiceExt, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + ChildProcess, ChildProcessInstance, + child_process::{ + builder::CommandBuilder, tokio::TokioChildProcessRunner, + transport::ChildProcessTransport, + }, + }, }; -use tokio::io::AsyncReadExt; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; @@ -29,18 +35,21 @@ async fn init() -> anyhow::Result<()> { async fn test_with_python_server() -> anyhow::Result<()> { init().await?; - let transport = TokioChildProcess::new(tokio::process::Command::new("uv").configure(|cmd| { - cmd.arg("run") - .arg("server.py") - .current_dir("tests/test_with_python"); - }))?; + let server_command = CommandBuilder::::new("uv") + .args(["run", "server.py"]) + .current_dir("tests/test_with_python") + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(server_command) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; - let client = ().serve(transport).await?; + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -48,15 +57,14 @@ async fn test_with_python_server() -> anyhow::Result<()> { async fn test_with_python_server_stderr() -> anyhow::Result<()> { init().await?; - let (transport, stderr) = - TokioChildProcess::builder(tokio::process::Command::new("uv").configure(|cmd| { - cmd.arg("run") - .arg("server.py") - .current_dir("tests/test_with_python"); - })) + let mut server_command = CommandBuilder::::new("uv") + .args(["run", "server.py"]) + .current_dir("tests/test_with_python") .stderr(Stdio::piped()) - .spawn()?; + .spawn_dyn()?; + let stderr: Option<::Stderr> = + server_command.take_stderr().into(); let mut stderr = stderr.expect("stderr must be piped"); let stderr_task = tokio::spawn(async move { @@ -65,10 +73,14 @@ async fn test_with_python_server_stderr() -> anyhow::Result<()> { Ok::<_, std::io::Error>(buffer) }); - let client = ().serve(transport).await?; + let transport = ChildProcessTransport::new(server_command) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; + + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let _ = client.list_all_resources().await?; let _ = client.list_all_tools().await?; - client.cancel().await?; + client.cancel().await; let stderr_output = stderr_task.await??; assert!(stderr_output.contains("server starting up...")); diff --git a/examples/simple-chat-client/Cargo.toml b/examples/simple-chat-client/Cargo.toml index e382e63c7..ab24878cb 100644 --- a/examples/simple-chat-client/Cargo.toml +++ b/examples/simple-chat-client/Cargo.toml @@ -17,6 +17,7 @@ toml = "1.0" rmcp = { workspace = true, features = [ "client", "transport-child-process", - "transport-streamable-http-client-reqwest" + "transport-child-process-tokio", + "transport-streamable-http-client-reqwest", ] } clap = { version = "4.0", features = ["derive"] } diff --git a/examples/simple-chat-client/src/config.rs b/examples/simple-chat-client/src/config.rs index e469c2808..946436fdf 100644 --- a/examples/simple-chat-client/src/config.rs +++ b/examples/simple-chat-client/src/config.rs @@ -1,7 +1,14 @@ -use std::{collections::HashMap, path::Path, process::Stdio}; +use std::{collections::HashMap, path::Path}; use anyhow::Result; -use rmcp::{RoleClient, ServiceExt, service::RunningService, transport::ConfigureCommandExt}; +use rmcp::{ + RoleClient, ServiceExt, + service::RunningService, + transport::{ + CommandBuilder, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -47,22 +54,25 @@ impl McpServerTransportConfig { McpServerTransportConfig::Streamable { url } => { let transport = rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); - ().serve(transport).await? + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } McpServerTransportConfig::Stdio { command, args, envs, } => { - let transport = rmcp::transport::child_process::TokioChildProcess::new( - tokio::process::Command::new(command).configure(|cmd| { - cmd.args(args) - .envs(envs) - .stderr(Stdio::inherit()) - .stdout(Stdio::inherit()); - }), - )?; - ().serve(transport).await? + let cmd = CommandBuilder::::new(command) + .args(args) + .envs(envs) + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(cmd) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } }; Ok(client)