Skip to content
5 changes: 5 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,8 @@ required-features = [
"transport-streamable-http-server",
]
path = "tests/test_custom_headers.rs"

[[test]]
name = "test_sse_concurrent_streams"
required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"]
path = "tests/test_sse_concurrent_streams.rs"
103 changes: 77 additions & 26 deletions crates/rmcp/src/transport/streamable_http_server/session/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ pub struct LocalSessionWorker {
tx_router: HashMap<HttpRequestId, HttpRequestWise>,
resource_router: HashMap<ResourceKey, HttpRequestId>,
common: CachedTx,
/// Shadow senders for secondary SSE streams (e.g. from POST EventSource
/// reconnections). These keep the HTTP connections alive via SSE keep-alive
/// without receiving notifications, preventing MCP clients from entering
/// infinite reconnect loops when multiple EventSource connections compete
/// to replace the common channel.
shadow_txs: Vec<Sender<ServerSseMessage>>,
event_rx: Receiver<SessionEvent>,
session_config: SessionConfig,
}
Expand Down Expand Up @@ -513,36 +519,77 @@ impl LocalSessionWorker {
&mut self,
last_event_id: EventId,
) -> Result<StreamableHttpMessageReceiver, SessionError> {
// Clean up closed shadow senders before processing
self.shadow_txs.retain(|tx| !tx.is_closed());

match last_event_id.http_request_id {
Some(http_request_id) => {
let request_wise = self
.tx_router
.get_mut(&http_request_id)
.ok_or(SessionError::ChannelClosed(Some(http_request_id)))?;
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
request_wise.tx.tx = tx;
let index = last_event_id.index;
// sync messages after index
request_wise.tx.sync(index).await?;
Ok(StreamableHttpMessageReceiver {
http_request_id: Some(http_request_id),
inner: rx,
})
}
None => {
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
self.common.tx = tx;
let index = last_event_id.index;
// sync messages after index
self.common.sync(index).await?;
Ok(StreamableHttpMessageReceiver {
http_request_id: None,
inner: rx,
})
if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
// Resume existing request-wise channel
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
request_wise.tx.tx = tx;
let index = last_event_id.index;
// sync messages after index
request_wise.tx.sync(index).await?;
Ok(StreamableHttpMessageReceiver {
http_request_id: Some(http_request_id),
inner: rx,
})
} else {
// Request-wise channel completed (POST response already delivered).
// The client's EventSource is reconnecting after the POST SSE stream
// ended. Fall through to common channel handling below.
tracing::debug!(
http_request_id,
"Request-wise channel completed, falling back to common channel"
);
self.resume_or_shadow_common(last_event_id.index).await
}
}
None => self.resume_or_shadow_common(last_event_id.index).await,
}
}

/// Resume the common channel, or create a shadow stream if the primary is
/// still active.
///
/// When the primary common channel is dead (receiver dropped), replace it
/// so this stream becomes the new primary notification channel. Cached
/// messages are replayed from `last_event_index` so the client receives
/// any events it missed (including server-initiated requests).
///
/// When the primary is still active, create a "shadow" stream — an idle SSE
/// connection kept alive by keep-alive pings. This prevents multiple
/// EventSource connections (e.g. from POST response reconnections) from
/// killing each other by repeatedly replacing the common channel sender.
async fn resume_or_shadow_common(
&mut self,
last_event_index: usize,
) -> Result<StreamableHttpMessageReceiver, SessionError> {
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
if self.common.tx.is_closed() {
// Primary common channel is dead — replace it.
tracing::debug!("Replacing dead common channel with new primary");
self.common.tx = tx;
// Replay cached messages from where the client left off so
// server-initiated requests and notifications are not lost.
self.common.sync(last_event_index).await?;
} else {
// Primary common channel is still active. Create a shadow stream
// that stays alive via SSE keep-alive but doesn't receive
// notifications. This prevents competing EventSource connections
// from killing each other's channels.
tracing::debug!(
shadow_count = self.shadow_txs.len(),
"Common channel active, creating shadow stream"
);
self.shadow_txs.push(tx);
}
Ok(StreamableHttpMessageReceiver {
http_request_id: None,
inner: rx,
})
}

async fn close_sse_stream(
Expand Down Expand Up @@ -584,6 +631,9 @@ impl LocalSessionWorker {
let (tx, _rx) = tokio::sync::mpsc::channel(1);
self.common.tx = tx;

// Also close all shadow streams
self.shadow_txs.clear();

tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
Ok(())
}
Expand Down Expand Up @@ -1036,6 +1086,7 @@ pub fn create_local_session(
tx_router: HashMap::new(),
resource_router: HashMap::new(),
common,
shadow_txs: Vec::new(),
event_rx,
session_config: config.clone(),
};
Expand Down
24 changes: 12 additions & 12 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ where
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
// unauthorized
// MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
.status(http::StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
.expect("valid response"));
};
// check if session exists
Expand All @@ -201,10 +201,10 @@ where
.await
.map_err(internal_error_response("check session"))?;
if !has_session {
// unauthorized
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}
// check if last event id is provided
Expand Down Expand Up @@ -313,10 +313,10 @@ where
.await
.map_err(internal_error_response("check session"))?;
if !has_session {
// unauthorized
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}

Expand Down Expand Up @@ -505,10 +505,10 @@ where
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
// unauthorized
// MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
.status(http::StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
.expect("valid response"));
};
// close session
Expand Down
Loading
Loading