Skip to content

Commit 457495c

Browse files
committed
Update the external_evaluators to
In addition, correct some comments and remove some type ignores.
1 parent 75d1722 commit 457495c

File tree

3 files changed

+57
-83
lines changed

3 files changed

+57
-83
lines changed

mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[mypy]
22
exclude = (?x)(
3-
^src/humanloop/eval_utils\.py$
3+
^src/humanloop/eval_utils/*\.py$
44
| ^src/humanloop/prompt_utils\.py$
55
)

src/humanloop/eval_utils/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
class EvaluationContext(TypedDict):
55
"""Context Log to Humanloop.
66
7-
Global state that is set when an Evaluation is ran.
7+
Per datapoint state that is set when an Evaluation is ran.
88
"""
99

1010
"""Required for associating a Log with the Evaluation Run."""
1111
source_datapoint_id: str
1212

13-
"""Exporter calls this so the eval_utils are notified to evaluate an uploaded Log."""
14-
upload_callback: Callable[[dict], None]
13+
"""Overloaded .log method call."""
14+
upload_callback: Callable[[str], None]
1515

1616
"""ID of the evaluated File."""
1717
file_id: str

src/humanloop/eval_utils/run.py

Lines changed: 53 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
from humanloop.types.create_flow_log_response import CreateFlowLogResponse
5555
from humanloop.types.create_prompt_log_response import CreatePromptLogResponse
5656
from humanloop.types.create_tool_log_response import CreateToolLogResponse
57-
from humanloop.types.datapoint_response_target_value import DatapointResponseTargetValue
5857
from humanloop.types.evaluation_run_response import EvaluationRunResponse
5958
from humanloop.types.run_stats_response import RunStatsResponse
6059
from pydantic import ValidationError
@@ -115,9 +114,9 @@ def _is_evaluated_file(
115114
) == log_args.get("path")
116115

117116
# Copy the original log method in a hidden attribute
118-
client._log = client.log # type: ignore
117+
client._log = client.log
119118

120-
def _overloaded_log(
119+
def _overload_log(
121120
self,
122121
**kwargs,
123122
) -> Union[
@@ -132,20 +131,13 @@ def _overloaded_log(
132131
# If the Evaluation Context is not set, an Evaluation is not running
133132
evaluation_context = None
134133

135-
if _is_evaluated_file(
136-
evaluation_context=evaluation_context, # type: ignore
137-
log_args=kwargs,
138-
):
134+
if _is_evaluated_file(evaluation_context=evaluation_context, log_args=kwargs):
139135
# If the .log API user does not provide the source_datapoint_id or run_id,
140136
# override them with the values from the EvaluationContext
141137
# _is_evaluated_file ensures that evaluation_context is not None
142-
evaluation_context = typing.cast(
143-
EvaluationContext,
144-
evaluation_context,
145-
)
146138
for attribute in ["source_datapoint_id", "run_id"]:
147139
if attribute not in kwargs or kwargs[attribute] is None:
148-
kwargs[attribute] = evaluation_context[attribute] # type: ignore
140+
kwargs[attribute] = evaluation_context[attribute]
149141

150142
# Call the original .log method
151143
logger.debug(
@@ -156,33 +148,21 @@ def _overloaded_log(
156148
)
157149
response = self._log(**kwargs)
158150

159-
# Call the callback so the Evaluation can be updated
160151
if _is_evaluated_file(
161-
evaluation_context=evaluation_context, # type: ignore
152+
evaluation_context=evaluation_context,
162153
log_args=kwargs,
163154
):
164-
# Notify that the Log has been added to the Evaluation
155+
# Call the callback so the Evaluation can be updated
165156
# _is_evaluated_file ensures that evaluation_context is not None
166-
evaluation_context = typing.cast(
167-
EvaluationContext,
168-
evaluation_context,
169-
)
170-
evaluation_context["upload_callback"]( # type: ignore
171-
{
172-
**kwargs,
173-
# ID in kwargs refers to the File ID
174-
# Replace it with the Log ID
175-
"id": response.id,
176-
}
177-
)
157+
evaluation_context["upload_callback"](log_id=response.id)
178158

179159
# Mark the Evaluation Context as consumed
180160
evaluation_context_variable.set(None)
181161

182162
return response
183163

184164
# Replace the original log method with the overloaded one
185-
client.log = types.MethodType(_overloaded_log, client) # type: ignore
165+
client.log = types.MethodType(_overload_log, client)
186166
# Return the client with the overloaded log method
187167
logger.debug("Overloaded the .log method of %s", client)
188168
return client
@@ -316,7 +296,7 @@ def run_eval(
316296
except ValidationError:
317297
flow_version = {"attributes": version}
318298
file_dict = {**file_, **flow_version}
319-
hl_file = client.flows.upsert(**file_dict) # type: ignore
299+
hl_file = client.flows.upsert(**file_dict)
320300

321301
elif type_ == "prompt":
322302
try:
@@ -325,7 +305,7 @@ def run_eval(
325305
logger.error(msg="Invalid Prompt `version` in your `file` request. \n\nValidation error: \n)")
326306
raise error_
327307
try:
328-
hl_file = client.prompts.upsert(**file_dict) # type: ignore
308+
hl_file = client.prompts.upsert(**file_dict)
329309
except ApiError as error_:
330310
raise error_
331311

@@ -335,10 +315,10 @@ def run_eval(
335315
except ValidationError as error_:
336316
logger.error(msg="Invalid Tool `version` in your `file` request. \n\nValidation error: \n)")
337317
raise error_
338-
hl_file = client.tools.upsert(**file_dict) # type: ignore
318+
hl_file = client.tools.upsert(**file_dict)
339319

340320
elif type_ == "evaluator":
341-
hl_file = client.evaluators.upsert(**file_dict) # type: ignore
321+
hl_file = client.evaluators.upsert(**file_dict)
342322

343323
else:
344324
raise NotImplementedError(f"Unsupported File type: {type_}")
@@ -396,7 +376,7 @@ def run_eval(
396376
break
397377
if requires_target:
398378
missing_target = 0
399-
for datapoint in hl_dataset.datapoints: # type: ignore
379+
for datapoint in hl_dataset.datapoints:
400380
if not datapoint.target:
401381
missing_target += 1
402382
if missing_target > 0:
@@ -410,15 +390,15 @@ def run_eval(
410390
try:
411391
evaluation = client.evaluations.create(
412392
name=name,
413-
evaluators=[{"path": e["path"]} for e in evaluators], # type: ignore
393+
evaluators=[{"path": e["path"]} for e in evaluators],
414394
file={"id": hl_file.id},
415395
)
416396
except ApiError as error_:
417397
# If the name exists, go and get it # TODO: Update API GET to allow querying by name and file.
418398
if error_.status_code == 409:
419399
evals = client.evaluations.list(file_id=hl_file.id, size=50)
420400
for page in evals.iter_pages():
421-
evaluation = next((e for e in page.items if e.name == name), None) # type: ignore
401+
evaluation = next((e for e in page.items if e.name == name), None)
422402
else:
423403
raise error_
424404
if not evaluation:
@@ -433,25 +413,19 @@ def run_eval(
433413
# Every Run will generate a new batch of Logs
434414
run_id = run.id
435415

436-
_PROGRESS_BAR = _SimpleProgressBar(len(hl_dataset.datapoints)) # type: ignore
416+
_PROGRESS_BAR = _SimpleProgressBar(len(hl_dataset.datapoints))
437417

438418
# Define the function to execute the `callable` in parallel and Log to Humanloop
439419
def process_datapoint(dp: Datapoint, file_id: str, file_path: str, run_id: str):
440-
def upload_callback(log: dict):
420+
def upload_callback(log_id: str):
441421
"""Logic ran after the Log has been created."""
442-
logger.debug(
443-
"upload_callback on Thread %s: log %s datapoint_target %s",
444-
threading.get_ident(),
445-
log,
446-
dp.target,
447-
)
448422
_run_local_evaluators(
449423
client=client,
450-
log=log,
451-
datapoint_target=dp.target,
424+
log_id=log_id,
425+
datapoint=dp,
452426
local_evaluators=local_evaluators,
453427
)
454-
_PROGRESS_BAR.increment() # type: ignore
428+
_PROGRESS_BAR.increment()
455429

456430
datapoint_dict = dp.dict()
457431
# Set the Evaluation Context for current datapoint
@@ -471,6 +445,7 @@ def upload_callback(log: dict):
471445
# .get() is safe since process_datapoint is always called in the context of an Evaluation
472446
evaluation_context_variable.get(),
473447
)
448+
# TODO: shouldn't this only be defined in case where we actually need to log?
474449
log_func = _get_log_func(
475450
client=client,
476451
file_type=type_,
@@ -481,18 +456,12 @@ def upload_callback(log: dict):
481456
start_time = datetime.now()
482457
try:
483458
if "messages" in datapoint_dict and datapoint_dict["messages"] is not None:
484-
# function_ is decorated by Humanloop, the OTel Exporter will
485-
# handle the logging, which will call the upload_callback
486-
# function above when it's done
487-
output = function_( # type: ignore
459+
output = function_(
488460
**datapoint_dict["inputs"],
489461
messages=datapoint_dict["messages"],
490462
)
491463
else:
492-
# function_ is decorated by Humanloop, the OTel Exporter will
493-
# handle the logging, which will call the upload_callback
494-
# function above when it's done
495-
output = function_(**datapoint_dict["inputs"]) # type: ignore
464+
output = function_(**datapoint_dict["inputs"])
496465

497466
if not isinstance(output, str):
498467
try:
@@ -509,7 +478,7 @@ def upload_callback(log: dict):
509478
logger.debug(
510479
"process_datapoint on Thread %s: function_ %s is a simple callable, context was not consumed",
511480
threading.get_ident(),
512-
function_.__name__, # type: ignore
481+
function_.__name__,
513482
)
514483
log_func(
515484
inputs=datapoint.inputs,
@@ -534,12 +503,12 @@ def upload_callback(log: dict):
534503
logger.info(f"{CYAN}Run ID: {run_id}{RESET}")
535504

536505
# Generate locally if a file `callable` is provided
537-
if function_: # type: ignore
506+
if function_:
538507
logger.info(
539508
f"{CYAN}\nRunning '{hl_file.name}' over the Dataset '{hl_dataset.name}' using {workers} workers{RESET} "
540509
)
541510
with ThreadPoolExecutor(max_workers=workers) as executor:
542-
for datapoint in hl_dataset.datapoints: # type: ignore
511+
for datapoint in hl_dataset.datapoints:
543512
executor.submit(
544513
process_datapoint,
545514
datapoint,
@@ -572,8 +541,8 @@ def upload_callback(log: dict):
572541

573542
# Skip `check_evaluation_improvement` if no thresholds were provided and there is only one run.
574543
# (Or the logs would not be helpful)
575-
if any(evaluator.get("threshold") is not None for evaluator in evaluators) or len(stats.run_stats) > 1: # type: ignore
576-
for evaluator in evaluators: # type: ignore
544+
if any(evaluator.get("threshold") is not None for evaluator in evaluators) or len(stats.run_stats) > 1:
545+
for evaluator in evaluators:
577546
score, delta = _check_evaluation_improvement(
578547
evaluation=evaluation,
579548
stats=stats,
@@ -623,13 +592,13 @@ def _get_log_func(
623592
"run_id": run_id,
624593
}
625594
if file_type == "flow":
626-
return partial(client.flows.log, **log_request, trace_status="complete") # type: ignore
595+
return partial(client.flows.log, **log_request, trace_status="complete")
627596
elif file_type == "prompt":
628-
return partial(client.prompts.log, **log_request) # type: ignore
597+
return partial(client.prompts.log, **log_request)
629598
elif file_type == "evaluator":
630-
return partial(client.evaluators.log, **log_request) # type: ignore
599+
return partial(client.evaluators.log, **log_request)
631600
elif file_type == "tool":
632-
return partial(client.tools.log, **log_request) # type: ignore
601+
return partial(client.tools.log, **log_request)
633602
else:
634603
raise NotImplementedError(f"Unsupported File version: {file_type}")
635604

@@ -643,10 +612,10 @@ def _get_score_from_evaluator_stat(
643612
if stat.total_logs:
644613
score = round(stat.num_true / stat.total_logs, 2)
645614
elif isinstance(stat, NumericStats):
646-
score = round(stat.mean, 2) # type: ignore
615+
score = round(stat.mean, 2)
647616
else:
648617
raise ValueError(f"Unsupported Evaluator Stat type: {type(stat)}")
649-
return score # type: ignore
618+
return score
650619

651620

652621
def _get_evaluator_stats_by_path(
@@ -660,7 +629,7 @@ def _get_evaluator_stats_by_path(
660629
evaluators_by_id[evaluator_stat.evaluator_version_id].version.path: evaluator_stat
661630
for evaluator_stat in stat.evaluator_stats
662631
}
663-
return evaluator_stats_by_path # type: ignore
632+
return evaluator_stats_by_path
664633

665634

666635
def _check_evaluation_threshold(
@@ -675,14 +644,14 @@ def _check_evaluation_threshold(
675644
evaluator_stats_by_path = _get_evaluator_stats_by_path(
676645
stat=next(
677646
(stat for stat in stats.run_stats if stat.run_id == run_id),
678-
None, # type: ignore
647+
None,
679648
),
680649
evaluation=evaluation,
681650
)
682651
if evaluator_path in evaluator_stats_by_path:
683652
evaluator_stat = evaluator_stats_by_path[evaluator_path]
684653
score = _get_score_from_evaluator_stat(stat=evaluator_stat)
685-
if score >= threshold: # type: ignore
654+
if score >= threshold:
686655
logger.info(
687656
f"{GREEN}✅ Latest eval [{score}] above threshold [{threshold}] for evaluator {evaluator_path}.{RESET}"
688657
)
@@ -712,7 +681,7 @@ def _check_evaluation_improvement(
712681
latest_evaluator_stats_by_path = _get_evaluator_stats_by_path(
713682
stat=next(
714683
(stat for stat in stats.run_stats if stat.run_id == run_id),
715-
None, # type: ignore
684+
None,
716685
),
717686
evaluation=evaluation,
718687
)
@@ -731,37 +700,42 @@ def _check_evaluation_improvement(
731700
previous_score = _get_score_from_evaluator_stat(stat=previous_evaluator_stat)
732701
if latest_score is None or previous_score is None:
733702
raise ValueError(f"Could not find score for Evaluator {evaluator_path}.")
734-
diff = round(latest_score - previous_score, 2) # type: ignore
703+
diff = round(latest_score - previous_score, 2)
735704
if diff >= 0:
736705
logger.info(f"{CYAN}Change of [{diff}] for Evaluator {evaluator_path}{RESET}")
737-
return True, latest_score, diff # type: ignore
706+
return True, latest_score, diff
738707
else:
739708
logger.info(f"{CYAN}Change of [{diff}] for Evaluator {evaluator_path}{RESET}")
740-
return False, latest_score, diff # type: ignore
709+
return False, latest_score, diff
741710
else:
742711
raise ValueError(f"Evaluator {evaluator_path} not found in the stats.")
743712

744713

745714
def _run_local_evaluators(
746715
client: "BaseHumanloop",
747-
log: dict,
748-
datapoint_target: typing.Optional[typing.Dict[str, DatapointResponseTargetValue]],
716+
log_id: str,
717+
datapoint: Optional[Datapoint],
749718
local_evaluators: list[Evaluator],
750719
):
720+
"""Run local Evaluators on the Log and send the judgments to Humanloop."""
721+
# Need to get the full log to pass to the evaluators
722+
log = client.logs.get(id=log_id)
723+
log_dict = log.dict()
724+
datapoint_dict = datapoint.dict() if datapoint else None
751725
for local_evaluator in local_evaluators:
752726
start_time = datetime.now()
753727
try:
754728
eval_function = local_evaluator["callable"]
755729
if local_evaluator["args_type"] == "target_required":
756730
judgement = eval_function(
757-
log,
758-
datapoint_target,
731+
log_dict,
732+
datapoint_dict,
759733
)
760734
else:
761-
judgement = eval_function(log)
735+
judgement = eval_function(log_dict)
762736

763737
_ = client.evaluators.log(
764-
parent_id=log["id"],
738+
parent_id=log_id,
765739
judgment=judgement,
766740
id=local_evaluator.get("id"),
767741
path=local_evaluator.get("path"),
@@ -770,7 +744,7 @@ def _run_local_evaluators(
770744
)
771745
except Exception as e:
772746
_ = client.evaluators.log(
773-
parent_id=log["id"],
747+
parent_id=log_id,
774748
path=local_evaluator.get("path"),
775749
id=local_evaluator.get("id"),
776750
error=str(e),

0 commit comments

Comments
 (0)