From 875003120e1be298c4cc46bcdde795fdbe58db9d Mon Sep 17 00:00:00 2001 From: lutz-grex Date: Wed, 15 Apr 2026 11:50:54 +0200 Subject: [PATCH] feat(router): support runtime disabling of tools Add methods to disable/enable tools at runtime. Disabled tools are hidden from listing, lookup, and execution, including in composed routers. Closes #477 --- crates/rmcp/src/handler/server/router.rs | 1 + crates/rmcp/src/handler/server/router/tool.rs | 68 ++++++++- crates/rmcp/tests/test_tool_routers.rs | 130 ++++++++++++++++++ 3 files changed, 194 insertions(+), 5 deletions(-) diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 08beb61d2..aeca04cbe 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -84,6 +84,7 @@ where match request { ClientRequest::CallToolRequest(request) => { if self.tool_router.has_route(request.params.name.as_ref()) + || self.tool_router.is_disabled(request.params.name.as_ref()) || !self.tool_router.transparent_when_not_found { let tool_call_context = crate::handler::server::tool::ToolCallContext::new( diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 79a228ffe..e6d861a45 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -305,6 +305,8 @@ pub struct ToolRouter { pub map: std::collections::HashMap, ToolRoute>, pub transparent_when_not_found: bool, + + disabled: std::collections::HashSet>, } impl Default for ToolRouter { @@ -312,6 +314,7 @@ impl Default for ToolRouter { Self { map: std::collections::HashMap::new(), transparent_when_not_found: false, + disabled: std::collections::HashSet::new(), } } } @@ -320,6 +323,7 @@ impl Clone for ToolRouter { Self { map: self.map.clone(), transparent_when_not_found: self.transparent_when_not_found, + disabled: self.disabled.clone(), } } } @@ -329,7 +333,11 @@ impl IntoIterator for ToolRouter { type IntoIter = std::collections::hash_map::IntoValues, ToolRoute>; fn into_iter(self) -> Self::IntoIter { - self.map.into_values() + let mut map = self.map; + for name in &self.disabled { + map.remove(name); + } + map.into_values() } } @@ -341,6 +349,7 @@ where Self { map: std::collections::HashMap::new(), transparent_when_not_found: false, + disabled: std::collections::HashSet::new(), } } pub fn with_route(mut self, route: R) -> Self @@ -394,6 +403,7 @@ where } pub fn merge(&mut self, other: ToolRouter) { + self.disabled.extend(other.disabled); for item in other.map.into_values() { self.add_route(item); } @@ -401,17 +411,56 @@ where pub fn remove_route(&mut self, name: &str) { self.map.remove(name); + self.disabled.remove(name); } + pub fn has_route(&self, name: &str) -> bool { - self.map.contains_key(name) + self.map.contains_key(name) && !self.disabled.contains(name) + } + + /// Disable a tool by name so it is hidden from `list_all`, `get`, and + /// rejected by `call`. The tool remains in the router and can be + /// re-enabled later with [`enable_route`](Self::enable_route). + /// + /// The name is recorded even if no matching route exists yet, so routes + /// added later (via [`add_route`](Self::add_route) or + /// [`merge`](Self::merge)) will inherit the disabled state. + pub fn disable_route(&mut self, name: &str) { + self.disabled.insert(Cow::Owned(name.to_owned())); + } + + /// Re-enable a previously disabled tool. + pub fn enable_route(&mut self, name: &str) { + self.disabled.remove(name); + } + + /// Returns `true` if the tool exists in the router but is currently + /// disabled. + pub fn is_disabled(&self, name: &str) -> bool { + self.map.contains_key(name) && self.disabled.contains(name) + } + + /// Builder-style variant of [`disable_route`](Self::disable_route). + /// + /// The name is recorded even if no matching route has been added yet, + /// so it can be called before [`with_route`](Self::with_route) in a + /// builder chain. + pub fn with_disabled(mut self, name: impl Into>) -> Self { + self.disabled.insert(name.into()); + self } + pub async fn call( &self, context: ToolCallContext<'_, S>, ) -> Result { + let name = context.name(); + if self.disabled.contains(name) { + return Err(crate::ErrorData::invalid_params("tool not found", None)); + } let item = self .map - .get(context.name()) + .get(name) .ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?; let result = (item.call)(context).await?; @@ -420,15 +469,24 @@ where } pub fn list_all(&self) -> Vec { - let mut tools: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect(); + let mut tools: Vec<_> = self + .map + .values() + .filter(|item| !self.disabled.contains(&item.attr.name)) + .map(|item| item.attr.clone()) + .collect(); tools.sort_by(|a, b| a.name.cmp(&b.name)); tools } /// Get a tool definition by name. /// - /// Returns the tool if found, or `None` if no tool with the given name exists. + /// Returns the tool if found and enabled, or `None` if the tool does not + /// exist or is disabled. pub fn get(&self, name: &str) -> Option<&crate::model::Tool> { + if self.disabled.contains(name) { + return None; + } self.map.get(name).map(|r| &r.attr) } } diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs index c10665064..d60aa5d32 100644 --- a/crates/rmcp/tests/test_tool_routers.rs +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -84,3 +84,133 @@ fn test_tool_router_list_all_is_sorted() { "list_all() should return tools sorted alphabetically by name" ); } + +fn build_router() -> ToolRouter> { + ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)) + .with_route((async_function2_tool_attr(), async_function2)) + + TestHandler::<()>::test_router_1() + + TestHandler::<()>::test_router_2() +} + +#[test] +fn test_disable_route() { + let mut router = build_router(); + assert_eq!(router.list_all().len(), 4); + assert!(router.has_route("async_function")); + assert!(router.get("async_function").is_some()); + + router.disable_route("async_function"); + + assert_eq!(router.list_all().len(), 3); + assert!(!router.has_route("async_function")); + assert!(router.get("async_function").is_none()); + assert!(router.is_disabled("async_function")); + + // other tools unaffected + assert!(router.has_route("async_function2")); + assert!(router.get("async_function2").is_some()); + assert!(!router.is_disabled("async_function2")); +} + +#[test] +fn test_enable_route() { + let mut router = build_router(); + router.disable_route("async_function"); + assert!(!router.has_route("async_function")); + + router.enable_route("async_function"); + assert!(router.has_route("async_function")); + assert!(router.get("async_function").is_some()); + assert!(!router.is_disabled("async_function")); + assert_eq!(router.list_all().len(), 4); +} + +#[test] +fn test_with_disabled_builder() { + let router = build_router() + .with_disabled("async_function") + .with_disabled("sync_method"); + + assert_eq!(router.list_all().len(), 2); + assert!(!router.has_route("async_function")); + assert!(!router.has_route("sync_method")); + assert!(router.has_route("async_function2")); + assert!(router.has_route("async_method")); +} + +#[test] +fn test_disabled_tools_survive_merge() { + let mut router_a = ToolRouter::>::new() + .with_route((async_function_tool_attr(), async_function)); + router_a.disable_route("async_function"); + + let router_b = ToolRouter::>::new() + .with_route((async_function2_tool_attr(), async_function2)); + + router_a.merge(router_b); + + assert_eq!(router_a.list_all().len(), 1); + assert!(router_a.is_disabled("async_function")); + assert!(router_a.has_route("async_function2")); +} + +#[test] +fn test_disable_nonexistent_tool() { + let mut router = build_router(); + // should not panic + router.disable_route("does_not_exist"); + assert_eq!(router.list_all().len(), 4); + // is_disabled returns false for tools not in the map + assert!(!router.is_disabled("does_not_exist")); +} + +#[test] +fn test_remove_route_clears_disabled_state() { + let mut router = build_router(); + router.disable_route("async_function"); + assert!(router.is_disabled("async_function")); + + router.remove_route("async_function"); + assert!(!router.is_disabled("async_function")); + assert!(!router.has_route("async_function")); +} + +#[test] +fn test_into_iter_skips_disabled() { + let router = build_router().with_disabled("async_function"); + let names: Vec<_> = router + .into_iter() + .map(|r| r.attr.name.to_string()) + .collect(); + assert_eq!(names.len(), 3); + assert!(!names.contains(&"async_function".to_string())); +} + +#[test] +fn test_pre_disable_before_add_route() { + // Disabling a name before adding a route with that name should + // result in the route being disabled once added. + let router = ToolRouter::>::new() + .with_disabled("async_function") + .with_route((async_function_tool_attr(), async_function)); + + assert_eq!(router.list_all().len(), 0); + assert!(router.is_disabled("async_function")); + assert!(!router.has_route("async_function")); +} + +#[test] +fn test_disabled_tool_invisible_across_all_queries() { + let router = build_router().with_disabled("async_function"); + + // Not listed + let names: Vec<_> = router.list_all().iter().map(|t| t.name.clone()).collect(); + assert!(!names.contains(&"async_function".into())); + // Not retrievable + assert!(router.get("async_function").is_none()); + // Not routable + assert!(!router.has_route("async_function")); + // But still known as disabled + assert!(router.is_disabled("async_function")); +}