Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/rmcp/src/handler/server/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ where
match request {
ClientRequest::CallToolRequest(request) => {
if self.tool_router.has_route(request.params.name.as_ref())
|| self.tool_router.is_disabled(request.params.name.as_ref())
|| !self.tool_router.transparent_when_not_found
{
let tool_call_context = crate::handler::server::tool::ToolCallContext::new(
Expand Down
68 changes: 63 additions & 5 deletions crates/rmcp/src/handler/server/router/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,16 @@ pub struct ToolRouter<S> {
pub map: std::collections::HashMap<Cow<'static, str>, ToolRoute<S>>,

pub transparent_when_not_found: bool,

disabled: std::collections::HashSet<Cow<'static, str>>,
}

impl<S> Default for ToolRouter<S> {
fn default() -> Self {
Self {
map: std::collections::HashMap::new(),
transparent_when_not_found: false,
disabled: std::collections::HashSet::new(),
}
}
}
Expand All @@ -320,6 +323,7 @@ impl<S> Clone for ToolRouter<S> {
Self {
map: self.map.clone(),
transparent_when_not_found: self.transparent_when_not_found,
disabled: self.disabled.clone(),
}
}
}
Expand All @@ -329,7 +333,11 @@ impl<S> IntoIterator for ToolRouter<S> {
type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, ToolRoute<S>>;

fn into_iter(self) -> Self::IntoIter {
self.map.into_values()
let mut map = self.map;
for name in &self.disabled {
map.remove(name);
}
map.into_values()
}
}

Expand All @@ -341,6 +349,7 @@ where
Self {
map: std::collections::HashMap::new(),
transparent_when_not_found: false,
disabled: std::collections::HashSet::new(),
}
}
pub fn with_route<R, A>(mut self, route: R) -> Self
Expand Down Expand Up @@ -394,24 +403,64 @@ where
}

pub fn merge(&mut self, other: ToolRouter<S>) {
self.disabled.extend(other.disabled);
for item in other.map.into_values() {
self.add_route(item);
}
}

pub fn remove_route(&mut self, name: &str) {
self.map.remove(name);
self.disabled.remove(name);
}

pub fn has_route(&self, name: &str) -> bool {
self.map.contains_key(name)
self.map.contains_key(name) && !self.disabled.contains(name)
}

/// Disable a tool by name so it is hidden from `list_all`, `get`, and
/// rejected by `call`. The tool remains in the router and can be
/// re-enabled later with [`enable_route`](Self::enable_route).
///
/// The name is recorded even if no matching route exists yet, so routes
/// added later (via [`add_route`](Self::add_route) or
/// [`merge`](Self::merge)) will inherit the disabled state.
pub fn disable_route(&mut self, name: &str) {
self.disabled.insert(Cow::Owned(name.to_owned()));
}

/// Re-enable a previously disabled tool.
pub fn enable_route(&mut self, name: &str) {
self.disabled.remove(name);
}

/// Returns `true` if the tool exists in the router but is currently
/// disabled.
pub fn is_disabled(&self, name: &str) -> bool {
self.map.contains_key(name) && self.disabled.contains(name)
}

/// Builder-style variant of [`disable_route`](Self::disable_route).
///
/// The name is recorded even if no matching route has been added yet,
/// so it can be called before [`with_route`](Self::with_route) in a
/// builder chain.
pub fn with_disabled(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.disabled.insert(name.into());
self
}

pub async fn call(
&self,
context: ToolCallContext<'_, S>,
) -> Result<CallToolResult, crate::ErrorData> {
let name = context.name();
if self.disabled.contains(name) {
return Err(crate::ErrorData::invalid_params("tool not found", None));
}
let item = self
.map
.get(context.name())
.get(name)
.ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?;

let result = (item.call)(context).await?;
Expand All @@ -420,15 +469,24 @@ where
}

pub fn list_all(&self) -> Vec<crate::model::Tool> {
let mut tools: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect();
let mut tools: Vec<_> = self
.map
.values()
.filter(|item| !self.disabled.contains(&item.attr.name))
.map(|item| item.attr.clone())
.collect();
tools.sort_by(|a, b| a.name.cmp(&b.name));
tools
}

/// Get a tool definition by name.
///
/// Returns the tool if found, or `None` if no tool with the given name exists.
/// Returns the tool if found and enabled, or `None` if the tool does not
/// exist or is disabled.
pub fn get(&self, name: &str) -> Option<&crate::model::Tool> {
if self.disabled.contains(name) {
return None;
}
self.map.get(name).map(|r| &r.attr)
}
}
Expand Down
130 changes: 130 additions & 0 deletions crates/rmcp/tests/test_tool_routers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,133 @@ fn test_tool_router_list_all_is_sorted() {
"list_all() should return tools sorted alphabetically by name"
);
}

fn build_router() -> ToolRouter<TestHandler<()>> {
ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function))
.with_route((async_function2_tool_attr(), async_function2))
+ TestHandler::<()>::test_router_1()
+ TestHandler::<()>::test_router_2()
}

#[test]
fn test_disable_route() {
let mut router = build_router();
assert_eq!(router.list_all().len(), 4);
assert!(router.has_route("async_function"));
assert!(router.get("async_function").is_some());

router.disable_route("async_function");

assert_eq!(router.list_all().len(), 3);
assert!(!router.has_route("async_function"));
assert!(router.get("async_function").is_none());
assert!(router.is_disabled("async_function"));

// other tools unaffected
assert!(router.has_route("async_function2"));
assert!(router.get("async_function2").is_some());
assert!(!router.is_disabled("async_function2"));
}

#[test]
fn test_enable_route() {
let mut router = build_router();
router.disable_route("async_function");
assert!(!router.has_route("async_function"));

router.enable_route("async_function");
assert!(router.has_route("async_function"));
assert!(router.get("async_function").is_some());
assert!(!router.is_disabled("async_function"));
assert_eq!(router.list_all().len(), 4);
}

#[test]
fn test_with_disabled_builder() {
let router = build_router()
.with_disabled("async_function")
.with_disabled("sync_method");

assert_eq!(router.list_all().len(), 2);
assert!(!router.has_route("async_function"));
assert!(!router.has_route("sync_method"));
assert!(router.has_route("async_function2"));
assert!(router.has_route("async_method"));
}

#[test]
fn test_disabled_tools_survive_merge() {
let mut router_a = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function));
router_a.disable_route("async_function");

let router_b = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function2_tool_attr(), async_function2));

router_a.merge(router_b);

assert_eq!(router_a.list_all().len(), 1);
assert!(router_a.is_disabled("async_function"));
assert!(router_a.has_route("async_function2"));
}

#[test]
fn test_disable_nonexistent_tool() {
let mut router = build_router();
// should not panic
router.disable_route("does_not_exist");
assert_eq!(router.list_all().len(), 4);
// is_disabled returns false for tools not in the map
assert!(!router.is_disabled("does_not_exist"));
}

#[test]
fn test_remove_route_clears_disabled_state() {
let mut router = build_router();
router.disable_route("async_function");
assert!(router.is_disabled("async_function"));

router.remove_route("async_function");
assert!(!router.is_disabled("async_function"));
assert!(!router.has_route("async_function"));
}

#[test]
fn test_into_iter_skips_disabled() {
let router = build_router().with_disabled("async_function");
let names: Vec<_> = router
.into_iter()
.map(|r| r.attr.name.to_string())
.collect();
assert_eq!(names.len(), 3);
assert!(!names.contains(&"async_function".to_string()));
}

#[test]
fn test_pre_disable_before_add_route() {
// Disabling a name before adding a route with that name should
// result in the route being disabled once added.
let router = ToolRouter::<TestHandler<()>>::new()
.with_disabled("async_function")
.with_route((async_function_tool_attr(), async_function));

assert_eq!(router.list_all().len(), 0);
assert!(router.is_disabled("async_function"));
assert!(!router.has_route("async_function"));
}

#[test]
fn test_disabled_tool_invisible_across_all_queries() {
let router = build_router().with_disabled("async_function");

// Not listed
let names: Vec<_> = router.list_all().iter().map(|t| t.name.clone()).collect();
assert!(!names.contains(&"async_function".into()));
// Not retrievable
assert!(router.get("async_function").is_none());
// Not routable
assert!(!router.has_route("async_function"));
// But still known as disabled
assert!(router.is_disabled("async_function"));
}