diff --git a/pyproject.toml b/pyproject.toml index 8d57f865..3d884ddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.8.6" +version = "0.8.7" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 8a87fec8..8f7bc5ca 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -16,6 +16,8 @@ LowCodeAgentDefinition, ) +from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION + from .context_tool import create_context_tool from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool @@ -54,6 +56,15 @@ async def create_tools_from_resources( else: tools.append(tool) + if agent.is_conversational: + props = getattr(resource, "properties", None) + if props and getattr( + props, REQUIRE_CONVERSATIONAL_CONFIRMATION, False + ): + if tool.metadata is None: + tool.metadata = {} + tool.metadata[REQUIRE_CONVERSATIONAL_CONFIRMATION] = True + return tools diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 6d6b9fc5..83072d42 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -7,6 +7,7 @@ from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.tools import BaseTool from langgraph._internal._runnable import RunnableCallable +from langgraph.errors import GraphBubbleUp from langgraph.types import Command from pydantic import BaseModel from uipath.platform.resume_triggers import is_no_content_marker @@ -21,6 +22,7 @@ extract_current_tool_call_index, find_latest_ai_message, ) +from uipath_langchain.chat.hitl import request_tool_confirmation # the type safety can be improved with generics ToolWrapperReturnType = dict[str, Any] | Command[Any] | None @@ -79,6 +81,13 @@ def _func(self, state: AgentGraphState) -> OutputType: if call is None: return None + # prompt user for approval if tool requires confirmation + confirmation = request_tool_confirmation(call, self.tool) + + # user rejected the tool call + if confirmation is not None and confirmation.cancelled: + return self._process_result(call, confirmation.cancelled) + try: if self.wrapper: inputs = self._prepare_wrapper_inputs( @@ -87,7 +96,13 @@ def _func(self, state: AgentGraphState) -> OutputType: result = self.wrapper(*inputs) else: result = self.tool.invoke(call) - return self._process_result(call, result) + output = self._process_result(call, result) + # HITL approved - apply confirmation metadata to tool result message + if confirmation is not None: + confirmation.annotate_result(output) + return output + except GraphBubbleUp: + raise except Exception as e: if self.handle_tool_errors: return self._process_error_result(call, e) @@ -98,6 +113,13 @@ async def _afunc(self, state: AgentGraphState) -> OutputType: if call is None: return None + # prompt user for approval if tool requires confirmation + confirmation = request_tool_confirmation(call, self.tool) + + # user rejected the tool call + if confirmation is not None and confirmation.cancelled: + return self._process_result(call, confirmation.cancelled) + try: if self.awrapper: inputs = self._prepare_wrapper_inputs( @@ -106,7 +128,13 @@ async def _afunc(self, state: AgentGraphState) -> OutputType: result = await self.awrapper(*inputs) else: result = await self.tool.ainvoke(call) - return self._process_result(call, result) + output = self._process_result(call, result) + # HITL approved - apply confirmation metadata to tool result message + if confirmation is not None: + confirmation.annotate_result(output) + return output + except GraphBubbleUp: + raise except Exception as e: if self.handle_tool_errors: return self._process_error_result(call, e) diff --git a/src/uipath_langchain/chat/hitl.py b/src/uipath_langchain/chat/hitl.py index 625fc9a6..9cccb8c0 100644 --- a/src/uipath_langchain/chat/hitl.py +++ b/src/uipath_langchain/chat/hitl.py @@ -1,8 +1,10 @@ import functools import inspect +import json from inspect import Parameter -from typing import Annotated, Any, Callable +from typing import Annotated, Any, Callable, NamedTuple +from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.tools import BaseTool, InjectedToolCallId from langchain_core.tools import tool as langchain_tool from langgraph.types import interrupt @@ -10,7 +12,46 @@ UiPathConversationToolCallConfirmationValue, ) -_CANCELLED_MESSAGE = "Cancelled by user" +CANCELLED_MESSAGE = "Cancelled by user" + +CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args" +REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation" + + +class ConfirmationResult(NamedTuple): + """Result of a tool confirmation check.""" + + cancelled: ToolMessage | None # ToolMessage if cancelled, None if approved + args_modified: bool + approved_args: dict[str, Any] | None = None + + def annotate_result(self, output: dict[str, Any] | Any) -> None: + """Apply confirmation metadata to a tool result message.""" + msg = None + if isinstance(output, dict): + messages = output.get("messages") + if messages: + msg = messages[0] + if msg is None: + return + if self.approved_args is not None: + msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = ( + self.approved_args + ) + if self.args_modified: + try: + result_value = json.loads(msg.content) + except (json.JSONDecodeError, TypeError): + result_value = msg.content + msg.content = json.dumps( + { + "meta": { + "args_modified_by_user": True, + "executed_args": self.approved_args, + }, + "result": result_value, + } + ) def _patch_span_input(approved_args: dict[str, Any]) -> None: @@ -53,7 +94,7 @@ def _patch_span_input(approved_args: dict[str, Any]) -> None: pass -def _request_approval( +def request_approval( tool_args: dict[str, Any], tool: BaseTool, ) -> dict[str, Any] | None: @@ -89,7 +130,41 @@ def _request_approval( if not confirmation.get("approved", True): return None - return confirmation.get("input") or tool_args + return ( + confirmation.get("input") + if confirmation.get("input") is not None + else tool_args + ) + + +def request_tool_confirmation( + call: ToolCall, tool: BaseTool +) -> ConfirmationResult | None: + """Check whether a tool requires user confirmation and request approval""" + if not (tool.metadata and tool.metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION)): + return None + + original_args = call["args"] + approved_args = request_approval( + {**original_args, "tool_call_id": call["id"]}, tool + ) + if approved_args is None: + cancelled_msg = ToolMessage( + content=CANCELLED_MESSAGE, + name=call["name"], + tool_call_id=call["id"], + ) + cancelled_msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = ( + original_args + ) + return ConfirmationResult(cancelled=cancelled_msg, args_modified=False) + # Mutate call args so the tool executes with the approved values + call["args"] = approved_args + return ConfirmationResult( + cancelled=None, + args_modified=approved_args != original_args, + approved_args=approved_args, + ) def requires_approval( @@ -107,9 +182,9 @@ def decorator(fn: Callable[..., Any]) -> BaseTool: # wrap the tool/function @functools.wraps(fn) def wrapper(**tool_args: Any) -> Any: - approved_args = _request_approval(tool_args, _created_tool[0]) + approved_args = request_approval(tool_args, _created_tool[0]) if approved_args is None: - return _CANCELLED_MESSAGE + return {"meta": CANCELLED_MESSAGE} _patch_span_input(approved_args) return fn(**approved_args) diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 53712e91..9af60452 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -39,6 +39,8 @@ ) from uipath.runtime import UiPathRuntimeStorageProtocol +from uipath_langchain.chat.hitl import CONVERSATIONAL_APPROVED_TOOL_ARGS + from ._citations import CitationStreamProcessor, extract_citations_from_text logger = logging.getLogger(__name__) @@ -58,6 +60,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None """Initialize the mapper with empty state.""" self.runtime_id = runtime_id self.storage = storage + self.tool_names_requiring_confirmation: set[str] = set() self.current_message: AIMessageChunk self.seen_message_ids: set[str] = set() self._storage_lock = asyncio.Lock() @@ -389,11 +392,17 @@ async def map_current_message_to_start_tool_call_events(self): tool_call_id_to_message_id_map[tool_call_id] = ( self.current_message.id ) - events.append( - self.map_tool_call_to_tool_call_start_event( - self.current_message.id, tool_call + + # if tool requires confirmation, we skip start tool call + if ( + tool_call["name"] + not in self.tool_names_requiring_confirmation + ): + events.append( + self.map_tool_call_to_tool_call_start_event( + self.current_message.id, tool_call + ) ) - ) if self.storage is not None: await self.storage.set_value( @@ -426,7 +435,19 @@ async def map_tool_message_to_events( # Keep as string if not valid JSON pass - events = [ + events: list[UiPathConversationMessageEvent] = [] + + # emit startToolCall for tools requiring confirmation after it's approved + approved_args = message.response_metadata.get(CONVERSATIONAL_APPROVED_TOOL_ARGS) + if approved_args is not None: + tool_call = ToolCall( + name=message.name or "", args=approved_args, id=message.tool_call_id + ) + events.append( + self.map_tool_call_to_tool_call_start_event(message_id, tool_call) + ) + + events.append( UiPathConversationMessageEvent( message_id=message_id, tool_call=UiPathConversationToolCallEvent( @@ -438,7 +459,7 @@ async def map_tool_message_to_events( ), ), ) - ] + ) if is_last_tool_call: events.append(self.map_to_message_end_event(message_id)) @@ -665,7 +686,7 @@ def _map_langchain_ai_message_to_uipath_message_data( role="assistant", content_parts=content_parts, tool_calls=uipath_tool_calls, - interrupts=[], # TODO: Interrupts + interrupts=[], ) diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index 228a5cdb..feb32701 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -29,6 +29,7 @@ ) from uipath.runtime.schema import UiPathRuntimeSchema +from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError from uipath_langchain.runtime.messages import UiPathChatMessagesMapper from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema @@ -64,6 +65,9 @@ def __init__( self.entrypoint: str | None = entrypoint self.callbacks: list[BaseCallbackHandler] = callbacks or [] self.chat = UiPathChatMessagesMapper(self.runtime_id, storage) + self.chat.tool_names_requiring_confirmation = ( + self._get_tool_names_requiring_confirmation() + ) self._middleware_node_names: set[str] = self._detect_middleware_nodes() async def execute( @@ -486,6 +490,18 @@ def _detect_middleware_nodes(self) -> set[str]: return middleware_nodes + def _get_tool_names_requiring_confirmation(self) -> set[str]: + names: set[str] = set() + for node_name, node_spec in self.graph.nodes.items(): + # langgraph's processing node.bound -> runnable.tool -> baseTool (if tool node) + tool = getattr(getattr(node_spec, "bound", None), "tool", None) + if tool is None: + continue + metadata = getattr(tool, "metadata", None) or {} + if metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION): + names.add(getattr(tool, "name", node_name)) + return names + def _is_middleware_node(self, node_name: str) -> bool: """Check if a node name represents a middleware node.""" return node_name in self._middleware_node_names diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index af3da38c..08b591b0 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -1,6 +1,7 @@ """Tests for tool_node.py module.""" from typing import Any, Dict +from unittest.mock import patch import pytest from langchain_core.messages import AIMessage, HumanMessage @@ -13,11 +14,16 @@ AgentRuntimeError, AgentRuntimeErrorCode, ) +from uipath_langchain.agent.react.types import AgentGraphState from uipath_langchain.agent.tools.tool_node import ( ToolWrapperMixin, UiPathToolNode, create_tool_node, ) +from uipath_langchain.chat.hitl import ( + CANCELLED_MESSAGE, + CONVERSATIONAL_APPROVED_TOOL_ARGS, +) class MockTool(BaseTool): @@ -66,10 +72,9 @@ class FilteredState(BaseModel): session_id: str = "test_session" -class MockState(BaseModel): +class MockState(AgentGraphState): """Mock state for testing.""" - messages: list[Any] = [] user_id: str = "test_user" session_id: str = "test_session" @@ -310,8 +315,7 @@ def test_tool_error_propagates_when_handle_errors_false(self, mock_state): node = UiPathToolNode(failing_tool, handle_tool_errors=False) with pytest.raises(ValueError) as exc_info: - node._func(state) # type: ignore[arg-type] - + node._func(state) assert "Tool execution failed: test input" in str(exc_info.value) async def test_async_tool_error_propagates_when_handle_errors_false(self): @@ -328,8 +332,7 @@ async def test_async_tool_error_propagates_when_handle_errors_false(self): node = UiPathToolNode(failing_tool, handle_tool_errors=False) with pytest.raises(ValueError) as exc_info: - await node._afunc(state) # type: ignore[arg-type] - + await node._afunc(state) assert "Async tool execution failed: test input" in str(exc_info.value) def test_tool_error_captured_when_handle_errors_true(self): @@ -345,8 +348,7 @@ def test_tool_error_captured_when_handle_errors_true(self): node = UiPathToolNode(failing_tool, handle_tool_errors=True) - result = node._func(state) # type: ignore[arg-type] - + result = node._func(state) assert result is not None assert isinstance(result, dict) assert "messages" in result @@ -372,8 +374,7 @@ async def test_async_tool_error_captured_when_handle_errors_true(self): node = UiPathToolNode(failing_tool, handle_tool_errors=True) - result = await node._afunc(state) # type: ignore[arg-type] - + result = await node._afunc(state) assert result is not None assert isinstance(result, dict) assert "messages" in result @@ -482,3 +483,185 @@ def test_create_tool_node_with_handle_errors_true(self): node = result[tool_name] assert isinstance(node, UiPathToolNode) assert node.handle_tool_errors is True + + +class TestToolNodeConfirmation: + """Tests for confirmation flow in UiPathToolNode._func / _afunc.""" + + @pytest.fixture + def confirmation_tool(self): + """Tool with require_conversational_confirmation metadata.""" + return MockTool(metadata={"require_conversational_confirmation": True}) + + @pytest.fixture + def confirmation_state(self): + tool_call = { + "name": "mock_tool", + "args": {"input_text": "test input"}, + "id": "test_call_id", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + return MockState(messages=[ai_message]) + + def test_no_confirmation_without_metadata(self): + """Tool without metadata executes normally, no interrupt.""" + tool = MockTool() # no metadata + node = UiPathToolNode(tool) + tool_call = { + "name": "mock_tool", + "args": {"input_text": "hello"}, + "id": "call_1", + } + state = MockState(messages=[AIMessage(content="go", tool_calls=[tool_call])]) + + result = node._func(state) + + assert result is not None + assert isinstance(result, dict) + assert "Mock result: hello" in result["messages"][0].content + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + def test_cancelled_returns_cancelled_message( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Rejected confirmation returns CANCELLED_MESSAGE.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert isinstance(msg, ToolMessage) + assert msg.content == CANCELLED_MESSAGE + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "test input"}, + ) + def test_approved_same_args_no_meta( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Approved with same args → normal execution, no meta injected.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert "args_modified_by_user" not in msg.content + assert "Mock result:" in msg.content + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "edited"}, + ) + def test_approved_modified_args_injects_meta( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Approved with edited args → tool runs with new args, meta injected.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + import json + + wrapped = json.loads(msg.content) + assert wrapped["meta"]["args_modified_by_user"] is True + assert wrapped["meta"]["executed_args"] == {"input_text": "edited"} + assert "Mock result: edited" in wrapped["result"] + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + async def test_async_cancelled( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Async path: rejected confirmation returns CANCELLED_MESSAGE.""" + node = UiPathToolNode(confirmation_tool) + + result = await node._afunc(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert msg.content == CANCELLED_MESSAGE + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "async edited"}, + ) + async def test_async_approved_modified_args( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Async path: approved with edited args → meta injected.""" + node = UiPathToolNode(confirmation_tool) + + result = await node._afunc(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + import json + + wrapped = json.loads(msg.content) + assert wrapped["meta"]["args_modified_by_user"] is True + assert wrapped["meta"]["executed_args"] == {"input_text": "async edited"} + assert "Async mock result: async edited" in wrapped["result"] + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"input_text": "approved"}, + ) + def test_approved_attaches_approved_args_metadata( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Approved path attaches approved args in response_metadata.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert CONVERSATIONAL_APPROVED_TOOL_ARGS in msg.response_metadata + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "input_text": "approved" + } + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + def test_cancelled_attaches_original_args_metadata( + self, mock_approval, confirmation_tool, confirmation_state + ): + """Cancelled path attaches original args in response_metadata.""" + node = UiPathToolNode(confirmation_tool) + + result = node._func(confirmation_state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert CONVERSATIONAL_APPROVED_TOOL_ARGS in msg.response_metadata + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "input_text": "test input" + } + + def test_no_confirmation_no_metadata(self): + """Non-confirmation tools don't get the approved args metadata.""" + tool = MockTool() # no confirmation metadata + node = UiPathToolNode(tool) + tool_call = { + "name": "mock_tool", + "args": {"input_text": "hello"}, + "id": "call_1", + } + state = MockState(messages=[AIMessage(content="go", tool_calls=[tool_call])]) + + result = node._func(state) + + assert result is not None + assert isinstance(result, dict) + msg = result["messages"][0] + assert CONVERSATIONAL_APPROVED_TOOL_ARGS not in msg.response_metadata diff --git a/tests/chat/test_hitl.py b/tests/chat/test_hitl.py new file mode 100644 index 00000000..0b0e9d46 --- /dev/null +++ b/tests/chat/test_hitl.py @@ -0,0 +1,185 @@ +"""Tests for hitl.py module.""" + +from typing import Any +from unittest.mock import patch + +from langchain_core.messages.tool import ToolCall, ToolMessage +from langchain_core.tools import BaseTool + +from uipath_langchain.chat.hitl import ( + CANCELLED_MESSAGE, + CONVERSATIONAL_APPROVED_TOOL_ARGS, + ConfirmationResult, + request_approval, + request_tool_confirmation, +) + + +class MockTool(BaseTool): + name: str = "mock_tool" + description: str = "A mock tool" + + def _run(self) -> str: + return "" + + +def _make_call(args: dict[str, Any] | None = None) -> ToolCall: + return ToolCall(name="mock_tool", args=args or {"query": "test"}, id="call_1") + + +class TestCheckToolConfirmation: + """Tests for request_tool_confirmation.""" + + def test_returns_none_when_no_metadata(self): + """No metadata → no confirmation needed.""" + tool = MockTool() + call = _make_call() + assert request_tool_confirmation(call, tool) is None + + def test_returns_none_when_flag_not_set(self): + """Metadata exists but flag is missing → no confirmation needed.""" + tool = MockTool(metadata={"other_key": True}) + call = _make_call() + assert request_tool_confirmation(call, tool) is None + + def test_returns_none_when_flag_false(self): + """Flag explicitly False → no confirmation needed.""" + tool = MockTool(metadata={"require_conversational_confirmation": False}) + call = _make_call() + assert request_tool_confirmation(call, tool) is None + + @patch("uipath_langchain.chat.hitl.request_approval", return_value=None) + def test_cancelled_returns_tool_message(self, mock_approval): + """User rejects → ConfirmationResult with cancelled ToolMessage and metadata.""" + tool = MockTool(metadata={"require_conversational_confirmation": True}) + call = _make_call() + + result = request_tool_confirmation(call, tool) + + assert result is not None + assert isinstance(result, ConfirmationResult) + assert result.cancelled is not None + assert isinstance(result.cancelled, ToolMessage) + assert result.cancelled.content == CANCELLED_MESSAGE + assert result.cancelled.name == "mock_tool" + assert result.cancelled.tool_call_id == "call_1" + assert result.args_modified is False + assert result.cancelled.response_metadata[ + CONVERSATIONAL_APPROVED_TOOL_ARGS + ] == {"query": "test"} + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"query": "test"}, + ) + def test_approved_same_args(self, mock_approval): + """User approves without editing → cancelled=None, args_modified=False.""" + tool = MockTool(metadata={"require_conversational_confirmation": True}) + call = _make_call({"query": "test"}) + + result = request_tool_confirmation(call, tool) + + assert result is not None + assert result.cancelled is None + assert result.args_modified is False + assert result.approved_args == {"query": "test"} + + @patch( + "uipath_langchain.chat.hitl.request_approval", + return_value={"query": "edited"}, + ) + def test_approved_modified_args(self, mock_approval): + """User edits args → cancelled=None, args_modified=True, call updated.""" + tool = MockTool(metadata={"require_conversational_confirmation": True}) + call = _make_call({"query": "original"}) + + result = request_tool_confirmation(call, tool) + + assert result is not None + assert result.cancelled is None + assert result.args_modified is True + assert result.approved_args == {"query": "edited"} + assert call["args"] == {"query": "edited"} + + +class TestAnnotateResult: + """Tests for ConfirmationResult.annotate_result.""" + + def test_annotate_sets_metadata(self): + """annotate_result sets approved_args on response_metadata.""" + confirmation = ConfirmationResult( + cancelled=None, args_modified=False, approved_args={"query": "test"} + ) + msg = ToolMessage(content="result", tool_call_id="call_1") + output = {"messages": [msg]} + + confirmation.annotate_result(output) + + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "query": "test" + } + assert msg.content == "result" + + def test_annotate_wraps_content_when_modified(self): + """annotate_result wraps content with structured meta when args were modified.""" + confirmation = ConfirmationResult( + cancelled=None, args_modified=True, approved_args={"query": "edited"} + ) + msg = ToolMessage(content="result", tool_call_id="call_1") + output = {"messages": [msg]} + + confirmation.annotate_result(output) + + assert msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] == { + "query": "edited" + } + import json + + wrapped = json.loads(msg.content) + assert wrapped["meta"]["args_modified_by_user"] is True + assert wrapped["meta"]["executed_args"] == {"query": "edited"} + assert wrapped["result"] == "result" + + +class TestRequestApprovalTruthiness: + """Tests for the truthiness fix in request_approval.""" + + @patch("uipath_langchain.chat.hitl.interrupt") + def test_empty_dict_input_preserved(self, mock_interrupt): + """Empty dict from user edits should not be replaced by original args.""" + mock_interrupt.return_value = {"value": {"approved": True, "input": {}}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == {} + + @patch("uipath_langchain.chat.hitl.interrupt") + def test_empty_list_input_preserved(self, mock_interrupt): + """Empty list from user edits should not be replaced by original args.""" + mock_interrupt.return_value = {"value": {"approved": True, "input": []}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == [] + + @patch("uipath_langchain.chat.hitl.interrupt") + def test_none_input_falls_back_to_original(self, mock_interrupt): + """None input should fall back to original tool_args.""" + mock_interrupt.return_value = {"value": {"approved": True, "input": None}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == {"query": "test"} + + @patch("uipath_langchain.chat.hitl.interrupt") + def test_missing_input_falls_back_to_original(self, mock_interrupt): + """Missing input key should fall back to original tool_args.""" + mock_interrupt.return_value = {"value": {"approved": True}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result == {"query": "test"} + + @patch("uipath_langchain.chat.hitl.interrupt") + def test_rejected_returns_none(self, mock_interrupt): + """Rejected approval returns None.""" + mock_interrupt.return_value = {"value": {"approved": False}} + tool = MockTool() + result = request_approval({"query": "test", "tool_call_id": "c1"}, tool) + assert result is None diff --git a/tests/runtime/test_chat_message_mapper.py b/tests/runtime/test_chat_message_mapper.py index 3eabe5e6..9c29679a 100644 --- a/tests/runtime/test_chat_message_mapper.py +++ b/tests/runtime/test_chat_message_mapper.py @@ -1718,3 +1718,134 @@ def test_ai_message_with_media_citation(self): assert isinstance(source, UiPathConversationCitationSourceMedia) assert source.download_url == "https://r.com" assert source.page_number == "3" + + +class TestConfirmationToolDeferral: + """Tests for deferring startToolCall events for confirmation tools.""" + + @pytest.mark.asyncio + async def test_start_tool_call_skipped_for_confirmation_tool(self): + """AIMessageChunk with confirmation tool should NOT emit startToolCall.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"confirm_tool"} + + # First chunk starts the message with a confirmation tool call + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[{"id": "tc-1", "name": "confirm_tool", "args": {"x": 1}}], + ) + await mapper.map_event(first_chunk) + + # Last chunk triggers tool call start events + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + tool_start_events = [ + e + for e in result + if e.tool_call is not None and e.tool_call.start is not None + ] + assert len(tool_start_events) == 0 + + @pytest.mark.asyncio + async def test_start_tool_call_emitted_for_non_confirmation_tool(self): + """Normal tools still emit startToolCall even when confirmation set is populated.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"other_tool"} + + first_chunk = AIMessageChunk( + content="", + id="msg-2", + tool_calls=[{"id": "tc-2", "name": "normal_tool", "args": {}}], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-2") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + tool_start_events = [ + e + for e in result + if e.tool_call is not None and e.tool_call.start is not None + ] + assert len(tool_start_events) >= 1 + assert tool_start_events[0].tool_call is not None + assert tool_start_events[0].tool_call.start is not None + assert tool_start_events[0].tool_call.start.tool_name == "normal_tool" + + @pytest.mark.asyncio + async def test_deferred_start_tool_call_emitted_from_tool_message(self): + """ToolMessage with approved_tool_args should trigger startToolCall before endToolCall.""" + from uipath_langchain.chat.hitl import CONVERSATIONAL_APPROVED_TOOL_ARGS + + storage = create_mock_storage() + storage.get_value.return_value = {"tc-3": "msg-3"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"confirm_tool"} + + approved_args = {"query": "approved value"} + tool_msg = ToolMessage( + content='{"result": "ok"}', + tool_call_id="tc-3", + name="confirm_tool", + ) + tool_msg.response_metadata[CONVERSATIONAL_APPROVED_TOOL_ARGS] = approved_args + + result = await mapper.map_event(tool_msg) + + assert result is not None + # Should have: startToolCall, endToolCall, messageEnd + assert len(result) == 3 + + # First event: deferred startToolCall + start_event = result[0] + assert start_event.tool_call is not None + assert start_event.tool_call.start is not None + assert start_event.tool_call.start.tool_name == "confirm_tool" + assert start_event.tool_call.start.input == approved_args + + # Second event: endToolCall + end_event = result[1] + assert end_event.tool_call is not None + assert end_event.tool_call.end is not None + + @pytest.mark.asyncio + async def test_mixed_tools_only_confirmation_deferred(self): + """Mixed tools in one AIMessage: only confirmation tool's startToolCall is deferred.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tool_names_requiring_confirmation = {"confirm_tool"} + + first_chunk = AIMessageChunk( + content="", + id="msg-4", + tool_calls=[ + {"id": "tc-normal", "name": "normal_tool", "args": {"a": 1}}, + {"id": "tc-confirm", "name": "confirm_tool", "args": {"b": 2}}, + ], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-4") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + tool_start_names = [ + e.tool_call.start.tool_name + for e in result + if e.tool_call is not None and e.tool_call.start is not None + ] + # normal_tool should have startToolCall, confirm_tool should NOT + assert "normal_tool" in tool_start_names + assert "confirm_tool" not in tool_start_names diff --git a/uv.lock b/uv.lock index 168f7ac9..85d4f802 100644 --- a/uv.lock +++ b/uv.lock @@ -3324,7 +3324,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.8.6" +version = "0.8.7" source = { editable = "." } dependencies = [ { name = "httpx" },