11import inspect
22import logging
33import types
4- from typing import Any , Dict , Optional , Union , Callable
4+ from typing import Any , Callable , Dict , Optional , TypeVar , Union
55
6+ from humanloop .agents .client import AgentsClient
67from humanloop .context import (
78 get_decorator_context ,
89 get_evaluation_context ,
910 get_trace_id ,
1011)
12+ from humanloop .datasets .client import DatasetsClient
1113from humanloop .error import HumanloopRuntimeError
12- from humanloop .sync .sync_client import SyncClient
13- from humanloop .prompts .client import PromptsClient
14+ from humanloop .evaluators .client import EvaluatorsClient
1415from humanloop .flows .client import FlowsClient
15- from humanloop .datasets .client import DatasetsClient
16- from humanloop .agents . client import AgentsClient
16+ from humanloop .prompts .client import PromptsClient
17+ from humanloop .sync . sync_client import SyncClient
1718from humanloop .tools .client import ToolsClient
18- from humanloop .evaluators .client import EvaluatorsClient
1919from humanloop .types import FileType
20+ from humanloop .types .agent_call_response import AgentCallResponse
2021from humanloop .types .create_evaluator_log_response import CreateEvaluatorLogResponse
2122from humanloop .types .create_flow_log_response import CreateFlowLogResponse
2223from humanloop .types .create_prompt_log_response import CreatePromptLogResponse
2324from humanloop .types .create_tool_log_response import CreateToolLogResponse
2425from humanloop .types .prompt_call_response import PromptCallResponse
25- from humanloop .types .agent_call_response import AgentCallResponse
2626
2727logger = logging .getLogger ("humanloop.sdk" )
2828
29+
2930LogResponseType = Union [
3031 CreatePromptLogResponse ,
3132 CreateToolLogResponse ,
3940]
4041
4142
43+ T = TypeVar ("T" , bound = Union [PromptsClient , AgentsClient , ToolsClient , FlowsClient , DatasetsClient , EvaluatorsClient ])
44+
45+
4246def _get_file_type_from_client (
4347 client : Union [PromptsClient , AgentsClient , ToolsClient , FlowsClient , DatasetsClient , EvaluatorsClient ],
4448) -> FileType :
@@ -55,13 +59,13 @@ def _get_file_type_from_client(
5559 return "dataset"
5660 elif isinstance (client , EvaluatorsClient ):
5761 return "evaluator"
58-
59- raise ValueError (f"Unsupported client type: { type (client )} " )
62+ else :
63+ raise ValueError (f"Unsupported client type: { type (client )} " )
6064
6165
62- def _handle_tracing_context (kwargs : Dict [str , Any ], client : Any ) -> Dict [str , Any ]:
66+ def _handle_tracing_context (kwargs : Dict [str , Any ], client : T ) -> Dict [str , Any ]:
6367 """Handle tracing context for both log and call methods."""
64- trace_id = get_trace_id ()
68+ trace_id = get_trace_id ()
6569 if trace_id is not None :
6670 if "flow" in str (type (client ).__name__ ).lower ():
6771 context = get_decorator_context ()
@@ -86,7 +90,7 @@ def _handle_tracing_context(kwargs: Dict[str, Any], client: Any) -> Dict[str, An
8690
8791def _handle_local_files (
8892 kwargs : Dict [str , Any ],
89- client : Any ,
93+ client : T ,
9094 sync_client : Optional [SyncClient ],
9195 use_local_files : bool ,
9296) -> Dict [str , Any ]:
@@ -136,7 +140,7 @@ def _handle_evaluation_context(kwargs: Dict[str, Any]) -> tuple[Dict[str, Any],
136140 return kwargs , None
137141
138142
139- def _overload_log (self : Any , sync_client : Optional [SyncClient ], use_local_files : bool , ** kwargs ) -> LogResponseType :
143+ def _overload_log (self : T , sync_client : Optional [SyncClient ], use_local_files : bool , ** kwargs ) -> LogResponseType :
140144 try :
141145 # Special handling for flows - prevent direct log usage
142146 if type (self ) is FlowsClient and get_trace_id () is not None :
@@ -158,7 +162,7 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
158162 kwargs = _handle_local_files (kwargs , self , sync_client , use_local_files )
159163
160164 kwargs , eval_callback = _handle_evaluation_context (kwargs )
161- response = self ._log (** kwargs ) # Use stored original method
165+ response = self ._log (** kwargs ) # type: ignore[union-attr] # Use stored original method
162166 if eval_callback is not None :
163167 eval_callback (response .id )
164168 return response
@@ -170,11 +174,11 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
170174 raise HumanloopRuntimeError from e
171175
172176
173- def _overload_call (self : Any , sync_client : Optional [SyncClient ], use_local_files : bool , ** kwargs ) -> CallResponseType :
177+ def _overload_call (self : T , sync_client : Optional [SyncClient ], use_local_files : bool , ** kwargs ) -> CallResponseType :
174178 try :
175179 kwargs = _handle_tracing_context (kwargs , self )
176180 kwargs = _handle_local_files (kwargs , self , sync_client , use_local_files )
177- return self ._call (** kwargs ) # Use stored original method
181+ return self ._call (** kwargs ) # type: ignore[union-attr] # Use stored original method
178182 except HumanloopRuntimeError :
179183 # Re-raise HumanloopRuntimeError without wrapping to preserve the message
180184 raise
@@ -184,33 +188,37 @@ def _overload_call(self: Any, sync_client: Optional[SyncClient], use_local_files
184188
185189
186190def overload_client (
187- client : Any ,
191+ client : T ,
188192 sync_client : Optional [SyncClient ] = None ,
189193 use_local_files : bool = False ,
190- ) -> Any :
194+ ) -> T :
191195 """Overloads client methods to add tracing, local file handling, and evaluation context."""
192196 # Store original log method as _log for all clients. Used in flow decorator
193197 if hasattr (client , "log" ) and not hasattr (client , "_log" ):
194- client ._log = client .log # type: ignore[attr-defined]
198+ # Store original method with type ignore
199+ client ._log = client .log # type: ignore
195200
196201 # Create a closure to capture sync_client and use_local_files
197- def log_wrapper (self : Any , ** kwargs ) -> LogResponseType :
202+ def log_wrapper (self : T , ** kwargs ) -> LogResponseType :
198203 return _overload_log (self , sync_client , use_local_files , ** kwargs )
199204
200- client .log = types .MethodType (log_wrapper , client )
205+ # Replace the log method with type ignore
206+ client .log = types .MethodType (log_wrapper , client ) # type: ignore
201207
202208 # Overload call method for Prompt and Agent clients
203209 if _get_file_type_from_client (client ) in ["prompt" , "agent" ]:
204210 if sync_client is None and use_local_files :
205211 logger .error ("sync_client is None but client has call method and use_local_files=%s" , use_local_files )
206212 raise HumanloopRuntimeError ("sync_client is required for clients that support call operations" )
207213 if hasattr (client , "call" ) and not hasattr (client , "_call" ):
208- client ._call = client .call # type: ignore[attr-defined]
214+ # Store original method with type ignore
215+ client ._call = client .call # type: ignore
209216
210217 # Create a closure to capture sync_client and use_local_files
211- def call_wrapper (self : Any , ** kwargs ) -> CallResponseType :
218+ def call_wrapper (self : T , ** kwargs ) -> CallResponseType :
212219 return _overload_call (self , sync_client , use_local_files , ** kwargs )
213220
214- client .call = types .MethodType (call_wrapper , client )
221+ # Replace the call method with type ignore
222+ client .call = types .MethodType (call_wrapper , client ) # type: ignore
215223
216224 return client
0 commit comments