|
| 1 | +"""Human-in-the-loop (HITL) support for LlamaIndex tools. |
| 2 | +
|
| 3 | +Provides the ``requires_approval`` decorator for marking tool functions |
| 4 | +that need human approval before execution. |
| 5 | +
|
| 6 | +Example:: |
| 7 | +
|
| 8 | + from uipath_llamaindex.chat.tools import requires_approval |
| 9 | +
|
| 10 | + @requires_approval |
| 11 | + def transfer_funds(from_account: str, to_account: str, amount: float) -> str: |
| 12 | + \"\"\"Transfer funds between accounts.\"\"\" |
| 13 | + return f"Transferred ${amount:.2f} from {from_account} to {to_account}" |
| 14 | +""" |
| 15 | + |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +import functools |
| 19 | +import inspect |
| 20 | +import json |
| 21 | +from collections.abc import Callable |
| 22 | +from typing import Any |
| 23 | +from uuid import uuid4 |
| 24 | + |
| 25 | +from llama_index.core.tools import FunctionTool |
| 26 | +from llama_index.core.tools.utils import create_schema_from_function |
| 27 | +from llama_index.core.workflow import Context |
| 28 | +from uipath.core.chat import UiPathConversationToolCallConfirmationValue |
| 29 | +from workflows.events import HumanResponseEvent, InputRequiredEvent |
| 30 | + |
| 31 | +_CANCELLED_MESSAGE = "Cancelled by user" |
| 32 | + |
| 33 | + |
| 34 | +def _is_context_param(annotation: Any) -> bool: |
| 35 | + from typing import get_origin |
| 36 | + |
| 37 | + return annotation is Context or get_origin(annotation) is Context |
| 38 | + |
| 39 | + |
| 40 | +def requires_approval( |
| 41 | + func: Callable[..., Any] | None = None, |
| 42 | +) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]: |
| 43 | + """Decorator that marks a tool function as requiring human approval. |
| 44 | +
|
| 45 | + When the agent calls a tool decorated with ``@requires_approval``, |
| 46 | + execution suspends and waits for a human to approve, edit, or reject |
| 47 | + the tool call before proceeding. |
| 48 | +
|
| 49 | + Can be used with or without parentheses:: |
| 50 | +
|
| 51 | + @requires_approval |
| 52 | + def my_tool(arg: str) -> str: ... |
| 53 | +
|
| 54 | + @requires_approval() |
| 55 | + def my_tool(arg: str) -> str: ... |
| 56 | +
|
| 57 | + Args: |
| 58 | + func: The tool function to wrap. If None, returns a decorator. |
| 59 | +
|
| 60 | + Returns: |
| 61 | + A FunctionTool that suspends for human approval before executing. |
| 62 | + """ |
| 63 | + |
| 64 | + def decorator(fn: Callable[..., Any]) -> FunctionTool: |
| 65 | + is_async = inspect.iscoroutinefunction(fn) |
| 66 | + |
| 67 | + # Determine if the original function has a ctx parameter |
| 68 | + original_sig = inspect.signature(fn) |
| 69 | + ctx_param_name = next( |
| 70 | + ( |
| 71 | + p.name |
| 72 | + for p in original_sig.parameters.values() |
| 73 | + if _is_context_param(p.annotation) |
| 74 | + ), |
| 75 | + None, |
| 76 | + ) |
| 77 | + |
| 78 | + # Build the schema from the original function (excluding ctx if present) |
| 79 | + ignore = [ctx_param_name] if ctx_param_name else [] |
| 80 | + fn_schema = create_schema_from_function(fn.__name__, fn, ignore_fields=ignore) |
| 81 | + |
| 82 | + @functools.wraps(fn) |
| 83 | + async def wrapper(ctx: Context, **kwargs: Any) -> Any: |
| 84 | + tool_call_id = str(uuid4()) |
| 85 | + input_schema = fn_schema.model_json_schema() if fn_schema else {} |
| 86 | + |
| 87 | + confirmation = UiPathConversationToolCallConfirmationValue( |
| 88 | + tool_call_id=tool_call_id, |
| 89 | + tool_name=fn.__name__, |
| 90 | + input_schema=input_schema, |
| 91 | + input_value=kwargs, |
| 92 | + ) |
| 93 | + interrupt_payload = json.dumps( |
| 94 | + { |
| 95 | + "type": "uipath_cas_tool_call_confirmation", |
| 96 | + "value": confirmation.model_dump(by_alias=True), |
| 97 | + } |
| 98 | + ) |
| 99 | + |
| 100 | + response: HumanResponseEvent = await ctx.wait_for_event( |
| 101 | + HumanResponseEvent, |
| 102 | + waiter_id=f"approval_{tool_call_id}", |
| 103 | + waiter_event=InputRequiredEvent(prefix=interrupt_payload), |
| 104 | + timeout=None, |
| 105 | + ) |
| 106 | + |
| 107 | + # Parse resume payload: |
| 108 | + # {"type": "uipath_cas_tool_call_confirmation", "value": {"approved": bool, "input": ...}} |
| 109 | + response_data: Any = response.response |
| 110 | + if isinstance(response_data, str): |
| 111 | + try: |
| 112 | + response_data = json.loads(response_data) |
| 113 | + except json.JSONDecodeError: |
| 114 | + pass |
| 115 | + |
| 116 | + if isinstance(response_data, dict): |
| 117 | + end_value = response_data.get("value", response_data) |
| 118 | + if not end_value.get("approved", True): |
| 119 | + return _CANCELLED_MESSAGE |
| 120 | + approved_kwargs: dict[str, Any] = end_value.get("input") or kwargs |
| 121 | + else: |
| 122 | + approved_kwargs = kwargs |
| 123 | + |
| 124 | + if ctx_param_name: |
| 125 | + approved_kwargs[ctx_param_name] = ctx |
| 126 | + |
| 127 | + if is_async: |
| 128 | + return await fn(**approved_kwargs) |
| 129 | + return fn(**approved_kwargs) |
| 130 | + |
| 131 | + return FunctionTool.from_defaults( |
| 132 | + async_fn=wrapper, |
| 133 | + name=fn.__name__, |
| 134 | + description=fn.__doc__ or "", |
| 135 | + fn_schema=fn_schema, |
| 136 | + ) |
| 137 | + |
| 138 | + if func is not None: |
| 139 | + return decorator(func) |
| 140 | + return decorator |
| 141 | + |
| 142 | + |
| 143 | +__all__ = ["requires_approval"] |
0 commit comments