Skip to content

Commit 21f3e2f

Browse files
Improve type annotations for overloaded client methods (#70)
* Fix type annotations in overload_client to preserve method return types * Further refine types in overload.py * refactor: simplify method overloading by using direct assignment Replace getattr/setattr with direct attribute assignment in overload_client function. Add type ignore comments to handle type checking. This change improves code readability while maintaining the same functionality. * chore: cleaned up imports in sync_client.py
1 parent 079fe22 commit 21f3e2f

File tree

3 files changed

+48
-40
lines changed

3 files changed

+48
-40
lines changed

src/humanloop/client.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,34 @@
1+
import logging
12
import os
23
import typing
34
from typing import Any, List, Optional, Sequence, Tuple
4-
import logging
55

66
import httpx
77
from opentelemetry.sdk.resources import Resource
88
from opentelemetry.sdk.trace import TracerProvider
99
from opentelemetry.trace import Tracer
1010

11+
from humanloop.base_client import AsyncBaseHumanloop, BaseHumanloop
1112
from humanloop.core.client_wrapper import SyncClientWrapper
12-
13+
from humanloop.decorators.flow import flow as flow_decorator_factory
14+
from humanloop.decorators.prompt import prompt_decorator_factory
15+
from humanloop.decorators.tool import tool_decorator_factory as tool_decorator_factory
16+
from humanloop.environment import HumanloopEnvironment
1317
from humanloop.evals import run_eval
1418
from humanloop.evals.types import (
1519
DatasetEvalConfig,
16-
EvaluatorEvalConfig,
1720
EvaluatorCheck,
21+
EvaluatorEvalConfig,
1822
FileEvalConfig,
1923
)
20-
21-
from humanloop.base_client import AsyncBaseHumanloop, BaseHumanloop
22-
from humanloop.overload import overload_client
23-
from humanloop.decorators.flow import flow as flow_decorator_factory
24-
from humanloop.decorators.prompt import prompt_decorator_factory
25-
from humanloop.decorators.tool import tool_decorator_factory as tool_decorator_factory
26-
from humanloop.environment import HumanloopEnvironment
2724
from humanloop.evaluations.client import EvaluationsClient
2825
from humanloop.otel import instrument_provider
2926
from humanloop.otel.exporter import HumanloopSpanExporter
3027
from humanloop.otel.processor import HumanloopSpanProcessor
28+
from humanloop.overload import overload_client
3129
from humanloop.prompt_utils import populate_template
3230
from humanloop.prompts.client import PromptsClient
33-
from humanloop.sync.sync_client import SyncClient, DEFAULT_CACHE_SIZE
31+
from humanloop.sync.sync_client import DEFAULT_CACHE_SIZE, SyncClient
3432

3533
logger = logging.getLogger("humanloop.sdk")
3634

@@ -168,6 +166,7 @@ def __init__(
168166

169167
# Overload the .log method of the clients to be aware of Evaluation Context
170168
# and the @flow decorator providing the trace_id
169+
# Additionally, call and log methods are overloaded in the prompts and agents client to support the use of local files
171170
self.prompts = overload_client(
172171
client=self.prompts, sync_client=self._sync_client, use_local_files=self.use_local_files
173172
)

src/humanloop/overload.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
import inspect
22
import logging
33
import types
4-
from typing import Any, Dict, Optional, Union, Callable
4+
from typing import Any, Callable, Dict, Optional, TypeVar, Union
55

6+
from humanloop.agents.client import AgentsClient
67
from humanloop.context import (
78
get_decorator_context,
89
get_evaluation_context,
910
get_trace_id,
1011
)
12+
from humanloop.datasets.client import DatasetsClient
1113
from humanloop.error import HumanloopRuntimeError
12-
from humanloop.sync.sync_client import SyncClient
13-
from humanloop.prompts.client import PromptsClient
14+
from humanloop.evaluators.client import EvaluatorsClient
1415
from humanloop.flows.client import FlowsClient
15-
from humanloop.datasets.client import DatasetsClient
16-
from humanloop.agents.client import AgentsClient
16+
from humanloop.prompts.client import PromptsClient
17+
from humanloop.sync.sync_client import SyncClient
1718
from humanloop.tools.client import ToolsClient
18-
from humanloop.evaluators.client import EvaluatorsClient
1919
from humanloop.types import FileType
20+
from humanloop.types.agent_call_response import AgentCallResponse
2021
from humanloop.types.create_evaluator_log_response import CreateEvaluatorLogResponse
2122
from humanloop.types.create_flow_log_response import CreateFlowLogResponse
2223
from humanloop.types.create_prompt_log_response import CreatePromptLogResponse
2324
from humanloop.types.create_tool_log_response import CreateToolLogResponse
2425
from humanloop.types.prompt_call_response import PromptCallResponse
25-
from humanloop.types.agent_call_response import AgentCallResponse
2626

2727
logger = logging.getLogger("humanloop.sdk")
2828

29+
2930
LogResponseType = Union[
3031
CreatePromptLogResponse,
3132
CreateToolLogResponse,
@@ -39,6 +40,9 @@
3940
]
4041

4142

43+
T = TypeVar("T", bound=Union[PromptsClient, AgentsClient, ToolsClient, FlowsClient, DatasetsClient, EvaluatorsClient])
44+
45+
4246
def _get_file_type_from_client(
4347
client: Union[PromptsClient, AgentsClient, ToolsClient, FlowsClient, DatasetsClient, EvaluatorsClient],
4448
) -> FileType:
@@ -55,13 +59,13 @@ def _get_file_type_from_client(
5559
return "dataset"
5660
elif isinstance(client, EvaluatorsClient):
5761
return "evaluator"
58-
59-
raise ValueError(f"Unsupported client type: {type(client)}")
62+
else:
63+
raise ValueError(f"Unsupported client type: {type(client)}")
6064

6165

62-
def _handle_tracing_context(kwargs: Dict[str, Any], client: Any) -> Dict[str, Any]:
66+
def _handle_tracing_context(kwargs: Dict[str, Any], client: T) -> Dict[str, Any]:
6367
"""Handle tracing context for both log and call methods."""
64-
trace_id = get_trace_id()
68+
trace_id = get_trace_id()
6569
if trace_id is not None:
6670
if "flow" in str(type(client).__name__).lower():
6771
context = get_decorator_context()
@@ -86,7 +90,7 @@ def _handle_tracing_context(kwargs: Dict[str, Any], client: Any) -> Dict[str, An
8690

8791
def _handle_local_files(
8892
kwargs: Dict[str, Any],
89-
client: Any,
93+
client: T,
9094
sync_client: Optional[SyncClient],
9195
use_local_files: bool,
9296
) -> Dict[str, Any]:
@@ -136,7 +140,7 @@ def _handle_evaluation_context(kwargs: Dict[str, Any]) -> tuple[Dict[str, Any],
136140
return kwargs, None
137141

138142

139-
def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> LogResponseType:
143+
def _overload_log(self: T, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> LogResponseType:
140144
try:
141145
# Special handling for flows - prevent direct log usage
142146
if type(self) is FlowsClient and get_trace_id() is not None:
@@ -158,7 +162,7 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
158162
kwargs = _handle_local_files(kwargs, self, sync_client, use_local_files)
159163

160164
kwargs, eval_callback = _handle_evaluation_context(kwargs)
161-
response = self._log(**kwargs) # Use stored original method
165+
response = self._log(**kwargs) # type: ignore[union-attr] # Use stored original method
162166
if eval_callback is not None:
163167
eval_callback(response.id)
164168
return response
@@ -170,11 +174,11 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
170174
raise HumanloopRuntimeError from e
171175

172176

173-
def _overload_call(self: Any, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> CallResponseType:
177+
def _overload_call(self: T, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> CallResponseType:
174178
try:
175179
kwargs = _handle_tracing_context(kwargs, self)
176180
kwargs = _handle_local_files(kwargs, self, sync_client, use_local_files)
177-
return self._call(**kwargs) # Use stored original method
181+
return self._call(**kwargs) # type: ignore[union-attr] # Use stored original method
178182
except HumanloopRuntimeError:
179183
# Re-raise HumanloopRuntimeError without wrapping to preserve the message
180184
raise
@@ -184,33 +188,37 @@ def _overload_call(self: Any, sync_client: Optional[SyncClient], use_local_files
184188

185189

186190
def overload_client(
187-
client: Any,
191+
client: T,
188192
sync_client: Optional[SyncClient] = None,
189193
use_local_files: bool = False,
190-
) -> Any:
194+
) -> T:
191195
"""Overloads client methods to add tracing, local file handling, and evaluation context."""
192196
# Store original log method as _log for all clients. Used in flow decorator
193197
if hasattr(client, "log") and not hasattr(client, "_log"):
194-
client._log = client.log # type: ignore[attr-defined]
198+
# Store original method with type ignore
199+
client._log = client.log # type: ignore
195200

196201
# Create a closure to capture sync_client and use_local_files
197-
def log_wrapper(self: Any, **kwargs) -> LogResponseType:
202+
def log_wrapper(self: T, **kwargs) -> LogResponseType:
198203
return _overload_log(self, sync_client, use_local_files, **kwargs)
199204

200-
client.log = types.MethodType(log_wrapper, client)
205+
# Replace the log method with type ignore
206+
client.log = types.MethodType(log_wrapper, client) # type: ignore
201207

202208
# Overload call method for Prompt and Agent clients
203209
if _get_file_type_from_client(client) in ["prompt", "agent"]:
204210
if sync_client is None and use_local_files:
205211
logger.error("sync_client is None but client has call method and use_local_files=%s", use_local_files)
206212
raise HumanloopRuntimeError("sync_client is required for clients that support call operations")
207213
if hasattr(client, "call") and not hasattr(client, "_call"):
208-
client._call = client.call # type: ignore[attr-defined]
214+
# Store original method with type ignore
215+
client._call = client.call # type: ignore
209216

210217
# Create a closure to capture sync_client and use_local_files
211-
def call_wrapper(self: Any, **kwargs) -> CallResponseType:
218+
def call_wrapper(self: T, **kwargs) -> CallResponseType:
212219
return _overload_call(self, sync_client, use_local_files, **kwargs)
213220

214-
client.call = types.MethodType(call_wrapper, client)
221+
# Replace the call method with type ignore
222+
client.call = types.MethodType(call_wrapper, client) # type: ignore
215223

216224
return client

src/humanloop/sync/sync_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import json
12
import logging
2-
from pathlib import Path
3-
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
4-
from functools import lru_cache
5-
import typing
63
import time
4+
import typing
5+
from functools import lru_cache
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING, List, Optional, Tuple
8+
79
from humanloop.error import HumanloopRuntimeError
8-
import json
910

1011
if TYPE_CHECKING:
1112
from humanloop.base_client import BaseHumanloop

0 commit comments

Comments
 (0)