From 4939e5de41e938af3400407f51545ab8f17b5622 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sat, 28 Feb 2026 17:44:16 +0000 Subject: [PATCH 01/11] refactor(service): remove uses of tokio::spawn --- crates/rmcp/Cargo.toml | 41 +- crates/rmcp/src/lib.rs | 1 + crates/rmcp/src/service.rs | 633 ++++++++++-------- .../transport/streamable_http_server/tower.rs | 7 +- crates/rmcp/src/util.rs | 5 + 5 files changed, 396 insertions(+), 291 deletions(-) create mode 100644 crates/rmcp/src/util.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 96c319dc..76820c81 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -19,7 +19,7 @@ async-trait = "0.1.89" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" thiserror = "2" -tokio = { version = "1", features = ["sync", "macros", "rt", "time"] } +tokio = { version = "1", features = ["sync", "macros", "time"] } futures = "0.3" tracing = { version = "0.1" } tokio-util = { version = "0.7" } @@ -109,7 +109,10 @@ client-side-sse = ["dep:sse-stream", "dep:http"] # Streamable HTTP client transport-streamable-http-client = ["client-side-sse", "transport-worker"] -transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"] +transport-streamable-http-client-reqwest = [ + "transport-streamable-http-client", + "__reqwest", +] transport-async-rw = ["tokio/io-util", "tokio-util/codec"] transport-io = ["transport-async-rw", "tokio/io-std"] @@ -135,7 +138,10 @@ schemars = ["dep:schemars"] [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "1.1.0", features = ["chrono04"] } -axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] } +axum = { version = "0.8", default-features = false, features = [ + "http1", + "tokio", +] } anyhow = "1.0" tracing-subscriber = { version = "0.3", features = [ "env-filter", @@ -150,12 +156,7 @@ path = "tests/test_tool_macros.rs" [[test]] name = "test_with_python" -required-features = [ - "reqwest", - "server", - "client", - "transport-child-process", -] +required-features = ["reqwest", "server", "client", "transport-child-process"] path = "tests/test_with_python.rs" [[test]] @@ -207,12 +208,22 @@ path = "tests/test_task.rs" [[test]] name = "test_streamable_http_priming" -required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "reqwest", +] path = "tests/test_streamable_http_priming.rs" [[test]] name = "test_streamable_http_json_response" -required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "reqwest", +] path = "tests/test_streamable_http_json_response.rs" @@ -249,5 +260,11 @@ path = "tests/test_custom_headers.rs" [[test]] name = "test_sse_concurrent_streams" -required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "transport-streamable-http-client", + "reqwest", +] path = "tests/test_sse_concurrent_streams.rs" diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 9ae3f958..456bc3ea 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -3,6 +3,7 @@ #![doc = include_str!("../README.md")] mod error; +mod util; #[allow(deprecated)] pub use error::{Error, ErrorData, RmcpError}; diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index b12839c6..0e8e95ea 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -1,5 +1,6 @@ -use futures::{FutureExt, future::BoxFuture}; +use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::FuturesUnordered}; use thiserror::Error; +use tokio_stream::wrappers::ReceiverStream; #[cfg(feature = "server")] use crate::model::ServerJsonRpcMessage; @@ -11,6 +12,7 @@ use crate::{ NumberOrString, ProgressToken, RequestId, }, transport::{DynamicTransportError, IntoTransport, Transport}, + util::PinnedFuture, }; #[cfg(feature = "client")] mod client; @@ -188,6 +190,7 @@ impl> DynService for S { use std::{ collections::{HashMap, VecDeque}, + fmt::Debug, ops::Deref, sync::{Arc, atomic::AtomicU64}, time::Duration, @@ -246,6 +249,8 @@ impl RequestHandle { pub const REQUEST_TIMEOUT_REASON: &str = "request timeout"; pub async fn await_response(self) -> Result { if let Some(timeout) = self.options.timeout { + // TODO: tokio timeout won't work if not in the tokio RT + // Find an alternative let timeout_result = tokio::time::timeout(timeout, async move { self.rx.await.map_err(|_e| ServiceError::TransportClosed)? }) @@ -426,14 +431,29 @@ impl Peer { } } -#[derive(Debug)] pub struct RunningService> { service: Arc, peer: Peer, - handle: Option>, + handle: Option>, cancellation_token: CancellationToken, dg: DropGuard, } + +impl> Debug for RunningService +where + S: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RunningService") + .field("service", &self.service) + .field("peer", &self.peer) + .field("handle", &self.handle.as_ref().map(|_| "")) + .field("cancellation_token", &self.cancellation_token) + .field("dg", &self.dg) + .finish() + } +} + impl> Deref for RunningService { type Target = Peer; @@ -467,10 +487,10 @@ impl> RunningService { /// This will block until the service loop terminates (either due to /// cancellation, transport closure, or an error). #[inline] - pub async fn waiting(mut self) -> Result { + pub async fn waiting(mut self) -> QuitReason { match self.handle.take() { Some(handle) => handle.await, - None => Ok(QuitReason::Closed), + None => QuitReason::Closed, } } @@ -491,7 +511,7 @@ impl> RunningService { /// // ... use the client ... /// client.close().await?; /// ``` - pub async fn close(&mut self) -> Result { + pub async fn close(&mut self) -> QuitReason { if let Some(handle) = self.handle.take() { // Disarm the drop guard so it doesn't try to cancel again // We need to cancel manually and wait for completion @@ -499,7 +519,7 @@ impl> RunningService { handle.await } else { // Already closed - Ok(QuitReason::Closed) + QuitReason::Closed } } @@ -511,24 +531,22 @@ impl> RunningService { /// /// Returns `Ok(Some(reason))` if shutdown completed within the timeout, /// `Ok(None)` if the timeout was reached, or `Err` if there was a join error. - pub async fn close_with_timeout( - &mut self, - timeout: Duration, - ) -> Result, tokio::task::JoinError> { + pub async fn close_with_timeout(&mut self, timeout: Duration) -> Option { if let Some(handle) = self.handle.take() { self.cancellation_token.cancel(); + // TODO: tokio timeout won't work if not in the tokio RT, find an alternative match tokio::time::timeout(timeout, handle).await { - Ok(result) => result.map(Some), + Ok(reason) => Some(reason), Err(_elapsed) => { tracing::warn!( "close_with_timeout: cleanup did not complete within {:?}", timeout ); - Ok(None) + None } } } else { - Ok(Some(QuitReason::Closed)) + Some(QuitReason::Closed) } } @@ -536,7 +554,7 @@ impl> RunningService { /// /// This consumes the `RunningService` and ensures the connection is properly /// closed. For a non-consuming alternative, see [`close`](Self::close). - pub async fn cancel(mut self) -> Result { + pub async fn cancel(mut self) -> QuitReason { // Disarm the drop guard since we're handling cancellation explicitly let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard()); self.close().await @@ -594,11 +612,20 @@ pub struct NotificationContext { } /// Use this function to skip initialization process +/// +/// TODO: What initialization process? Reference that here +/// +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. pub fn serve_directly( service: S, transport: T, peer_info: Option, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, @@ -609,12 +636,21 @@ where } /// Use this function to skip initialization process +/// +/// TODO: What initialization process? Reference that here +/// +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. pub fn serve_directly_with_ct( service: S, transport: T, peer_info: Option, ct: CancellationToken, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, @@ -622,25 +658,30 @@ where E: std::error::Error + Send + Sync + 'static, { let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info); + let peer_rx = ReceiverStream::new(peer_rx); serve_inner(service, transport.into_transport(), peer, peer_rx, ct) } +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. #[instrument(skip_all)] -fn serve_inner( +fn serve_inner( service: S, transport: T, peer: Peer, - mut peer_rx: tokio::sync::mpsc::Receiver>, + peer_rx: PeerStream, ct: CancellationToken, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, T: Transport + 'static, + PeerStream: Stream> + Unpin, { - const SINK_PROXY_BUFFER_SIZE: usize = 64; - let (sink_proxy_tx, mut sink_proxy_rx) = - tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); let peer_info = peer.peer_info(); if R::IS_CLIENT { tracing::info!(?peer_info, "Service initialized as client"); @@ -648,9 +689,6 @@ where tracing::info!(?peer_info, "Service initialized as server"); } - let mut local_responder_pool = - HashMap::>>::new(); - let mut local_ct_pool = HashMap::::new(); let shared_service = Arc::new(service); // for return let service = shared_service.clone(); @@ -658,283 +696,330 @@ where // let message_sink = tokio::sync:: // let mut stream = std::pin::pin!(stream); let serve_loop_ct = ct.child_token(); - let peer_return: Peer = peer.clone(); + let peer_return = peer.clone(); let current_span = tracing::Span::current(); - let handle = tokio::spawn(async move { - let mut transport = transport.into_transport(); - let mut batch_messages = VecDeque::>::new(); - let mut send_task_set = tokio::task::JoinSet::::new(); - #[derive(Debug)] - enum SendTaskResult { - Request { - id: RequestId, - result: Result<(), DynamicTransportError>, - }, - Notification { - responder: Responder>, - cancellation_param: Option, - result: Result<(), DynamicTransportError>, - }, - } - #[derive(Debug)] - enum Event { - ProxyMessage(PeerSinkMessage), - PeerMessage(RxJsonRpcMessage), - ToSink(TxJsonRpcMessage), - SendTaskResult(SendTaskResult), - } - let quit_reason = loop { - let evt = if let Some(m) = batch_messages.pop_front() { - Event::PeerMessage(m) - } else { - tokio::select! { - m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => { - if let Some(m) = m { - Event::ToSink(m) - } else { - continue - } - } - m = transport.receive() => { - if let Some(m) = m { - Event::PeerMessage(m) - } else { - // input stream closed - tracing::info!("input stream terminated"); - break QuitReason::Closed - } - } - m = peer_rx.recv(), if !peer_rx.is_closed() => { - if let Some(m) = m { - Event::ProxyMessage(m) - } else { - continue - } - } - m = send_task_set.join_next(), if !send_task_set.is_empty() => { - let Some(result) = m else { - continue - }; - match result { - Err(e) => { - // join error, which is serious, we should quit. - tracing::error!(%e, "send request task encounter a tokio join error"); - break QuitReason::JoinError(e) - } - Ok(result) => { - Event::SendTaskResult(result) - } - } - } - _ = serve_loop_ct.cancelled() => { - tracing::info!("task cancelled"); - break QuitReason::Cancelled + let work = controller(transport, peer_rx, serve_loop_ct, shared_service, peer) + .instrument(current_span); + + let (work, work_handle) = work.remote_handle(); + // If the handle is dropped, don't stop the work. + // We don't want to force the user to keep the `RunningService` + // struct alive just to keep the work running (since the work + // future will be explicitly managed by the caller) + work_handle.forget(); + + let running_service = RunningService { + service, + peer: peer_return, + handle: Some(work_handle.boxed()), + cancellation_token: ct.clone(), + dg: ct.drop_guard(), + }; + + (running_service, work) +} + +/// Main business logic for event dispatching and handling. +async fn controller( + transport: T, + peer_rx: PeerStream, + cancel_token: CancellationToken, + shared_service: Arc>, + peer: Peer, +) -> QuitReason +where + R: ServiceRole, + T: Transport + 'static, + PeerStream: Stream> + Unpin, +{ + let mut transport = transport.into_transport(); + let mut batch_messages = VecDeque::>::new(); + let mut send_task_set = FuturesUnordered::>::new(); + let mut side_effects_set = FuturesUnordered::>::new(); + + let mut local_responder_pool = + HashMap::>>::new(); + let mut local_ct_pool = HashMap::::new(); + + const SINK_PROXY_BUFFER_SIZE: usize = 64; + let (sink_proxy_tx, mut rpc_rx) = + tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); + + // Fuse the stream, so that once it return `None` it is guaranteed to never + // be polled again. Additionally, we can check if it have been fused by checking + // `is_done()`, which we use in the select branches below. + let mut peer_rx = peer_rx.fuse(); + + #[derive(Debug)] + enum SendTaskResult { + Request { + id: RequestId, + result: Result<(), DynamicTransportError>, + }, + Notification { + responder: Responder>, + cancellation_param: Option, + result: Result<(), DynamicTransportError>, + }, + } + #[derive(Debug)] + enum Event { + ProxyMessage(PeerSinkMessage), + PeerMessage(RxJsonRpcMessage), + ToSink(TxJsonRpcMessage), + SendTaskResult(SendTaskResult), + } + + let quit_reason = loop { + // Prioritize processing batch messages before other things + let evt = if let Some(m) = batch_messages.pop_front() { + Event::PeerMessage(m) + } else { + tokio::select! { + m = rpc_rx.recv(), if !rpc_rx.is_closed() => { + if let Some(m) = m { + Event::ToSink(m) + } else { + continue } } - }; - - tracing::trace!(?evt, "new event"); - match evt { - Event::SendTaskResult(SendTaskResult::Request { id, result }) => { - if let Err(e) = result { - if let Some(responder) = local_responder_pool.remove(&id) { - let _ = responder.send(Err(ServiceError::TransportSend(e))); - } + m = transport.receive() => { + if let Some(m) = m { + Event::PeerMessage(m) + } else { + // input stream closed + tracing::info!("input stream terminated"); + break QuitReason::Closed } } - Event::SendTaskResult(SendTaskResult::Notification { - responder, - result, - cancellation_param, - }) => { - let response = if let Err(e) = result { - Err(ServiceError::TransportSend(e)) + m = peer_rx.next(), if !peer_rx.is_done() => { + if let Some(m) = m { + Event::ProxyMessage(m) } else { - Ok(()) - }; - let _ = responder.send(response); - if let Some(param) = cancellation_param { - if let Some(responder) = local_responder_pool.remove(¶m.request_id) { - tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); - let _response_result = responder.send(Err(ServiceError::Cancelled { - reason: param.reason.clone(), - })); - } + continue } } - // response and error - Event::ToSink(m) => { - if let Some(id) = match &m { - JsonRpcMessage::Response(response) => Some(&response.id), - JsonRpcMessage::Error(error) => Some(&error.id), - _ => None, - } { - if let Some(ct) = local_ct_pool.remove(id) { - ct.cancel(); - } - let send = transport.send(m); - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let send_result = send.await; - if let Err(error) = send_result { - tracing::error!(%error, "fail to response message"); - } - }.instrument(current_span)); + m = send_task_set.next(), if !send_task_set.is_empty() => { + let Some(send_result) = m else { + continue + }; + Event::SendTaskResult(send_result) + } + _ = side_effects_set.next(), if !side_effects_set.is_empty() => { + // just drive the future, we don't care about the result + continue + } + _ = cancel_token.cancelled() => { + tracing::info!("task cancelled"); + break QuitReason::Cancelled + } + } + }; + + tracing::trace!(?evt, "new event"); + match evt { + Event::SendTaskResult(SendTaskResult::Request { id, result }) => { + if let Err(e) = result { + if let Some(responder) = local_responder_pool.remove(&id) { + let _ = responder.send(Err(ServiceError::TransportSend(e))); } } - Event::ProxyMessage(PeerSinkMessage::Request { - request, - id, - responder, - }) => { - local_responder_pool.insert(id.clone(), responder); - let send = transport.send(JsonRpcMessage::request(request, id.clone())); - { - let id = id.clone(); - let current_span = tracing::Span::current(); - send_task_set.spawn(send.map(move |r| SendTaskResult::Request { - id, - result: r.map_err(DynamicTransportError::new::), - }).instrument(current_span)); + } + Event::SendTaskResult(SendTaskResult::Notification { + responder, + result, + cancellation_param, + }) => { + let response = if let Err(e) = result { + Err(ServiceError::TransportSend(e)) + } else { + Ok(()) + }; + let _ = responder.send(response); + if let Some(param) = cancellation_param { + if let Some(responder) = local_responder_pool.remove(¶m.request_id) { + tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); + let _response_result = responder.send(Err(ServiceError::Cancelled { + reason: param.reason.clone(), + })); } } - Event::ProxyMessage(PeerSinkMessage::Notification { - notification, - responder, - }) => { - // catch cancellation notification - let mut cancellation_param = None; - let notification = match notification.try_into() { - Ok::(cancelled) => { - cancellation_param.replace(cancelled.params.clone()); - cancelled.into() - } - Err(notification) => notification, - }; - let send = transport.send(JsonRpcMessage::notification(notification)); + } + // response and error + Event::ToSink(m) => { + if let Some(id) = match &m { + JsonRpcMessage::Response(response) => Some(&response.id), + JsonRpcMessage::Error(error) => Some(&error.id), + _ => None, + } { + if let Some(ct) = local_ct_pool.remove(id) { + ct.cancel(); + } + let send = transport.send(m); let current_span = tracing::Span::current(); - send_task_set.spawn(send.map(move |result| SendTaskResult::Notification { + let send_work = async move { + let send_result = send.await; + if let Err(error) = send_result { + tracing::error!(%error, "fail to response message"); + } + } + .instrument(current_span) + .boxed(); + side_effects_set.push(send_work); + } + } + Event::ProxyMessage(PeerSinkMessage::Request { + request, + id, + responder, + }) => { + local_responder_pool.insert(id.clone(), responder); + let send = transport.send(JsonRpcMessage::request(request, id.clone())); + let id = id.clone(); + let current_span = tracing::Span::current(); + + let send = send + .map(move |r| SendTaskResult::Request { + id, + result: r.map_err(DynamicTransportError::new::), + }) + .instrument(current_span) + .boxed(); + send_task_set.push(send); + } + Event::ProxyMessage(PeerSinkMessage::Notification { + notification, + responder, + }) => { + // catch cancellation notification + let mut cancellation_param = None; + let notification = match notification.try_into() { + Ok::(cancelled) => { + cancellation_param.replace(cancelled.params.clone()); + cancelled.into() + } + Err(notification) => notification, + }; + let send = transport.send(JsonRpcMessage::notification(notification)); + let current_span = tracing::Span::current(); + let send = send + .map(move |result| SendTaskResult::Notification { responder, cancellation_param, result: result.map_err(DynamicTransportError::new::), - }).instrument(current_span)); - } - Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { - id, - mut request, - .. - })) => { - tracing::debug!(%id, ?request, "received request"); - { - let service = shared_service.clone(); - let sink = sink_proxy_tx.clone(); - let request_ct = serve_loop_ct.child_token(); - let context_ct = request_ct.child_token(); - local_ct_pool.insert(id.clone(), request_ct); - let mut extensions = Extensions::new(); - let mut meta = Meta::new(); - // avoid clone - // swap meta firstly, otherwise progress token will be lost - std::mem::swap(&mut meta, request.get_meta_mut()); - std::mem::swap(&mut extensions, request.extensions_mut()); - let context = RequestContext { - ct: context_ct, - id: id.clone(), - peer: peer.clone(), - meta, - extensions, + }) + .instrument(current_span) + .boxed(); + send_task_set.push(send); + } + Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { + id, mut request, .. + })) => { + tracing::debug!(%id, ?request, "received request"); + { + let service = shared_service.clone(); + let sink = sink_proxy_tx.clone(); + let request_ct = cancel_token.child_token(); + let context_ct = request_ct.child_token(); + local_ct_pool.insert(id.clone(), request_ct); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + // swap meta firstly, otherwise progress token will be lost + std::mem::swap(&mut meta, request.get_meta_mut()); + std::mem::swap(&mut extensions, request.extensions_mut()); + let context = RequestContext { + ct: context_ct, + id: id.clone(), + peer: peer.clone(), + meta, + extensions, + }; + let current_span = tracing::Span::current(); + let work = async move { + let result = service.handle_request(request, context).await; + let response = match result { + Ok(result) => { + tracing::debug!(%id, ?result, "response message"); + JsonRpcMessage::response(result, id) + } + Err(error) => { + tracing::warn!(%id, ?error, "response error"); + JsonRpcMessage::error(error, id) + } }; - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let result = service - .handle_request(request, context) - .await; - let response = match result { - Ok(result) => { - tracing::debug!(%id, ?result, "response message"); - JsonRpcMessage::response(result, id) - } - Err(error) => { - tracing::warn!(%id, ?error, "response error"); - JsonRpcMessage::error(error, id) - } - }; - let _send_result = sink.send(response).await; - }.instrument(current_span)); + let _send_result = sink.send(response).await; } + .instrument(current_span) + .boxed(); + side_effects_set.push(work); } - Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { - notification, - .. - })) => { - tracing::info!(?notification, "received notification"); - // catch cancelled notification - let mut notification = match notification.try_into() { - Ok::(cancelled) => { - if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { - tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); - ct.cancel(); - } - cancelled.into() + } + Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { + notification, + .. + })) => { + tracing::info!(?notification, "received notification"); + // catch cancelled notification + let mut notification = match notification.try_into() { + Ok::(cancelled) => { + if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { + tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); + ct.cancel(); } - Err(notification) => notification, + cancelled.into() + } + Err(notification) => notification, + }; + { + let service = shared_service.clone(); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + std::mem::swap(&mut extensions, notification.extensions_mut()); + std::mem::swap(&mut meta, notification.get_meta_mut()); + let context = NotificationContext { + peer: peer.clone(), + meta, + extensions, }; - { - let service = shared_service.clone(); - let mut extensions = Extensions::new(); - let mut meta = Meta::new(); - // avoid clone - std::mem::swap(&mut extensions, notification.extensions_mut()); - std::mem::swap(&mut meta, notification.get_meta_mut()); - let context = NotificationContext { - peer: peer.clone(), - meta, - extensions, - }; - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let result = service.handle_notification(notification, context).await; - if let Err(error) = result { - tracing::warn!(%error, "Error sending notification"); - } - }.instrument(current_span)); + let current_span = tracing::Span::current(); + let work = async move { + let result = service.handle_notification(notification, context).await; + if let Err(error) = result { + tracing::warn!(%error, "Error sending notification"); + } } + .instrument(current_span) + .boxed(); + side_effects_set.push(work); } - Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { - result, - id, - .. - })) => { - if let Some(responder) = local_responder_pool.remove(&id) { - let response_result = responder.send(Ok(result)); - if let Err(_error) = response_result { - tracing::warn!(%id, "Error sending response"); - } + } + Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { + result, id, .. + })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let response_result = responder.send(Ok(result)); + if let Err(_error) = response_result { + tracing::warn!(%id, "Error sending response"); } } - Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { - if let Some(responder) = local_responder_pool.remove(&id) { - let _response_result = responder.send(Err(ServiceError::McpError(error))); - if let Err(_error) = _response_result { - tracing::warn!(%id, "Error sending response"); - } + } + Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let _response_result = responder.send(Err(ServiceError::McpError(error))); + if let Err(_error) = _response_result { + tracing::warn!(%id, "Error sending response"); } } } - }; - let sink_close_result = transport.close().await; - if let Err(e) = sink_close_result { - tracing::error!(%e, "fail to close sink"); } - tracing::info!(?quit_reason, "serve finished"); - quit_reason - }.instrument(current_span)); - RunningService { - service, - peer: peer_return, - handle: Some(handle), - cancellation_token: ct.clone(), - dg: ct.drop_guard(), + }; + let sink_close_result = transport.close().await; + if let Err(e) = sink_close_result { + tracing::error!(%e, "fail to close sink"); } + tracing::info!(?quit_reason, "serve finished"); + quit_reason } diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 74b1fd79..ba28a8bb 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -593,11 +593,8 @@ where request.request.extensions_mut().insert(part); let (transport, mut receiver) = OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); - let service = serve_directly(service, transport, None); - tokio::spawn(async move { - // on service created - let _ = service.waiting().await; - }); + let (_, work) = serve_directly(service, transport, None); + tokio::spawn(work); if self.config.json_response { // JSON-direct mode: await the single response and return as // application/json, eliminating SSE framing overhead. diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs new file mode 100644 index 00000000..06a563e3 --- /dev/null +++ b/crates/rmcp/src/util.rs @@ -0,0 +1,5 @@ +use std::pin::Pin; + +pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; + +pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; From 3b79366999a4131546e755ff2d20a4ae031b5e1e Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sat, 28 Feb 2026 22:09:59 +0000 Subject: [PATCH 02/11] refactor(operation-processor): remove uses of tokio::spawn Refactor by using a worker future and bubling that up to the top-level of the API. The callee is now responsible for polling the worker task, or else no work will get done. --- crates/rmcp/Cargo.toml | 7 +- crates/rmcp/src/service.rs | 26 +++--- crates/rmcp/src/service/client.rs | 14 +++- crates/rmcp/src/service/server.rs | 17 +++- crates/rmcp/src/task_manager.rs | 126 +++++++++++++++++++++++------- crates/rmcp/tests/test_task.rs | 10 ++- 6 files changed, 151 insertions(+), 49 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 76820c81..e5489986 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -77,7 +77,12 @@ chrono = { version = "0.4.38", default-features = false, features = [ [features] default = ["base64", "macros", "server"] client = ["dep:tokio-stream"] -server = ["transport-async-rw", "dep:schemars", "dep:pastey"] +server = [ + "transport-async-rw", + "dep:schemars", + "dep:pastey", + "dep:tokio-stream", +] macros = ["dep:rmcp-macros", "dep:pastey"] elicitation = ["dep:url"] diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 0e8e95ea..e3fae794 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -1,4 +1,8 @@ -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::FuturesUnordered}; +use futures::{ + FutureExt, Stream, StreamExt, + future::{BoxFuture, RemoteHandle}, + stream::FuturesUnordered, +}; use thiserror::Error; use tokio_stream::wrappers::ReceiverStream; @@ -113,7 +117,9 @@ pub trait ServiceExt: Service + Sized { fn serve( self, transport: T, - ) -> impl Future, R::InitializeError>> + Send + ) -> impl Future< + Output = Result<(RunningService, impl Future), R::InitializeError>, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -125,7 +131,9 @@ pub trait ServiceExt: Service + Sized { self, transport: T, ct: CancellationToken, - ) -> impl Future, R::InitializeError>> + Send + ) -> impl Future< + Output = Result<(RunningService, impl Future), R::InitializeError>, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -434,7 +442,7 @@ impl Peer { pub struct RunningService> { service: Arc, peer: Peer, - handle: Option>, + handle: Option>, cancellation_token: CancellationToken, dg: DropGuard, } @@ -564,6 +572,9 @@ impl> RunningService { impl> Drop for RunningService { fn drop(&mut self) { if self.handle.is_some() && !self.cancellation_token.is_cancelled() { + // Make sure we don't stop the work itself, the work future should + // handle that via cancellation token or drop guard + self.handle.take().unwrap().forget(); tracing::debug!( "RunningService dropped without explicit close(). \ The connection will be closed asynchronously. \ @@ -703,16 +714,11 @@ where .instrument(current_span); let (work, work_handle) = work.remote_handle(); - // If the handle is dropped, don't stop the work. - // We don't want to force the user to keep the `RunningService` - // struct alive just to keep the work running (since the work - // future will be explicitly managed by the caller) - work_handle.forget(); let running_service = RunningService { service, peer: peer_return, - handle: Some(work_handle.boxed()), + handle: Some(work_handle), cancellation_token: ct.clone(), dg: ct.drop_guard(), }; diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 837fafef..5c971791 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -161,7 +161,12 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, ClientInitializeError>> + Send + ) -> impl Future< + Output = Result< + (RunningService, impl Future), + ClientInitializeError, + >, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -174,7 +179,7 @@ impl> ServiceExt for S { pub async fn serve_client( service: S, transport: T, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: IntoTransport, @@ -187,7 +192,7 @@ pub async fn serve_client_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: IntoTransport, @@ -205,7 +210,7 @@ async fn serve_client_with_ct_inner( service: S, transport: T, ct: CancellationToken, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: Transport + 'static, @@ -263,6 +268,7 @@ where transport.send(notification).await.map_err(|error| { ClientInitializeError::transport::(error, "send initialized notification") })?; + let peer_rx = ReceiverStream::new(peer_rx); Ok(serve_inner(service, transport, peer, peer_rx, ct)) } diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 5f54f3dc..b19a1602 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -94,7 +94,12 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, ServerInitializeError>> + Send + ) -> impl Future< + Output = Result< + (RunningService, impl Future), + ServerInitializeError, + >, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -107,7 +112,7 @@ impl> ServiceExt for S { pub async fn serve_server( service: S, transport: T, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: IntoTransport, @@ -166,7 +171,7 @@ pub async fn serve_server_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: IntoTransport, @@ -180,11 +185,14 @@ where } } +/// Performs handshake and initial protocol setup through the transport, +/// and returns a [RunningService] with a separate work future that will +/// need polled to run the service. async fn serve_server_with_ct_inner( service: S, transport: T, ct: CancellationToken, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: Transport + 'static, @@ -258,6 +266,7 @@ where peer: peer.clone(), }; let _ = service.handle_notification(notification, context).await; + let peer_rx = ReceiverStream::new(peer_rx); // Continue processing service Ok(serve_inner(service, transport, peer, peer_rx, ct)) } diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs index 774c542f..c1b6edc3 100644 --- a/crates/rmcp/src/task_manager.rs +++ b/crates/rmcp/src/task_manager.rs @@ -1,6 +1,10 @@ -use std::{any::Any, collections::HashMap, pin::Pin}; +use std::{any::Any, collections::HashMap}; -use futures::Future; +use futures::{ + Future, FutureExt, StreamExt, + future::abortable, + stream::{AbortHandle, FuturesUnordered}, +}; use tokio::{ sync::mpsc, time::{Duration, timeout}, @@ -11,11 +15,14 @@ use crate::{ error::{ErrorData as McpError, RmcpError as Error}, model::{CallToolResult, ClientRequest}, service::RequestContext, + util::PinnedFuture, }; +/// Result of running an operation +pub type OperationResult = Result, Error>; + /// Boxed future that represents an asynchronous operation managed by the processor. -pub type OperationFuture = - Pin, Error>> + Send>>; +pub type OperationFuture<'a> = PinnedFuture<'a, OperationResult>; /// Describes metadata associated with an enqueued task. #[derive(Debug, Clone)] @@ -57,11 +64,11 @@ impl OperationDescriptor { /// Operation message describing a unit of asynchronous work. pub struct OperationMessage { pub descriptor: OperationDescriptor, - pub future: OperationFuture, + pub future: OperationFuture<'static>, } impl OperationMessage { - pub fn new(descriptor: OperationDescriptor, future: OperationFuture) -> Self { + pub fn new(descriptor: OperationDescriptor, future: OperationFuture<'static>) -> Self { Self { descriptor, future } } } @@ -80,17 +87,23 @@ pub struct OperationProcessor { running_tasks: HashMap, /// Completed results waiting to be collected completed_results: Vec, + /// Receiver for asynchronously completed task results. Used + /// to collect back into `completed_results` task_result_receiver: mpsc::UnboundedReceiver, - task_result_sender: mpsc::UnboundedSender, + /// Sender to spawn futures on the worker task associated with this + /// processor. The worker future is created as part of [OperationProcessor::new] + spawn_tx: mpsc::UnboundedSender<(OperationDescriptor, OperationFuture<'static>)>, } +/// A handle to a running operation. struct RunningTask { - task_handle: tokio::task::JoinHandle<()>, + task_handle: AbortHandle, started_at: std::time::Instant, timeout: Option, descriptor: OperationDescriptor, } +/// The result of a running operation. pub struct TaskResult { pub descriptor: OperationDescriptor, pub result: Result, Error>, @@ -126,21 +139,63 @@ impl OperationResultTransport for ToolCallTaskResult { } } -impl Default for OperationProcessor { - fn default() -> Self { - Self::new() - } -} - impl OperationProcessor { - pub fn new() -> Self { + /// Create a new operation processor. + /// + /// This function will return the new [OperationProcessor] + /// facade you can use to queue operations, and also a future + /// that must be polled to handle these operations. + /// + /// Spawn the work function on your runtime of choice, or poll it + /// manually. + pub fn new() -> (Self, impl Future) { let (task_result_sender, task_result_receiver) = mpsc::unbounded_channel(); - Self { + let (spawn_tx, mut spawn_rx) = + mpsc::unbounded_channel::<(OperationDescriptor, OperationFuture)>(); + + let work = async move { + let mut work_set = + FuturesUnordered::>::new(); + + // Loop and listen for new operations incoming that need to be added to the future pool, + // and also listen to operation completions via the future pool. + loop { + tokio::select! { + spawn_req = spawn_rx.recv(), if !spawn_rx.is_closed() => { + if let Some((descriptor, fut)) = spawn_req { + // Map the future back to a descriptor and result tuple + let operation_work = fut.map(|result| (descriptor, result)).boxed(); + // Add it to the set we are polling + work_set.push(operation_work); + } + }, + operation_result = work_set.next(), if !work_set.is_empty() => { + if let Some((descriptor, result)) = operation_result { + match task_result_sender.send(TaskResult { descriptor, result }) { + Err(e) => { + // TODO: Produce an error message here! + } + _ => {} + } + }; + }, + else => { + // Work was empty, and spawn channel was closed. Time + // to break the loop. + break; + } + } + } + }; + + let this = Self { running_tasks: HashMap::new(), completed_results: Vec::new(), task_result_receiver, - task_result_sender, - } + spawn_tx, + }; + + (this, work) } /// Submit an operation for asynchronous execution. @@ -159,12 +214,11 @@ impl OperationProcessor { Ok(()) } + /// Spawns an operation to be executed to completion. fn spawn_async_task(&mut self, message: OperationMessage) { let OperationMessage { descriptor, future } = message; let task_id = descriptor.operation_id.clone(); let timeout_secs = descriptor.ttl.or(Some(DEFAULT_TASK_TIMEOUT_SECS)); - let sender = self.task_result_sender.clone(); - let descriptor_for_result = descriptor.clone(); let timed_future = async move { if let Some(secs) = timeout_secs { @@ -177,16 +231,32 @@ impl OperationProcessor { } }; - let handle = tokio::spawn(async move { - let result = timed_future.await; - let task_result = TaskResult { - descriptor: descriptor_for_result, - result, - }; - let _ = sender.send(task_result); + // Below, we want to give the user a handle to the long-running operation, + // but we don't want to send the result to the user's handle. Rather the + // result gets consumed in the worker task created in the `Self::new` + // function. So here we will use the `Abortable` future utility. + let (work, abort_handle) = abortable(timed_future); + + // Map the error type of abortion (for now) + let work = work.map(|result| { + match result { + // Was not aborted, true operation result + Ok(inner_result) => inner_result, + // Was aborted, flatten to expected error type + Err(e) => Err(Error::TaskError(e.to_string())), + } }); + + // Then send the work to be executed + match self.spawn_tx.send((descriptor.clone(), work.boxed())) { + Ok(_) => {} + Err(e) => { + // TODO: Produce an error message! + } + } + let running_task = RunningTask { - task_handle: handle, + task_handle: abort_handle, started_at: std::time::Instant::now(), timeout: timeout_secs, descriptor, diff --git a/crates/rmcp/tests/test_task.rs b/crates/rmcp/tests/test_task.rs index 9ad0b200..c0f08de0 100644 --- a/crates/rmcp/tests/test_task.rs +++ b/crates/rmcp/tests/test_task.rs @@ -21,7 +21,10 @@ impl OperationResultTransport for DummyTransport { #[tokio::test] async fn executes_enqueued_future() { - let mut processor = OperationProcessor::new(); + let (mut processor, work) = OperationProcessor::new(); + + tokio::spawn(work); + let descriptor = OperationDescriptor::new("op1", "dummy"); let future = Box::pin(async { tokio::time::sleep(Duration::from_millis(10)).await; @@ -50,7 +53,10 @@ async fn executes_enqueued_future() { #[tokio::test] async fn rejects_duplicate_operation_ids() { - let mut processor = OperationProcessor::new(); + let (mut processor, work) = OperationProcessor::new(); + + tokio::spawn(work); + let descriptor = OperationDescriptor::new("dup", "dummy"); let future = Box::pin(async { Ok(Box::new(DummyTransport { From 8d6655b802663477f6d506ccf657d920f5fddff5 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 00:14:16 +0000 Subject: [PATCH 03/11] refactor(progress): remove need for spawning on drop larger refactor for the way progress is multiplexed this needed a redesign of the broadcast multiplex logic to a more stateless design. this design removes the need for mutating any state on drop, the stream dropping implicitly removes broadcast listeners. this design also allows for multiple subscribers of the same progress token. --- crates/rmcp/Cargo.toml | 2 +- crates/rmcp/src/handler/client/progress.rs | 151 +++++++++++++----- crates/rmcp/src/util.rs | 5 + crates/rmcp/tests/test_progress_subscriber.rs | 8 +- 4 files changed, 123 insertions(+), 43 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index e5489986..35049780 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -56,7 +56,7 @@ process-wrap = { version = "9.0", features = ["tokio1"], optional = true } # for http-server transport rand = { version = "0.10", optional = true } -tokio-stream = { version = "0.1", optional = true } +tokio-stream = { version = "0.1", optional = true, features = ["sync"] } uuid = { version = "1", features = ["v4"], optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } diff --git a/crates/rmcp/src/handler/client/progress.rs b/crates/rmcp/src/handler/client/progress.rs index 04f31610..7dd84f08 100644 --- a/crates/rmcp/src/handler/client/progress.rs +++ b/crates/rmcp/src/handler/client/progress.rs @@ -1,32 +1,55 @@ use std::{collections::HashMap, sync::Arc}; use futures::{Stream, StreamExt}; -use tokio::sync::RwLock; -use tokio_stream::wrappers::ReceiverStream; +use tokio::sync::{RwLock, broadcast}; +use tokio_stream::wrappers::BroadcastStream; -use crate::model::{ProgressNotificationParam, ProgressToken}; -type Dispatcher = - Arc>>>; +use crate::{ + model::{ProgressNotificationParam, ProgressToken}, + util::PinnedStream, +}; /// A dispatcher for progress notifications. -#[derive(Debug, Clone, Default)] +/// +/// See [ProgressNotificationParam] and [ProgressToken] for more details on +/// how progress is dispatched to a particular listener. +#[derive(Debug, Clone)] pub struct ProgressDispatcher { - pub(crate) dispatcher: Dispatcher, + /// A channel of any progress notification. Subscribers will filter + /// on this channel. + pub(crate) any_progress_notification_tx: broadcast::Sender, + pub(crate) unsubscribe_tx: broadcast::Sender, + pub(crate) unsubscribe_all_tx: broadcast::Sender<()>, } impl ProgressDispatcher { const CHANNEL_SIZE: usize = 16; pub fn new() -> Self { - Self::default() + // Note that channel size is per-receiver for broadcast channel. It is up to the receiver to + // keep up with the notifications to avoid missing any (via propper polling) + let (any_progress_notification_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + let (unsubscribe_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + let (unsubscribe_all_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + Self { + any_progress_notification_tx, + unsubscribe_tx, + unsubscribe_all_tx, + } } /// Handle a progress notification by sending it to the appropriate subscriber pub async fn handle_notification(&self, notification: ProgressNotificationParam) { - let token = ¬ification.progress_token; - if let Some(sender) = self.dispatcher.read().await.get(token).cloned() { - let send_result = sender.send(notification).await; - if let Err(e) = send_result { - tracing::warn!("Failed to send progress notification: {e}"); + // Broadcast the notification to all subscribers. Interested subscribers + // will filter on their end. + // ! Note that this implementaiton is very stateless and simple, we cannot + // ! easily inspect which subscribers are interested in which notifications. + // ! However, the stateless-ness and simplicity is also a plus! + // ! Cleanup becomes much easier. Just drop the `ProgressSubscriber`. + match self.any_progress_notification_tx.send(notification) { + Ok(_) => {} + Err(_) => { + // This error only happens if there are no active receivers of the `broadcast` channel. + // Silent error. } } } @@ -35,35 +58,97 @@ impl ProgressDispatcher { /// /// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token. pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber { - let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE); - self.dispatcher - .write() - .await - .insert(progress_token.clone(), sender); - let receiver = ReceiverStream::new(receiver); + // First, set up the unsubscribe listeners. This will fuse the notifiaction stream below. + let progress_token_clone = progress_token.clone(); + let unsub_this_token_rx = BroadcastStream::new(self.unsubscribe_tx.subscribe()).filter_map( + move |token| { + let progress_token_clone = progress_token_clone.clone(); + async move { + match token { + Ok(token) => { + if token == progress_token_clone { + Some(()) + } else { + None + } + } + Err(e) => { + // An error here means the broadcast stream did not receive values quick enough and + // and we missed some notification. This implies there are notifications + // we missed, but we cannot assume they were for us :( + tracing::warn!( + "Error receiving unsubscribe notification from broadcast channel: {e}" + ); + None + } + } + } + }, + ); + let unsub_any_token_tx = + BroadcastStream::new(self.unsubscribe_all_tx.subscribe()).map(|_| { + // Any reception of a result here indicates we should unsubscribe, + // regardless of if we received an `Ok(())` or an `Err(_)` (which + // indicates the broadcast receiver lagged behind) + () + }); + let unsub_fut = futures::stream::select(unsub_this_token_rx, unsub_any_token_tx) + .boxed() + .into_future(); // If the unsub streams end, this will cause unsubscription from the subscriber below. + + // Now setup the notification stream. We will receive all notifications and only forward progress notifications + // for the token we're interested in. + let progress_token_clone = progress_token.clone(); + let receiver = BroadcastStream::new(self.any_progress_notification_tx.subscribe()) + .filter_map(move |notification| { + let progress_token_clone = progress_token_clone.clone(); + async move { + // We need to kneed-out the broadcast receive error type here. + match notification { + Ok(notification) => { + let token = notification.progress_token.clone(); + if token == progress_token_clone { + Some(notification) + } else { + None + } + } + Err(e) => { + tracing::warn!( + "Error receiving progress notification from broadcast channel: {e}" + ); + None + } + } + } + }) + // Fuse this stream so it stops once we receive an unsubscribe notification from the stream + // created above + .take_until(unsub_fut) + .boxed(); + ProgressSubscriber { progress_token, receiver, - dispatcher: self.dispatcher.clone(), } } /// Unsubscribe from progress notifications for a specific token. - pub async fn unsubscribe(&self, token: &ProgressToken) { - self.dispatcher.write().await.remove(token); + pub fn unsubscribe(&self, token: ProgressToken) { + // The only error defined is if there are no listeners, which is fine. Ignore the result. + let _ = self.unsubscribe_tx.send(token); } /// Clear all dispatcher. - pub async fn clear(&self) { - let mut dispatcher = self.dispatcher.write().await; - dispatcher.clear(); + pub fn clear(&self) { + // The only error defined is if there are no listeners, which is fine. Ignore the result. + let _ = self.unsubscribe_all_tx.send(()); } } pub struct ProgressSubscriber { pub(crate) progress_token: ProgressToken, - pub(crate) receiver: ReceiverStream, - pub(crate) dispatcher: Dispatcher, + pub(crate) receiver: PinnedStream<'static, ProgressNotificationParam>, } impl ProgressSubscriber { @@ -86,15 +171,3 @@ impl Stream for ProgressSubscriber { self.receiver.size_hint() } } - -impl Drop for ProgressSubscriber { - fn drop(&mut self) { - let token = self.progress_token.clone(); - self.receiver.close(); - let dispatcher = self.dispatcher.clone(); - tokio::spawn(async move { - let mut dispatcher = dispatcher.write_owned().await; - dispatcher.remove(&token); - }); - } -} diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 06a563e3..33b273f8 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -1,5 +1,10 @@ +use futures::Stream; use std::pin::Pin; pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; + +pub type PinnedStream<'a, T> = Pin + Send + 'a>>; + +pub type PinnedLocalStream<'a, T> = Pin + 'a>>; diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs index 521219a3..5c5715b9 100644 --- a/crates/rmcp/tests/test_progress_subscriber.rs +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -100,11 +100,13 @@ async fn test_progress_subscriber() -> anyhow::Result<()> { let server = MyServer::new(); let (transport_server, transport_client) = tokio::io::duplex(4096); tokio::spawn(async move { - let service = server.serve(transport_server).await?; - service.waiting().await?; + let (service, work) = server.serve(transport_server).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); - let client_service = client.serve(transport_client).await?; + let (client_service, client_work) = client.serve(transport_client).await?; + tokio::spawn(client_work); let handle = client_service .send_cancellable_request( ClientRequest::CallToolRequest(Request::new(CallToolRequestParams { From dc4204ca270f4f7fd0c730b9d30e54a585c7047f Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 05:16:26 +0000 Subject: [PATCH 04/11] refactor(child-process): wip experiment with new child process transport --- crates/rmcp/Cargo.toml | 2 +- crates/rmcp/src/lib.rs | 1 + crates/rmcp/src/transport.rs | 2 + crates/rmcp/src/transport/async_rw.rs | 6 +- crates/rmcp/src/transport/child_process.rs | 4 +- crates/rmcp/src/transport/child_process2.rs | 2 + .../src/transport/child_process2/runner.rs | 314 ++++++++++++++++++ .../src/transport/child_process2/transport.rs | 71 ++++ crates/rmcp/src/util.rs | 65 +++- 9 files changed, 461 insertions(+), 6 deletions(-) create mode 100644 crates/rmcp/src/transport/child_process2.rs create mode 100644 crates/rmcp/src/transport/child_process2/runner.rs create mode 100644 crates/rmcp/src/transport/child_process2/transport.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 35049780..421bd07a 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -119,7 +119,7 @@ transport-streamable-http-client-reqwest = [ "__reqwest", ] -transport-async-rw = ["tokio/io-util", "tokio-util/codec"] +transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"] transport-io = ["transport-async-rw", "tokio/io-std"] transport-child-process = [ "transport-async-rw", diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 456bc3ea..c70e61b5 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -4,6 +4,7 @@ mod error; mod util; + #[allow(deprecated)] pub use error::{Error, ErrorData, RmcpError}; diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index d7dfa979..f22b63d8 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -85,6 +85,8 @@ pub use worker::WorkerTransport; pub mod child_process; #[cfg(feature = "transport-child-process")] pub use child_process::{ConfigureCommandExt, TokioChildProcess}; +#[cfg(feature = "transport-child-process")] +pub mod child_process2; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/async_rw.rs b/crates/rmcp/src/transport/async_rw.rs index ff4ecc65..cb5ea75d 100644 --- a/crates/rmcp/src/transport/async_rw.rs +++ b/crates/rmcp/src/transport/async_rw.rs @@ -43,9 +43,13 @@ where } pub type TransportWriter = FramedWrite>>; +pub type TransportReader = FramedRead>>; pub struct AsyncRwTransport { - read: FramedRead>>, + read: TransportReader, + /// This is behind a mutex so that concurrent writes can happen. + /// Naturally, the mutex will block parallel writes, but allow + /// multiple futures to be executed at once, even if some are waiting. write: Arc>>>, } diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index e33800b1..58e8b9a8 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -23,7 +23,7 @@ type ChildProcessParts = ( /// Returns `(child, stdout, stdin, stderr)` where `stderr` is `Some` only /// if the process was spawned with `Stdio::piped()`. #[inline] -fn child_process(mut child: Box) -> std::io::Result { +fn split_child_process(mut child: Box) -> std::io::Result { let child_stdin = match child.inner_mut().stdin().take() { Some(stdin) => stdin, None => return Err(std::io::Error::other("stdin was already taken")), @@ -192,7 +192,7 @@ impl TokioChildProcessBuilder { .stdout(self.stdout) .stderr(self.stderr); - let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?; + let (child, stdout, stdin, stderr_opt) = split_child_process(self.cmd.spawn()?)?; let transport = AsyncRwTransport::new(stdout, stdin); let proc = TokioChildProcess { diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs new file mode 100644 index 00000000..27f1eb46 --- /dev/null +++ b/crates/rmcp/src/transport/child_process2.rs @@ -0,0 +1,2 @@ +pub mod runner; +pub mod transport; diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs new file mode 100644 index 00000000..d169970f --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -0,0 +1,314 @@ +use futures::{ + FutureExt, + io::{AsyncRead, AsyncWrite}, +}; +use std::process::Stdio; + +use crate::util::PinnedFuture; + +/// A simple enum for describing if a stream is available, unused, or already taken. +#[derive(Debug)] +pub enum StreamSlot { + /// The stream is not used in this implementation. + Unused, + /// The stream is available for use, and can be taken. + Available(S), + /// The stream has already been taken, and is no longer available. + Taken, +} + +impl From> for Option { + fn from(slot: StreamSlot) -> Self { + match slot { + StreamSlot::Unused => None, + StreamSlot::Available(s) => Some(s), + StreamSlot::Taken => None, + } + } +} + +/// A structure that requests how the child process streams should +/// be configured when spawning. +pub struct StdioConfig { + pub stdin: Stdio, + pub stdout: Stdio, + pub stderr: Stdio, +} + +/// The contract for what an instance of a child process +/// must provide to be used with a transport. +pub trait ChildProcessInstance { + /// The input stream for the command + type Stdin: AsyncWrite + Unpin + Send; + + /// The output stream of the command + type Stdout: AsyncRead + Unpin + Send; + + /// The error stream of the command + type Stderr: AsyncRead + Unpin + Send; + + fn take_stdin(&mut self) -> StreamSlot; + fn take_stdout(&mut self) -> StreamSlot; + fn take_stderr(&mut self) -> StreamSlot; + + fn pid(&self) -> u32; + fn wait( + &mut self, + ) -> impl Future> + Send + 'static; + fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static; + fn kill(&mut self) -> impl Future> + Send + 'static; +} + +/// A subset of functionality of [ChildProcessInstance] that only includes the +/// functions used to control or wait for the process. +pub trait ChildProcessControl { + fn pid(&self) -> u32; + fn wait(&mut self) -> PinnedFuture<'static, std::io::Result>; + fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; + fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; +} + +/// Auto-implement ChildProcessControl for any ChildProcessInstance, since it has all the required methods. +impl ChildProcessControl for T +where + T: ChildProcessInstance, +{ + fn pid(&self) -> u32 { + ChildProcessInstance::pid(self) + } + + fn wait(&mut self) -> PinnedFuture<'static, std::io::Result> { + ChildProcessInstance::wait(self).boxed() + } + + fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + ChildProcessInstance::graceful_shutdown(self).boxed() + } + + fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + ChildProcessInstance::kill(self).boxed() + } +} + +#[derive(Debug)] +pub enum RunnerSpawnError { + /// The child process instance failed to spawn. + SpawnError(std::io::Error), + Other(Box), +} + +pub trait ChildProcessRunner { + /// The implementation of the child process instance that this runner will spawn. + type Instance: ChildProcessInstance; + + fn spawn( + command: &str, + args: &[&str], + stdio_config: StdioConfig, + ) -> Result; +} + +/// A containing wrapper around a child process instance. This struct erases the type +/// by extracting some parts of the [ChildProcessInstance] trait into a common struct, +/// and then only exposes the control methods. +pub struct ChildProcess { + stdin: StreamSlot>, + stdout: StreamSlot>, + stderr: StreamSlot>, + inner: Box, +} + +impl ChildProcess { + pub fn new(mut instance: C) -> Self + where + C: ChildProcessInstance + Send + 'static, + { + Self { + stdin: match instance.take_stdin() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stdin stream was already taken during ChildProcess construction") + } + }, + stdout: match instance.take_stdout() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stdout stream was already taken during ChildProcess construction") + } + }, + stderr: match instance.take_stderr() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stderr stream was already taken during ChildProcess construction") + } + }, + inner: Box::new(instance), + } + } + + pub fn split( + self, + ) -> ( + Option>, + Option>, + Option>, + Box, + ) { + ( + self.stdout.into(), + self.stdin.into(), + self.stderr.into(), + self.inner, + ) + } +} + +impl ChildProcessInstance for ChildProcess { + type Stdin = Box; + + type Stdout = Box; + + type Stderr = Box; + + fn take_stdin(&mut self) -> StreamSlot { + match self.stdin { + StreamSlot::Available(_) => std::mem::replace(&mut self.stdin, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn take_stdout(&mut self) -> StreamSlot { + match self.stdout { + StreamSlot::Available(_) => std::mem::replace(&mut self.stdout, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn take_stderr(&mut self) -> StreamSlot { + match self.stderr { + StreamSlot::Available(_) => std::mem::replace(&mut self.stderr, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn pid(&self) -> u32 { + self.inner.pid() + } + + fn wait( + &mut self, + ) -> impl Future> + Send + 'static { + self.inner.wait() + } + + fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static { + self.inner.graceful_shutdown() + } + + fn kill(&mut self) -> impl Future> + Send + 'static { + self.inner.kill() + } +} + +pub struct CommandBuilder { + command: String, + args: Vec, + _marker: std::marker::PhantomData, + stderr: Stdio, +} + +pub enum CommandBuilderError { + EmptyCommand, +} + +impl CommandBuilder { + /// Create a CommandBuilder from an argv-style list of strings, where the first element is the command, and the rest are the args. + pub fn from_argv(argv: I) -> Result + where + I: IntoIterator, + S: Into, + { + let mut iter = argv.into_iter(); + + // Pop the first element as the command, and use the rest as args + let command = match iter.next() { + Some(cmd) => cmd.into(), + None => return Err(CommandBuilderError::EmptyCommand), + }; + + let args = iter.map(|s| s.into()).collect(); + Ok(Self { + command, + args, + _marker: std::marker::PhantomData, + stderr: Stdio::inherit(), + }) + } + + /// Create a CommandBuilder from a command and an optional list of args. + pub fn new(command: impl Into) -> Self { + Self { + command: command.into(), + args: Vec::new(), + _marker: std::marker::PhantomData, + stderr: Stdio::inherit(), + } + } + + /// Add a single argument to the command. + pub fn arg(mut self, arg: impl Into) -> Self { + self.args.push(arg.into()); + self + } + + /// Add multiple arguments to the command. + pub fn args(mut self, args: impl IntoIterator>) -> Self { + self.args.extend(args.into_iter().map(|arg| arg.into())); + self + } + + /// Sets what happens to stderr for the command. + /// By default if not set, stderr is inherited from the parent process. + pub fn stderr(mut self, _stdio: Stdio) -> Self { + self.stderr = _stdio; + self + } +} + +impl CommandBuilder +where + R: ChildProcessRunner, +{ + /// Spawn the command into its typed child process instance type. + pub fn spawn_raw(self) -> Result { + // We should always pipe stdin and stdout. + let stdio_config = StdioConfig { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: self.stderr, + }; + + R::spawn( + &self.command, + &self.args.iter().map(|s| s.as_str()).collect::>(), + stdio_config, + ) + } + + /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. + /// + /// Requires `R::Instance` to be [Send] and `'static`. + pub fn spawn_dyn(self) -> Result + where + R::Instance: Send + 'static, + { + let instance = self.spawn_raw()?; + Ok(ChildProcess::new(instance)) + } +} diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process2/transport.rs new file mode 100644 index 00000000..b378309a --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/transport.rs @@ -0,0 +1,71 @@ +use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}; + +use crate::{ + service::ServiceRole, + transport::{ + Transport, + async_rw::AsyncRwTransport, + child_process2::runner::{ChildProcess, ChildProcessControl}, + }, +}; + +pub struct ChildProcessTransport { + child: Box, + framed_transport: AsyncRwTransport< + R, + Box, + Box, + >, +} + +impl ChildProcessTransport +where + R: ServiceRole, +{ + pub fn new(child: ChildProcess) -> Result> { + let (stdout, stdin, stderr, control) = child.split(); + + let framed_transport: AsyncRwTransport = AsyncRwTransport::new( + Box::new( + stdout + .ok_or("Failed to capture stdout of child process")? + .compat(), + ) as Box, + Box::new( + stdin + .ok_or("Failed to capture stdin of child process")? + .compat_write(), + ) as Box, + ); + + Ok(Self { + child: control, + framed_transport, + }) + } +} + +impl Transport for ChildProcessTransport +where + R: ServiceRole, +{ + type Error = std::io::Error; + + fn send( + &mut self, + item: crate::service::TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + self.framed_transport.send(item) + } + + fn receive( + &mut self, + ) -> impl Future>> + Send { + self.framed_transport.receive() + } + + fn close(&mut self) -> impl Future> + Send { + self.framed_transport.close() + } +} diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 33b273f8..97121ac3 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -1,5 +1,5 @@ -use futures::Stream; -use std::pin::Pin; +use futures::{Sink, Stream}; +use std::{pin::Pin, task::Poll}; pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; @@ -8,3 +8,64 @@ pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; pub type PinnedStream<'a, T> = Pin + Send + 'a>>; pub type PinnedLocalStream<'a, T> = Pin + 'a>>; + +pub enum UnboundedSenderSinkError { + SendError(tokio::sync::mpsc::error::SendError), + Closed, +} + +/// A simple [Sink] wrapper for Tokio's [tokio::sync::mpsc::UnboundedSender] +#[derive(Debug, Clone)] +pub struct UnboundedSenderSink { + sender: tokio::sync::mpsc::UnboundedSender, +} + +impl UnboundedSenderSink { + pub fn new(sender: tokio::sync::mpsc::UnboundedSender) -> Self { + Self { sender } + } +} + +impl Sink for UnboundedSenderSink { + type Error = UnboundedSenderSinkError; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.sender.is_closed() { + Poll::Ready(Err(UnboundedSenderSinkError::Closed)) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let this = self.get_mut(); + match this.sender.send(item) { + Ok(_) => Ok(()), + Err(e) => Err(UnboundedSenderSinkError::SendError(e)), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // tokio's unbounded mpsc senders have no flushing required, since the + // receiver is unbounded and will get all messages we send (unless we run + // out of memory) + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Like `poll_flush`, there is nothing to wait on here. A single + // call to `mpsc_sender.send(...)` is immediate from the perspective + // of the sender + Poll::Ready(Ok(())) + } +} From 1db16c8ddae7fdbbd94eb192d3e24123bc514238 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 06:28:21 +0000 Subject: [PATCH 05/11] refactor(child-process): implement tokio child process and use in test --- crates/rmcp/Cargo.toml | 2 + crates/rmcp/src/task_manager.rs | 4 +- crates/rmcp/src/transport/child_process2.rs | 1 + .../src/transport/child_process2/runner.rs | 41 +++++----- .../src/transport/child_process2/tokio.rs | 81 +++++++++++++++++++ .../src/transport/child_process2/transport.rs | 2 +- crates/rmcp/src/transport/common.rs | 2 +- crates/rmcp/src/transport/common/reqwest.rs | 2 +- .../transport/streamable_http_server/tower.rs | 14 ++-- crates/rmcp/src/util.rs | 6 +- crates/rmcp/tests/test_with_js.rs | 31 ++++--- 11 files changed, 145 insertions(+), 41 deletions(-) create mode 100644 crates/rmcp/src/transport/child_process2/tokio.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 421bd07a..f8e8011e 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -74,6 +74,8 @@ chrono = { version = "0.4.38", default-features = false, features = [ "oldtime", ] } +[target.'cfg(test)'] + [features] default = ["base64", "macros", "server"] client = ["dep:tokio-stream"] diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs index c1b6edc3..7cc81575 100644 --- a/crates/rmcp/src/task_manager.rs +++ b/crates/rmcp/src/task_manager.rs @@ -173,7 +173,7 @@ impl OperationProcessor { if let Some((descriptor, result)) = operation_result { match task_result_sender.send(TaskResult { descriptor, result }) { Err(e) => { - // TODO: Produce an error message here! + tracing::error!("Failed to send completed task result: {e}"); } _ => {} } @@ -251,7 +251,7 @@ impl OperationProcessor { match self.spawn_tx.send((descriptor.clone(), work.boxed())) { Ok(_) => {} Err(e) => { - // TODO: Produce an error message! + tracing::error!("Failed to spawn task on worker: {e}"); } } diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs index 27f1eb46..c82e16db 100644 --- a/crates/rmcp/src/transport/child_process2.rs +++ b/crates/rmcp/src/transport/child_process2.rs @@ -1,2 +1,3 @@ pub mod runner; +pub mod tokio; pub mod transport; diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs index d169970f..6f062065 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -52,20 +52,21 @@ pub trait ChildProcessInstance { fn take_stderr(&mut self) -> StreamSlot; fn pid(&self) -> u32; - fn wait( - &mut self, - ) -> impl Future> + Send + 'static; - fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static; - fn kill(&mut self) -> impl Future> + Send + 'static; + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's; + fn graceful_shutdown<'s>(&'s mut self) + -> impl Future> + Send + 's; + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's; } /// A subset of functionality of [ChildProcessInstance] that only includes the /// functions used to control or wait for the process. pub trait ChildProcessControl { fn pid(&self) -> u32; - fn wait(&mut self) -> PinnedFuture<'static, std::io::Result>; - fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; - fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; + fn wait<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result>; + fn graceful_shutdown<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>>; + fn kill<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>>; } /// Auto-implement ChildProcessControl for any ChildProcessInstance, since it has all the required methods. @@ -77,23 +78,25 @@ where ChildProcessInstance::pid(self) } - fn wait(&mut self) -> PinnedFuture<'static, std::io::Result> { + fn wait<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result> { ChildProcessInstance::wait(self).boxed() } - fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + fn graceful_shutdown<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>> { ChildProcessInstance::graceful_shutdown(self).boxed() } - fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + fn kill<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>> { ChildProcessInstance::kill(self).boxed() } } -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum RunnerSpawnError { /// The child process instance failed to spawn. - SpawnError(std::io::Error), + #[error("Failed to spawn child process: {0}")] + SpawnError(#[from] std::io::Error), + #[error("Other error: {0}")] Other(Box), } @@ -201,17 +204,19 @@ impl ChildProcessInstance for ChildProcess { self.inner.pid() } - fn wait( - &mut self, - ) -> impl Future> + Send + 'static { + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { self.inner.wait() } - fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static { + fn graceful_shutdown<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { self.inner.graceful_shutdown() } - fn kill(&mut self) -> impl Future> + Send + 'static { + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's { self.inner.kill() } } diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process2/tokio.rs new file mode 100644 index 00000000..dd75e0e1 --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/tokio.rs @@ -0,0 +1,81 @@ +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +use crate::transport::child_process2::runner::{ + ChildProcessInstance, ChildProcessRunner, RunnerSpawnError, StdioConfig, +}; + +pub struct TokioChildProcessRunner {} + +pub struct TokioChildProcess { + inner: tokio::process::Child, +} + +impl ChildProcessInstance for TokioChildProcess { + type Stdin = Compat; + + type Stdout = Compat; + + type Stderr = Compat; + + fn take_stdin(&mut self) -> super::runner::StreamSlot { + match self.inner.stdin.take() { + Some(stdin) => super::runner::StreamSlot::Available(stdin.compat_write()), + None => super::runner::StreamSlot::Unused, + } + } + + fn take_stdout(&mut self) -> super::runner::StreamSlot { + match self.inner.stdout.take() { + Some(stdout) => super::runner::StreamSlot::Available(stdout.compat()), + None => super::runner::StreamSlot::Unused, + } + } + + fn take_stderr(&mut self) -> super::runner::StreamSlot { + match self.inner.stderr.take() { + Some(stderr) => super::runner::StreamSlot::Available(stderr.compat()), + None => super::runner::StreamSlot::Unused, + } + } + + fn pid(&self) -> u32 { + // TODO: Consider refactor to return Option to avoid confusion of 0 as a valid PID. + self.inner.id().unwrap_or(0) + } + + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + self.inner.wait() + } + + fn graceful_shutdown<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + // TODO: Implement graceful shutdown on unix with SIGTERM. And look into graceful shutdown on windows as well. + self.inner.kill() + } + + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's { + self.inner.kill() + } +} + +impl ChildProcessRunner for TokioChildProcessRunner { + type Instance = TokioChildProcess; + fn spawn( + command: &str, + args: &[&str], + stdio_configuration: StdioConfig, + ) -> Result { + tokio::process::Command::new(command) + .args(args) + .stdin(stdio_configuration.stdin) + .stdout(stdio_configuration.stdout) + .stderr(stdio_configuration.stderr) + .kill_on_drop(true) + .spawn() + .map(|child| TokioChildProcess { inner: child }) + .map_err(RunnerSpawnError::SpawnError) + } +} diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process2/transport.rs index b378309a..a6731347 100644 --- a/crates/rmcp/src/transport/child_process2/transport.rs +++ b/crates/rmcp/src/transport/child_process2/transport.rs @@ -24,7 +24,7 @@ where R: ServiceRole, { pub fn new(child: ChildProcess) -> Result> { - let (stdout, stdin, stderr, control) = child.split(); + let (stdout, stdin, _stderr, control) = child.split(); let framed_transport: AsyncRwTransport = AsyncRwTransport::new( Box::new( diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index 615b0e27..b41a8f3c 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -4,7 +4,7 @@ pub mod server_side_http; pub mod http_header; #[cfg(feature = "__reqwest")] -mod reqwest; +pub mod reqwest; // Note: This module provides SSE stream parsing and auto-reconnect utilities. // It's used by the streamable HTTP client (which receives SSE-formatted responses), diff --git a/crates/rmcp/src/transport/common/reqwest.rs b/crates/rmcp/src/transport/common/reqwest.rs index 42075921..696aa912 100644 --- a/crates/rmcp/src/transport/common/reqwest.rs +++ b/crates/rmcp/src/transport/common/reqwest.rs @@ -1,2 +1,2 @@ #[cfg(feature = "transport-streamable-http-client-reqwest")] -mod streamable_http_client; +pub mod streamable_http_client; diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index ba28a8bb..f62708e8 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -515,13 +515,15 @@ where let session_manager = self.session_manager.clone(); let session_id = session_id.clone(); async move { - let service = serve_server::( - service, transport, - ) - .await; - match service { - Ok(service) => { + let serve_result = + serve_server::( + service, transport, + ) + .await; + match serve_result { + Ok((service, work)) => { // on service created + tokio::spawn(work); let _ = service.waiting().await; } Err(e) => { diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 97121ac3..912e1378 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -31,7 +31,7 @@ impl Sink for UnboundedSenderSink { fn poll_ready( self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.get_mut(); if this.sender.is_closed() { @@ -51,7 +51,7 @@ impl Sink for UnboundedSenderSink { fn poll_flush( self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { // tokio's unbounded mpsc senders have no flushing required, since the // receiver is unbounded and will get all messages we send (unless we run @@ -61,7 +61,7 @@ impl Sink for UnboundedSenderSink { fn poll_close( self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { // Like `poll_flush`, there is nothing to wait on here. A single // call to `mpsc_sender.send(...)` is immediate from the perspective diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index c1e5d81a..31842773 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -3,7 +3,11 @@ use rmcp::{ service::QuitReason, transport::{ ConfigureCommandExt, StreamableHttpClientTransport, StreamableHttpServerConfig, - TokioChildProcess, + child_process2::{ + runner::{ChildProcessControl, CommandBuilder}, + tokio::TokioChildProcessRunner, + transport::ChildProcessTransport, + }, streamable_http_server::{ session::local::LocalSessionManager, tower::StreamableHttpService, }, @@ -32,18 +36,26 @@ async fn test_with_js_stdio_server() -> anyhow::Result<()> { .spawn()? .wait() .await?; - let transport = - TokioChildProcess::new(tokio::process::Command::new("node").configure(|cmd| { - cmd.arg("tests/test_with_js/server.js"); - }))?; - let client = ().serve(transport).await?; + let node_cmd = CommandBuilder::::new("node") + .args(["tests/test_with_js/server.js"]) + .spawn_dyn()?; + + tracing::info!("Spawned child process with PID: {}", node_cmd.pid()); + + let transport = ChildProcessTransport::new(node_cmd) + .map_err(|e| anyhow::anyhow!("Failed to spawn child process: {e}"))?; + + let (client, work) = ().serve(transport).await?; + + tokio::spawn(work); + let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -124,12 +136,13 @@ async fn test_with_js_streamable_http_server() -> anyhow::Result<()> { // waiting for server up tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - let client = ().serve(transport).await?; + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - let quit_reason = client.cancel().await?; + let quit_reason = client.cancel().await; server.kill().await?; assert!(matches!(quit_reason, QuitReason::Cancelled)); Ok(()) From d0bd6ca699f500e91c0d7f043323baa38ad72fee Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:06:25 +0000 Subject: [PATCH 06/11] refactor(child-process): continue to build command abstraction also update some unit tests, remove old child_process --- crates/rmcp/Cargo.toml | 14 +- crates/rmcp/src/transport.rs | 8 +- crates/rmcp/src/transport/child_process.rs | 309 ------------------ crates/rmcp/src/transport/child_process2.rs | 4 +- .../src/transport/child_process2/runner.rs | 71 ++-- .../src/transport/child_process2/tokio.rs | 24 +- crates/rmcp/tests/test_with_js.rs | 4 +- crates/rmcp/tests/test_with_python.rs | 48 ++- 8 files changed, 104 insertions(+), 378 deletions(-) delete mode 100644 crates/rmcp/src/transport/child_process.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index f8e8011e..d08951d0 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -123,10 +123,11 @@ transport-streamable-http-client-reqwest = [ transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"] transport-io = ["transport-async-rw", "tokio/io-std"] -transport-child-process = [ +transport-child-process = ["transport-async-rw", "tokio/process"] +transport-child-process-tokio = [ "transport-async-rw", "tokio/process", - "dep:process-wrap", + "tokio/rt", ] transport-streamable-http-server = [ "transport-streamable-http-server-session", @@ -163,7 +164,13 @@ path = "tests/test_tool_macros.rs" [[test]] name = "test_with_python" -required-features = ["reqwest", "server", "client", "transport-child-process"] +required-features = [ + "reqwest", + "server", + "client", + "transport-child-process", + "transport-child-process-tokio", +] path = "tests/test_with_python.rs" [[test]] @@ -172,6 +179,7 @@ required-features = [ "server", "client", "transport-child-process", + "transport-child-process-tokio", "transport-streamable-http-server", "transport-streamable-http-client", "__reqwest", diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index f22b63d8..99f84c05 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -81,12 +81,12 @@ pub mod worker; #[cfg(feature = "transport-worker")] pub use worker::WorkerTransport; -#[cfg(feature = "transport-child-process")] -pub mod child_process; -#[cfg(feature = "transport-child-process")] -pub use child_process::{ConfigureCommandExt, TokioChildProcess}; #[cfg(feature = "transport-child-process")] pub mod child_process2; +#[cfg(feature = "transport-child-process")] +pub use child_process2::runner::{ + ChildProcess, ChildProcessInstance, ChildProcessRunner, CommandBuilder, +}; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs deleted file mode 100644 index 58e8b9a8..00000000 --- a/crates/rmcp/src/transport/child_process.rs +++ /dev/null @@ -1,309 +0,0 @@ -use std::process::Stdio; - -use futures::future::Future; -use process_wrap::tokio::{ChildWrapper, CommandWrap}; -use tokio::{ - io::AsyncRead, - process::{ChildStderr, ChildStdin, ChildStdout}, -}; - -use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport}; -use crate::RoleClient; - -const MAX_WAIT_ON_DROP_SECS: u64 = 3; -/// The parts of a child process. -type ChildProcessParts = ( - Box, - ChildStdout, - ChildStdin, - Option, -); - -/// Extract the stdio handles from a spawned child. -/// Returns `(child, stdout, stdin, stderr)` where `stderr` is `Some` only -/// if the process was spawned with `Stdio::piped()`. -#[inline] -fn split_child_process(mut child: Box) -> std::io::Result { - let child_stdin = match child.inner_mut().stdin().take() { - Some(stdin) => stdin, - None => return Err(std::io::Error::other("stdin was already taken")), - }; - let child_stdout = match child.inner_mut().stdout().take() { - Some(stdout) => stdout, - None => return Err(std::io::Error::other("stdout was already taken")), - }; - let child_stderr = child.inner_mut().stderr().take(); - Ok((child, child_stdout, child_stdin, child_stderr)) -} - -pub struct TokioChildProcess { - child: ChildWithCleanup, - transport: AsyncRwTransport, -} - -pub struct ChildWithCleanup { - inner: Option>, -} - -impl Drop for ChildWithCleanup { - fn drop(&mut self) { - // We should not use start_kill(), instead we should use kill() to avoid zombies - if let Some(mut inner) = self.inner.take() { - // We don't care about the result, just try to kill it - tokio::spawn(async move { - if let Err(e) = Box::into_pin(inner.kill()).await { - tracing::warn!("Error killing child process: {}", e); - } - }); - } - } -} - -// we hold the child process with stdout, for it's easier to implement AsyncRead -pin_project_lite::pin_project! { - pub struct TokioChildProcessOut { - child: ChildWithCleanup, - #[pin] - child_stdout: ChildStdout, - } -} - -impl TokioChildProcessOut { - /// Get the process ID of the child process. - pub fn id(&self) -> Option { - self.child.inner.as_ref()?.id() - } -} - -impl AsyncRead for TokioChildProcessOut { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - self.project().child_stdout.poll_read(cx, buf) - } -} - -impl TokioChildProcess { - /// Convenience: spawn with default `piped` stdio - pub fn new(command: impl Into) -> std::io::Result { - let (proc, _ignored) = TokioChildProcessBuilder::new(command).spawn()?; - Ok(proc) - } - - /// Builder entry-point allowing fine-grained stdio control. - pub fn builder(command: impl Into) -> TokioChildProcessBuilder { - TokioChildProcessBuilder::new(command) - } - - /// Get the process ID of the child process. - pub fn id(&self) -> Option { - self.child.inner.as_ref()?.id() - } - - /// Gracefully shutdown the child process - /// - /// This will first close the transport to the child process (the server), - /// and wait for the child process to exit normally with a timeout. - /// If the child process doesn't exit within the timeout, it will be killed. - pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { - if let Some(mut child) = self.child.inner.take() { - self.transport.close().await?; - - let wait_fut = child.wait(); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => { - if let Err(e) = Box::into_pin(child.kill()).await { - tracing::warn!("Error killing child: {e}"); - return Err(e); - } - }, - res = wait_fut => { - match res { - Ok(status) => { - tracing::info!("Child exited gracefully {}", status); - } - Err(e) => { - tracing::warn!("Error waiting for child: {e}"); - return Err(e); - } - } - } - } - } - Ok(()) - } - - /// Take ownership of the inner child process - pub fn into_inner(mut self) -> Option> { - self.child.inner.take() - } - - /// Split this helper into a reader (stdout) and writer (stdin). - #[deprecated( - since = "0.5.0", - note = "use the Transport trait implementation instead" - )] - pub fn split(self) -> (TokioChildProcessOut, ChildStdin) { - unimplemented!("This method is deprecated, use the Transport trait implementation instead"); - } -} - -/// Builder for `TokioChildProcess` allowing custom `Stdio` configuration. -pub struct TokioChildProcessBuilder { - cmd: CommandWrap, - stdin: Stdio, - stdout: Stdio, - stderr: Stdio, -} - -impl TokioChildProcessBuilder { - fn new(cmd: impl Into) -> Self { - Self { - cmd: cmd.into(), - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: Stdio::inherit(), - } - } - - /// Override the child stdin configuration. - pub fn stdin(mut self, io: impl Into) -> Self { - self.stdin = io.into(); - self - } - /// Override the child stdout configuration. - pub fn stdout(mut self, io: impl Into) -> Self { - self.stdout = io.into(); - self - } - /// Override the child stderr configuration. - pub fn stderr(mut self, io: impl Into) -> Self { - self.stderr = io.into(); - self - } - - /// Spawn the child process. Returns the transport plus an optional captured stderr handle. - pub fn spawn(mut self) -> std::io::Result<(TokioChildProcess, Option)> { - self.cmd - .command_mut() - .stdin(self.stdin) - .stdout(self.stdout) - .stderr(self.stderr); - - let (child, stdout, stdin, stderr_opt) = split_child_process(self.cmd.spawn()?)?; - - let transport = AsyncRwTransport::new(stdout, stdin); - let proc = TokioChildProcess { - child: ChildWithCleanup { inner: Some(child) }, - transport, - }; - Ok((proc, stderr_opt)) - } -} - -impl Transport for TokioChildProcess { - type Error = std::io::Error; - - fn send( - &mut self, - item: TxJsonRpcMessage, - ) -> impl Future> + Send + 'static { - self.transport.send(item) - } - - fn receive(&mut self) -> impl Future>> + Send { - self.transport.receive() - } - - fn close(&mut self) -> impl Future> + Send { - self.graceful_shutdown() - } -} - -pub trait ConfigureCommandExt { - fn configure(self, f: impl FnOnce(&mut Self)) -> Self; -} - -impl ConfigureCommandExt for tokio::process::Command { - fn configure(mut self, f: impl FnOnce(&mut Self)) -> Self { - f(&mut self); - self - } -} - -#[cfg(unix)] -#[cfg(test)] -mod tests { - use tokio::process::Command; - - use super::*; - - #[tokio::test] - async fn test_tokio_child_process_drop() { - let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { - cmd.arg("30"); - })); - assert!(r.is_ok()); - let child_process = r.unwrap(); - let id = child_process.id(); - assert!(id.is_some()); - let id = id.unwrap(); - // Drop the child process - drop(child_process); - // Wait a moment to allow the cleanup task to run - tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; - // Check if the process is still running - let status = Command::new("ps") - .arg("-p") - .arg(id.to_string()) - .status() - .await; - match status { - Ok(status) => { - assert!( - !status.success(), - "Process with PID {} is still running", - id - ); - } - Err(e) => { - panic!("Failed to check process status: {}", e); - } - } - } - - #[tokio::test] - async fn test_tokio_child_process_graceful_shutdown() { - let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { - cmd.arg("30"); - })); - assert!(r.is_ok()); - let mut child_process = r.unwrap(); - let id = child_process.id(); - assert!(id.is_some()); - let id = id.unwrap(); - child_process.graceful_shutdown().await.unwrap(); - // Wait a moment to allow the cleanup task to run - tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; - // Check if the process is still running - let status = Command::new("ps") - .arg("-p") - .arg(id.to_string()) - .status() - .await; - match status { - Ok(status) => { - assert!( - !status.success(), - "Process with PID {} is still running", - id - ); - } - Err(e) => { - panic!("Failed to check process status: {}", e); - } - } - } -} diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs index c82e16db..7e74551d 100644 --- a/crates/rmcp/src/transport/child_process2.rs +++ b/crates/rmcp/src/transport/child_process2.rs @@ -1,3 +1,5 @@ pub mod runner; -pub mod tokio; pub mod transport; + +#[cfg(feature = "transport-child-process-tokio")] +pub mod tokio; diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs index 6f062065..784f1697 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -2,7 +2,7 @@ use futures::{ FutureExt, io::{AsyncRead, AsyncWrite}, }; -use std::process::Stdio; +use std::{path::PathBuf, process::Stdio}; use crate::util::PinnedFuture; @@ -104,11 +104,7 @@ pub trait ChildProcessRunner { /// The implementation of the child process instance that this runner will spawn. type Instance: ChildProcessInstance; - fn spawn( - command: &str, - args: &[&str], - stdio_config: StdioConfig, - ) -> Result; + fn spawn(command_config: CommandConfig) -> Result; } /// A containing wrapper around a child process instance. This struct erases the type @@ -222,10 +218,8 @@ impl ChildProcessInstance for ChildProcess { } pub struct CommandBuilder { - command: String, - args: Vec, + config: CommandConfig, _marker: std::marker::PhantomData, - stderr: Stdio, } pub enum CommandBuilderError { @@ -249,61 +243,78 @@ impl CommandBuilder { let args = iter.map(|s| s.into()).collect(); Ok(Self { - command, - args, + config: CommandConfig { + command, + args, + cwd: None, + stdio_config: StdioConfig { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + }, + }, _marker: std::marker::PhantomData, - stderr: Stdio::inherit(), }) } /// Create a CommandBuilder from a command and an optional list of args. pub fn new(command: impl Into) -> Self { Self { - command: command.into(), - args: Vec::new(), + config: CommandConfig { + command: command.into(), + args: Vec::new(), + cwd: None, + stdio_config: StdioConfig { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + }, + }, _marker: std::marker::PhantomData, - stderr: Stdio::inherit(), } } /// Add a single argument to the command. pub fn arg(mut self, arg: impl Into) -> Self { - self.args.push(arg.into()); + self.config.args.push(arg.into()); self } /// Add multiple arguments to the command. pub fn args(mut self, args: impl IntoIterator>) -> Self { - self.args.extend(args.into_iter().map(|arg| arg.into())); + self.config + .args + .extend(args.into_iter().map(|arg| arg.into())); self } /// Sets what happens to stderr for the command. /// By default if not set, stderr is inherited from the parent process. pub fn stderr(mut self, _stdio: Stdio) -> Self { - self.stderr = _stdio; + self.config.stdio_config.stderr = _stdio; + self + } + + pub fn current_dir(mut self, cwd: impl Into) -> Self { + self.config.cwd = Some(cwd.into()); self } } +pub struct CommandConfig { + pub command: String, + pub args: Vec, + pub cwd: Option, + pub stdio_config: StdioConfig, +} + impl CommandBuilder where R: ChildProcessRunner, { /// Spawn the command into its typed child process instance type. pub fn spawn_raw(self) -> Result { - // We should always pipe stdin and stdout. - let stdio_config = StdioConfig { - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: self.stderr, - }; - - R::spawn( - &self.command, - &self.args.iter().map(|s| s.as_str()).collect::>(), - stdio_config, - ) + R::spawn(self.config) } /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process2/tokio.rs index dd75e0e1..0beed76c 100644 --- a/crates/rmcp/src/transport/child_process2/tokio.rs +++ b/crates/rmcp/src/transport/child_process2/tokio.rs @@ -1,11 +1,12 @@ use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use crate::transport::child_process2::runner::{ - ChildProcessInstance, ChildProcessRunner, RunnerSpawnError, StdioConfig, + ChildProcessInstance, ChildProcessRunner, CommandConfig, RunnerSpawnError, }; pub struct TokioChildProcessRunner {} +/// An implementation for the tokio Child Process pub struct TokioChildProcess { inner: tokio::process::Child, } @@ -63,16 +64,17 @@ impl ChildProcessInstance for TokioChildProcess { impl ChildProcessRunner for TokioChildProcessRunner { type Instance = TokioChildProcess; - fn spawn( - command: &str, - args: &[&str], - stdio_configuration: StdioConfig, - ) -> Result { - tokio::process::Command::new(command) - .args(args) - .stdin(stdio_configuration.stdin) - .stdout(stdio_configuration.stdout) - .stderr(stdio_configuration.stderr) + fn spawn(command_config: CommandConfig) -> Result { + tokio::process::Command::new(command_config.command) + .args(command_config.args) + .stdin(command_config.stdio_config.stdin) + .stdout(command_config.stdio_config.stdout) + .stderr(command_config.stdio_config.stderr) + .current_dir( + command_config + .cwd + .unwrap_or_else(|| std::env::current_dir().unwrap()), + ) .kill_on_drop(true) .spawn() .map(|child| TokioChildProcess { inner: child }) diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 31842773..6e39b5da 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -2,7 +2,7 @@ use rmcp::{ ServiceExt, service::QuitReason, transport::{ - ConfigureCommandExt, StreamableHttpClientTransport, StreamableHttpServerConfig, + StreamableHttpClientTransport, StreamableHttpServerConfig, child_process2::{ runner::{ChildProcessControl, CommandBuilder}, tokio::TokioChildProcessRunner, @@ -44,7 +44,7 @@ async fn test_with_js_stdio_server() -> anyhow::Result<()> { tracing::info!("Spawned child process with PID: {}", node_cmd.pid()); let transport = ChildProcessTransport::new(node_cmd) - .map_err(|e| anyhow::anyhow!("Failed to spawn child process: {e}"))?; + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; let (client, work) = ().serve(transport).await?; diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 3f883c96..014b959d 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -1,10 +1,16 @@ use std::process::Stdio; +use futures::AsyncReadExt; use rmcp::{ ServiceExt, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + ChildProcess, ChildProcessInstance, + child_process2::{ + runner::CommandBuilder, tokio::TokioChildProcessRunner, + transport::ChildProcessTransport, + }, + }, }; -use tokio::io::AsyncReadExt; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; @@ -29,18 +35,21 @@ async fn init() -> anyhow::Result<()> { async fn test_with_python_server() -> anyhow::Result<()> { init().await?; - let transport = TokioChildProcess::new(tokio::process::Command::new("uv").configure(|cmd| { - cmd.arg("run") - .arg("server.py") - .current_dir("tests/test_with_python"); - }))?; + let server_command = CommandBuilder::::new("uv") + .args(["run", "server.py"]) + .current_dir("tests/test_with_python") + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(server_command) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; - let client = ().serve(transport).await?; + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -48,15 +57,14 @@ async fn test_with_python_server() -> anyhow::Result<()> { async fn test_with_python_server_stderr() -> anyhow::Result<()> { init().await?; - let (transport, stderr) = - TokioChildProcess::builder(tokio::process::Command::new("uv").configure(|cmd| { - cmd.arg("run") - .arg("server.py") - .current_dir("tests/test_with_python"); - })) + let mut server_command = CommandBuilder::::new("uv") + .args(["run", "server.py"]) + .current_dir("tests/test_with_python") .stderr(Stdio::piped()) - .spawn()?; + .spawn_dyn()?; + let stderr: Option<::Stderr> = + server_command.take_stderr().into(); let mut stderr = stderr.expect("stderr must be piped"); let stderr_task = tokio::spawn(async move { @@ -65,10 +73,14 @@ async fn test_with_python_server_stderr() -> anyhow::Result<()> { Ok::<_, std::io::Error>(buffer) }); - let client = ().serve(transport).await?; + let transport = ChildProcessTransport::new(server_command) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; + + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let _ = client.list_all_resources().await?; let _ = client.list_all_tools().await?; - client.cancel().await?; + client.cancel().await; let stderr_output = stderr_task.await??; assert!(stderr_output.contains("server starting up...")); From 02b53cd500705864fc61e42dd40776901b4f389c Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:18:22 +0000 Subject: [PATCH 07/11] refactor(child-process): add env to command, move builder to separate file --- crates/rmcp/src/transport.rs | 6 +- crates/rmcp/src/transport/child_process2.rs | 1 + .../src/transport/child_process2/builder.rs | 149 ++++++++++++++++++ .../src/transport/child_process2/runner.rs | 123 +-------------- .../src/transport/child_process2/tokio.rs | 6 +- crates/rmcp/tests/test_with_python.rs | 2 +- 6 files changed, 159 insertions(+), 128 deletions(-) create mode 100644 crates/rmcp/src/transport/child_process2/builder.rs diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 99f84c05..c23f3a07 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -84,9 +84,9 @@ pub use worker::WorkerTransport; #[cfg(feature = "transport-child-process")] pub mod child_process2; #[cfg(feature = "transport-child-process")] -pub use child_process2::runner::{ - ChildProcess, ChildProcessInstance, ChildProcessRunner, CommandBuilder, -}; +pub use child_process2::builder::CommandBuilder; +#[cfg(feature = "transport-child-process")] +pub use child_process2::runner::{ChildProcess, ChildProcessInstance, ChildProcessRunner}; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs index 7e74551d..44db4a45 100644 --- a/crates/rmcp/src/transport/child_process2.rs +++ b/crates/rmcp/src/transport/child_process2.rs @@ -1,3 +1,4 @@ +pub mod builder; pub mod runner; pub mod transport; diff --git a/crates/rmcp/src/transport/child_process2/builder.rs b/crates/rmcp/src/transport/child_process2/builder.rs new file mode 100644 index 00000000..f1216dac --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/builder.rs @@ -0,0 +1,149 @@ +use std::{collections::HashMap, hash::Hash, path::PathBuf, process::Stdio}; + +use crate::transport::{ + ChildProcess, ChildProcessRunner, child_process2::runner::RunnerSpawnError, +}; + +/// A builder for constructing a command to spawn a child process, with typical command +/// configuration options like `args` and `current_dir`. +pub struct CommandBuilder { + config: CommandConfig, + _marker: std::marker::PhantomData, +} + +#[derive(Debug, thiserror::Error)] +pub enum CommandBuilderError { + #[error("Command cannot be empty")] + EmptyCommand, +} + +impl CommandBuilder { + /// Create a CommandBuilder from an argv-style list of strings, where the first element is the command, and the rest are the args. + pub fn from_argv(argv: I) -> Result + where + I: IntoIterator, + S: Into, + { + let mut iter = argv.into_iter(); + + // Pop the first element as the command, and use the rest as args + let command = match iter.next() { + Some(cmd) => cmd.into(), + None => return Err(CommandBuilderError::EmptyCommand), + }; + + let args = iter.map(|s| s.into()).collect(); + Ok(Self { + config: CommandConfig { + command, + args, + ..Default::default() + }, + _marker: std::marker::PhantomData, + }) + } + + /// Create a CommandBuilder from a command and an optional list of args. + pub fn new(command: impl Into) -> Self { + Self { + config: CommandConfig { + command: command.into(), + ..Default::default() + }, + _marker: std::marker::PhantomData, + } + } + + /// Add a single argument to the command. + pub fn arg(mut self, arg: impl Into) -> Self { + self.config.args.push(arg.into()); + self + } + + /// Add multiple arguments to the command. + pub fn args(mut self, args: impl IntoIterator>) -> Self { + self.config + .args + .extend(args.into_iter().map(|arg| arg.into())); + self + } + + /// Set an environment variable for the command. + pub fn env(mut self, key: impl Into, value: impl Into) -> Self { + self.config.env.insert(key.into(), value.into()); + self + } + + /// Set multiple environment variables for the command. + pub fn envs( + mut self, + envs: impl IntoIterator, impl Into)>, + ) -> Self { + self.config + .env + .extend(envs.into_iter().map(|(k, v)| (k.into(), v.into()))); + self + } + + /// Sets what happens to stderr for the command. + /// By default if not set, stderr is inherited from the parent process. + pub fn stderr(mut self, _stdio: Stdio) -> Self { + self.config.stdio_config.stderr = _stdio; + self + } + + pub fn current_dir(mut self, cwd: impl Into) -> Self { + self.config.cwd = Some(cwd.into()); + self + } +} + +/// A structure that requests how the child process streams should +/// be configured when spawning. +#[derive(Debug)] +pub struct StdioConfig { + pub stdin: Stdio, + pub stdout: Stdio, + pub stderr: Stdio, +} + +impl Default for StdioConfig { + fn default() -> Self { + Self { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + } + } +} + +/// A structure that requests how the command should be executed +#[derive(Debug, Default)] +pub struct CommandConfig { + pub command: String, + pub args: Vec, + pub cwd: Option, + pub stdio_config: StdioConfig, + pub env: HashMap, +} + +impl CommandBuilder +where + R: ChildProcessRunner, +{ + /// Spawn the command into its typed child process instance type. + pub fn spawn_raw(self) -> Result { + R::spawn(self.config) + } + + /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. + /// + /// Requires `R::Instance` to be [Send] and `'static`. + pub fn spawn_dyn(self) -> Result + where + R::Instance: Send + 'static, + { + let instance = self.spawn_raw()?; + Ok(ChildProcess::new(instance)) + } +} diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs index 784f1697..f447d9aa 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -2,9 +2,8 @@ use futures::{ FutureExt, io::{AsyncRead, AsyncWrite}, }; -use std::{path::PathBuf, process::Stdio}; -use crate::util::PinnedFuture; +use crate::{transport::child_process2::builder::CommandConfig, util::PinnedFuture}; /// A simple enum for describing if a stream is available, unused, or already taken. #[derive(Debug)] @@ -27,14 +26,6 @@ impl From> for Option { } } -/// A structure that requests how the child process streams should -/// be configured when spawning. -pub struct StdioConfig { - pub stdin: Stdio, - pub stdout: Stdio, - pub stderr: Stdio, -} - /// The contract for what an instance of a child process /// must provide to be used with a transport. pub trait ChildProcessInstance { @@ -216,115 +207,3 @@ impl ChildProcessInstance for ChildProcess { self.inner.kill() } } - -pub struct CommandBuilder { - config: CommandConfig, - _marker: std::marker::PhantomData, -} - -pub enum CommandBuilderError { - EmptyCommand, -} - -impl CommandBuilder { - /// Create a CommandBuilder from an argv-style list of strings, where the first element is the command, and the rest are the args. - pub fn from_argv(argv: I) -> Result - where - I: IntoIterator, - S: Into, - { - let mut iter = argv.into_iter(); - - // Pop the first element as the command, and use the rest as args - let command = match iter.next() { - Some(cmd) => cmd.into(), - None => return Err(CommandBuilderError::EmptyCommand), - }; - - let args = iter.map(|s| s.into()).collect(); - Ok(Self { - config: CommandConfig { - command, - args, - cwd: None, - stdio_config: StdioConfig { - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: Stdio::inherit(), - }, - }, - _marker: std::marker::PhantomData, - }) - } - - /// Create a CommandBuilder from a command and an optional list of args. - pub fn new(command: impl Into) -> Self { - Self { - config: CommandConfig { - command: command.into(), - args: Vec::new(), - cwd: None, - stdio_config: StdioConfig { - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: Stdio::inherit(), - }, - }, - _marker: std::marker::PhantomData, - } - } - - /// Add a single argument to the command. - pub fn arg(mut self, arg: impl Into) -> Self { - self.config.args.push(arg.into()); - self - } - - /// Add multiple arguments to the command. - pub fn args(mut self, args: impl IntoIterator>) -> Self { - self.config - .args - .extend(args.into_iter().map(|arg| arg.into())); - self - } - - /// Sets what happens to stderr for the command. - /// By default if not set, stderr is inherited from the parent process. - pub fn stderr(mut self, _stdio: Stdio) -> Self { - self.config.stdio_config.stderr = _stdio; - self - } - - pub fn current_dir(mut self, cwd: impl Into) -> Self { - self.config.cwd = Some(cwd.into()); - self - } -} - -pub struct CommandConfig { - pub command: String, - pub args: Vec, - pub cwd: Option, - pub stdio_config: StdioConfig, -} - -impl CommandBuilder -where - R: ChildProcessRunner, -{ - /// Spawn the command into its typed child process instance type. - pub fn spawn_raw(self) -> Result { - R::spawn(self.config) - } - - /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. - /// - /// Requires `R::Instance` to be [Send] and `'static`. - pub fn spawn_dyn(self) -> Result - where - R::Instance: Send + 'static, - { - let instance = self.spawn_raw()?; - Ok(ChildProcess::new(instance)) - } -} diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process2/tokio.rs index 0beed76c..4733d457 100644 --- a/crates/rmcp/src/transport/child_process2/tokio.rs +++ b/crates/rmcp/src/transport/child_process2/tokio.rs @@ -1,7 +1,8 @@ use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use crate::transport::child_process2::runner::{ - ChildProcessInstance, ChildProcessRunner, CommandConfig, RunnerSpawnError, +use crate::transport::child_process2::{ + builder::CommandConfig, + runner::{ChildProcessInstance, ChildProcessRunner, RunnerSpawnError}, }; pub struct TokioChildProcessRunner {} @@ -67,6 +68,7 @@ impl ChildProcessRunner for TokioChildProcessRunner { fn spawn(command_config: CommandConfig) -> Result { tokio::process::Command::new(command_config.command) .args(command_config.args) + .envs(command_config.env) .stdin(command_config.stdio_config.stdin) .stdout(command_config.stdio_config.stdout) .stderr(command_config.stdio_config.stderr) diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 014b959d..d880eb03 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -6,7 +6,7 @@ use rmcp::{ transport::{ ChildProcess, ChildProcessInstance, child_process2::{ - runner::CommandBuilder, tokio::TokioChildProcessRunner, + builder::CommandBuilder, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, }, }, From 86977b678b0d0118fec8be6b1ed25f0917604c4d Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:32:01 +0000 Subject: [PATCH 08/11] refactor(example): fix example compilation --- crates/rmcp/src/error.rs | 2 -- crates/rmcp/src/handler/client/progress.rs | 4 +-- crates/rmcp/src/service.rs | 1 - .../src/transport/child_process2/builder.rs | 2 +- .../src/transport/child_process2/transport.rs | 4 +-- examples/simple-chat-client/Cargo.toml | 3 +- examples/simple-chat-client/src/config.rs | 34 ++++++++++++------- 7 files changed, 28 insertions(+), 22 deletions(-) diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs index c7901f4b..3bd528f4 100644 --- a/crates/rmcp/src/error.rs +++ b/crates/rmcp/src/error.rs @@ -30,8 +30,6 @@ pub enum RmcpError { #[cfg(feature = "server")] #[error("Server initialization error: {0}")] ServerInitialize(#[from] crate::service::ServerInitializeError), - #[error("Runtime error: {0}")] - Runtime(#[from] tokio::task::JoinError), #[error("Transport creation error: {error}")] // TODO: Maybe we can introduce something like `TryIntoTransport` to auto wrap transport type, // but it could be an breaking change, so we could do it in the future. diff --git a/crates/rmcp/src/handler/client/progress.rs b/crates/rmcp/src/handler/client/progress.rs index 7dd84f08..892de197 100644 --- a/crates/rmcp/src/handler/client/progress.rs +++ b/crates/rmcp/src/handler/client/progress.rs @@ -1,7 +1,5 @@ -use std::{collections::HashMap, sync::Arc}; - use futures::{Stream, StreamExt}; -use tokio::sync::{RwLock, broadcast}; +use tokio::sync::broadcast; use tokio_stream::wrappers::BroadcastStream; use crate::{ diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index e3fae794..f4374dea 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -598,7 +598,6 @@ impl RunningServiceCancellationToken { pub enum QuitReason { Cancelled, Closed, - JoinError(tokio::task::JoinError), } /// Request execution context diff --git a/crates/rmcp/src/transport/child_process2/builder.rs b/crates/rmcp/src/transport/child_process2/builder.rs index f1216dac..cf1a2e11 100644 --- a/crates/rmcp/src/transport/child_process2/builder.rs +++ b/crates/rmcp/src/transport/child_process2/builder.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, hash::Hash, path::PathBuf, process::Stdio}; +use std::{collections::HashMap, path::PathBuf, process::Stdio}; use crate::transport::{ ChildProcess, ChildProcessRunner, child_process2::runner::RunnerSpawnError, diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process2/transport.rs index a6731347..30f15a2d 100644 --- a/crates/rmcp/src/transport/child_process2/transport.rs +++ b/crates/rmcp/src/transport/child_process2/transport.rs @@ -11,7 +11,7 @@ use crate::{ }; pub struct ChildProcessTransport { - child: Box, + _child: Box, framed_transport: AsyncRwTransport< R, Box, @@ -40,7 +40,7 @@ where ); Ok(Self { - child: control, + _child: control, framed_transport, }) } diff --git a/examples/simple-chat-client/Cargo.toml b/examples/simple-chat-client/Cargo.toml index e382e63c..ab24878c 100644 --- a/examples/simple-chat-client/Cargo.toml +++ b/examples/simple-chat-client/Cargo.toml @@ -17,6 +17,7 @@ toml = "1.0" rmcp = { workspace = true, features = [ "client", "transport-child-process", - "transport-streamable-http-client-reqwest" + "transport-child-process-tokio", + "transport-streamable-http-client-reqwest", ] } clap = { version = "4.0", features = ["derive"] } diff --git a/examples/simple-chat-client/src/config.rs b/examples/simple-chat-client/src/config.rs index e469c280..6d0292b3 100644 --- a/examples/simple-chat-client/src/config.rs +++ b/examples/simple-chat-client/src/config.rs @@ -1,7 +1,14 @@ -use std::{collections::HashMap, path::Path, process::Stdio}; +use std::{collections::HashMap, path::Path}; use anyhow::Result; -use rmcp::{RoleClient, ServiceExt, service::RunningService, transport::ConfigureCommandExt}; +use rmcp::{ + RoleClient, ServiceExt, + service::RunningService, + transport::{ + CommandBuilder, + child_process2::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -47,22 +54,25 @@ impl McpServerTransportConfig { McpServerTransportConfig::Streamable { url } => { let transport = rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); - ().serve(transport).await? + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } McpServerTransportConfig::Stdio { command, args, envs, } => { - let transport = rmcp::transport::child_process::TokioChildProcess::new( - tokio::process::Command::new(command).configure(|cmd| { - cmd.args(args) - .envs(envs) - .stderr(Stdio::inherit()) - .stdout(Stdio::inherit()); - }), - )?; - ().serve(transport).await? + let cmd = CommandBuilder::::new(command) + .args(args) + .envs(envs) + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(cmd) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } }; Ok(client) From b746384d19e4cbde6d0c9d748165c3f571aa3317 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:34:11 +0000 Subject: [PATCH 09/11] refactor(child-process): rename module back to "child-process" --- crates/rmcp/src/transport.rs | 6 +++--- .../src/transport/{child_process2.rs => child_process.rs} | 0 .../transport/{child_process2 => child_process}/builder.rs | 4 +--- .../transport/{child_process2 => child_process}/runner.rs | 2 +- .../transport/{child_process2 => child_process}/tokio.rs | 2 +- .../{child_process2 => child_process}/transport.rs | 2 +- crates/rmcp/tests/test_with_js.rs | 2 +- crates/rmcp/tests/test_with_python.rs | 2 +- examples/simple-chat-client/src/config.rs | 2 +- 9 files changed, 10 insertions(+), 12 deletions(-) rename crates/rmcp/src/transport/{child_process2.rs => child_process.rs} (100%) rename crates/rmcp/src/transport/{child_process2 => child_process}/builder.rs (97%) rename crates/rmcp/src/transport/{child_process2 => child_process}/runner.rs (98%) rename crates/rmcp/src/transport/{child_process2 => child_process}/tokio.rs (98%) rename crates/rmcp/src/transport/{child_process2 => child_process}/transport.rs (96%) diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index c23f3a07..de12321c 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -82,11 +82,11 @@ pub mod worker; pub use worker::WorkerTransport; #[cfg(feature = "transport-child-process")] -pub mod child_process2; +pub mod child_process; #[cfg(feature = "transport-child-process")] -pub use child_process2::builder::CommandBuilder; +pub use child_process::builder::CommandBuilder; #[cfg(feature = "transport-child-process")] -pub use child_process2::runner::{ChildProcess, ChildProcessInstance, ChildProcessRunner}; +pub use child_process::runner::{ChildProcess, ChildProcessInstance, ChildProcessRunner}; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process.rs similarity index 100% rename from crates/rmcp/src/transport/child_process2.rs rename to crates/rmcp/src/transport/child_process.rs diff --git a/crates/rmcp/src/transport/child_process2/builder.rs b/crates/rmcp/src/transport/child_process/builder.rs similarity index 97% rename from crates/rmcp/src/transport/child_process2/builder.rs rename to crates/rmcp/src/transport/child_process/builder.rs index cf1a2e11..e23303ae 100644 --- a/crates/rmcp/src/transport/child_process2/builder.rs +++ b/crates/rmcp/src/transport/child_process/builder.rs @@ -1,8 +1,6 @@ use std::{collections::HashMap, path::PathBuf, process::Stdio}; -use crate::transport::{ - ChildProcess, ChildProcessRunner, child_process2::runner::RunnerSpawnError, -}; +use crate::transport::{ChildProcess, ChildProcessRunner, child_process::runner::RunnerSpawnError}; /// A builder for constructing a command to spawn a child process, with typical command /// configuration options like `args` and `current_dir`. diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process/runner.rs similarity index 98% rename from crates/rmcp/src/transport/child_process2/runner.rs rename to crates/rmcp/src/transport/child_process/runner.rs index f447d9aa..7c2bb3f4 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process/runner.rs @@ -3,7 +3,7 @@ use futures::{ io::{AsyncRead, AsyncWrite}, }; -use crate::{transport::child_process2::builder::CommandConfig, util::PinnedFuture}; +use crate::{transport::child_process::builder::CommandConfig, util::PinnedFuture}; /// A simple enum for describing if a stream is available, unused, or already taken. #[derive(Debug)] diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process/tokio.rs similarity index 98% rename from crates/rmcp/src/transport/child_process2/tokio.rs rename to crates/rmcp/src/transport/child_process/tokio.rs index 4733d457..edb3c061 100644 --- a/crates/rmcp/src/transport/child_process2/tokio.rs +++ b/crates/rmcp/src/transport/child_process/tokio.rs @@ -1,6 +1,6 @@ use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use crate::transport::child_process2::{ +use crate::transport::child_process::{ builder::CommandConfig, runner::{ChildProcessInstance, ChildProcessRunner, RunnerSpawnError}, }; diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process/transport.rs similarity index 96% rename from crates/rmcp/src/transport/child_process2/transport.rs rename to crates/rmcp/src/transport/child_process/transport.rs index 30f15a2d..9f46d9d9 100644 --- a/crates/rmcp/src/transport/child_process2/transport.rs +++ b/crates/rmcp/src/transport/child_process/transport.rs @@ -6,7 +6,7 @@ use crate::{ transport::{ Transport, async_rw::AsyncRwTransport, - child_process2::runner::{ChildProcess, ChildProcessControl}, + child_process::runner::{ChildProcess, ChildProcessControl}, }, }; diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 6e39b5da..ed1c0b7d 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -3,7 +3,7 @@ use rmcp::{ service::QuitReason, transport::{ StreamableHttpClientTransport, StreamableHttpServerConfig, - child_process2::{ + child_process::{ runner::{ChildProcessControl, CommandBuilder}, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index d880eb03..074b8447 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -5,7 +5,7 @@ use rmcp::{ ServiceExt, transport::{ ChildProcess, ChildProcessInstance, - child_process2::{ + child_process::{ builder::CommandBuilder, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, }, diff --git a/examples/simple-chat-client/src/config.rs b/examples/simple-chat-client/src/config.rs index 6d0292b3..946436fd 100644 --- a/examples/simple-chat-client/src/config.rs +++ b/examples/simple-chat-client/src/config.rs @@ -6,7 +6,7 @@ use rmcp::{ service::RunningService, transport::{ CommandBuilder, - child_process2::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, }, }; use serde::{Deserialize, Serialize}; From 0adab2536b347810afd743c849ecb1fd74b14aa1 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 19:32:28 +0000 Subject: [PATCH 10/11] refactor(test): re-introduce tests for child process dropping --- crates/rmcp/Cargo.toml | 2 - .../src/transport/child_process/runner.rs | 3 + .../rmcp/src/transport/child_process/tokio.rs | 79 ++++++++++++++++++- 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index d08951d0..ae5262c8 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -74,8 +74,6 @@ chrono = { version = "0.4.38", default-features = false, features = [ "oldtime", ] } -[target.'cfg(test)'] - [features] default = ["base64", "macros", "server"] client = ["dep:tokio-stream"] diff --git a/crates/rmcp/src/transport/child_process/runner.rs b/crates/rmcp/src/transport/child_process/runner.rs index 7c2bb3f4..bc457bb2 100644 --- a/crates/rmcp/src/transport/child_process/runner.rs +++ b/crates/rmcp/src/transport/child_process/runner.rs @@ -87,6 +87,9 @@ pub enum RunnerSpawnError { /// The child process instance failed to spawn. #[error("Failed to spawn child process: {0}")] SpawnError(#[from] std::io::Error), + /// The child process instance did not have a PID assigned (this is unexpected for a spawned process). + #[error("Child process did not have a PID assigned after spawning")] + NoPidAssigned, #[error("Other error: {0}")] Other(Box), } diff --git a/crates/rmcp/src/transport/child_process/tokio.rs b/crates/rmcp/src/transport/child_process/tokio.rs index edb3c061..db835afb 100644 --- a/crates/rmcp/src/transport/child_process/tokio.rs +++ b/crates/rmcp/src/transport/child_process/tokio.rs @@ -10,6 +10,8 @@ pub struct TokioChildProcessRunner {} /// An implementation for the tokio Child Process pub struct TokioChildProcess { inner: tokio::process::Child, + /// The PID at the time of spawning. + pid: u32, } impl ChildProcessInstance for TokioChildProcess { @@ -41,8 +43,7 @@ impl ChildProcessInstance for TokioChildProcess { } fn pid(&self) -> u32 { - // TODO: Consider refactor to return Option to avoid confusion of 0 as a valid PID. - self.inner.id().unwrap_or(0) + self.pid } fn wait<'s>( @@ -79,7 +80,79 @@ impl ChildProcessRunner for TokioChildProcessRunner { ) .kill_on_drop(true) .spawn() - .map(|child| TokioChildProcess { inner: child }) .map_err(RunnerSpawnError::SpawnError) + .and_then(|child| { + let pid = child.id().ok_or_else(|| RunnerSpawnError::NoPidAssigned)?; + Ok(TokioChildProcess { inner: child, pid }) + }) + } +} + +#[cfg(test)] +mod test { + + use crate::transport::CommandBuilder; + use tokio::process::Command; + + use super::*; + + async fn check_pid(pid: u32) -> std::io::Result { + // This command will output only process numbers on each line. + let output = Command::new("ps") + .arg("-o") + .arg("pid=") + .arg("-p") + .arg(pid.to_string()) + .output() + .await?; + + let output_str = String::from_utf8_lossy(&output.stdout); + Ok(output_str + .lines() + .any(|line| line.trim() == pid.to_string())) + } + + #[cfg(unix)] + #[tokio::test] + async fn test_kill_on_drop() { + let child = CommandBuilder::::new("sleep") + .args(["10"]) + .spawn_raw() + .expect("Failed to spawn child process"); + + let pid = child.pid(); + + // Drop the child process without waiting for it to exit, which should kill it due to `kill_on_drop(true)`. + drop(child); + + // Wait a moment to ensure the process has been killed. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let pid_found = check_pid(pid).await.expect("Failed to check if PID exists"); + + assert!(!pid_found, "Child process was not killed on drop"); + } + + #[tokio::test] + async fn test_graceful_shutdown() { + let mut child = CommandBuilder::::new("sleep") + .args(["10"]) + .spawn_raw() + .expect("Failed to spawn child process"); + + let pid = child.pid(); + + // Sleep a moment to ensure the process is running before we attempt to shut it down. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + child + .graceful_shutdown() + .await + .expect("Failed to gracefully shutdown child process"); + + // We should not need to wait here since we await the graceful shutdown above. + // Graceful shutdown *should* cover waiting for the process to exit. + let pid_found = check_pid(pid).await.expect("Failed to check if PID exists"); + assert!(!pid_found, "Child process was not shutdown"); } } From 007dd925c5ceb6f47edb4ef2db770821a9160fad Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 19:58:50 +0000 Subject: [PATCH 11/11] refactor: revert some unnecessary module visibility changes fix unit test --- crates/rmcp/Cargo.toml | 1 + crates/rmcp/src/transport/common.rs | 2 +- crates/rmcp/src/transport/common/reqwest.rs | 2 +- crates/rmcp/tests/test_with_js.rs | 5 ++--- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index ae5262c8..9117db70 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -180,6 +180,7 @@ required-features = [ "transport-child-process-tokio", "transport-streamable-http-server", "transport-streamable-http-client", + "transport-streamable-http-client-reqwest", "__reqwest", ] path = "tests/test_with_js.rs" diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index b41a8f3c..615b0e27 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -4,7 +4,7 @@ pub mod server_side_http; pub mod http_header; #[cfg(feature = "__reqwest")] -pub mod reqwest; +mod reqwest; // Note: This module provides SSE stream parsing and auto-reconnect utilities. // It's used by the streamable HTTP client (which receives SSE-formatted responses), diff --git a/crates/rmcp/src/transport/common/reqwest.rs b/crates/rmcp/src/transport/common/reqwest.rs index 696aa912..42075921 100644 --- a/crates/rmcp/src/transport/common/reqwest.rs +++ b/crates/rmcp/src/transport/common/reqwest.rs @@ -1,2 +1,2 @@ #[cfg(feature = "transport-streamable-http-client-reqwest")] -pub mod streamable_http_client; +mod streamable_http_client; diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index ed1c0b7d..40228b58 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -2,10 +2,9 @@ use rmcp::{ ServiceExt, service::QuitReason, transport::{ - StreamableHttpClientTransport, StreamableHttpServerConfig, + CommandBuilder, StreamableHttpClientTransport, StreamableHttpServerConfig, child_process::{ - runner::{ChildProcessControl, CommandBuilder}, - tokio::TokioChildProcessRunner, + runner::ChildProcessControl, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, }, streamable_http_server::{