Skip to content

Commit 3fd5b7e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - auto-infer metric/candidate and validate inputs for generate_loss_clusters
PiperOrigin-RevId: 894079615
1 parent 9e9dd70 commit 3fd5b7e

3 files changed

Lines changed: 1672 additions & 1311 deletions

File tree

tests/unit/vertexai/genai/test_evals.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,170 @@ def test_loss_analysis_result_show(self, capsys):
569569
assert "c1" in captured.out
570570

571571

572+
def _make_eval_result(
573+
metrics=None,
574+
candidate_names=None,
575+
):
576+
"""Helper to create an EvaluationResult with the given metrics and candidates."""
577+
metrics = metrics or ["task_success_v1"]
578+
candidate_names = candidate_names or ["agent-1"]
579+
580+
metric_results = {}
581+
for m in metrics:
582+
metric_results[m] = common_types.EvalCaseMetricResult(metric_name=m)
583+
584+
eval_case_results = [
585+
common_types.EvalCaseResult(
586+
eval_case_index=0,
587+
response_candidate_results=[
588+
common_types.ResponseCandidateResult(
589+
response_index=0,
590+
metric_results=metric_results,
591+
)
592+
],
593+
)
594+
]
595+
metadata = common_types.EvaluationRunMetadata(
596+
candidate_names=candidate_names,
597+
)
598+
return common_types.EvaluationResult(
599+
eval_case_results=eval_case_results,
600+
metadata=metadata,
601+
)
602+
603+
604+
class TestResolveMetricName:
605+
"""Unit tests for _resolve_metric_name."""
606+
607+
def test_none_returns_none(self):
608+
assert _evals_utils._resolve_metric_name(None) is None
609+
610+
def test_string_passes_through(self):
611+
assert _evals_utils._resolve_metric_name("task_success_v1") == "task_success_v1"
612+
613+
def test_metric_object_extracts_name(self):
614+
metric = common_types.Metric(name="multi_turn_task_success_v1")
615+
assert _evals_utils._resolve_metric_name(metric) == "multi_turn_task_success_v1"
616+
617+
def test_object_with_name_attr(self):
618+
"""Tests that any object with a .name attribute works (e.g., LazyLoadedPrebuiltMetric)."""
619+
620+
class FakeMetric:
621+
name = "tool_use_quality_v1"
622+
623+
assert _evals_utils._resolve_metric_name(FakeMetric()) == "tool_use_quality_v1"
624+
625+
def test_lazy_loaded_prebuilt_metric_resolves_versioned_name(self):
626+
"""Tests that LazyLoadedPrebuiltMetric resolves to the versioned API spec name."""
627+
628+
class FakeLazyMetric:
629+
name = "MULTI_TURN_TASK_SUCCESS"
630+
631+
def _get_api_metric_spec_name(self):
632+
return "multi_turn_task_success_v1"
633+
634+
assert (
635+
_evals_utils._resolve_metric_name(FakeLazyMetric())
636+
== "multi_turn_task_success_v1"
637+
)
638+
639+
def test_lazy_loaded_prebuilt_metric_falls_back_to_name(self):
640+
"""Tests fallback to .name when _get_api_metric_spec_name returns None."""
641+
642+
class FakeLazyMetricNoSpec:
643+
name = "CUSTOM_METRIC"
644+
645+
def _get_api_metric_spec_name(self):
646+
return None
647+
648+
assert (
649+
_evals_utils._resolve_metric_name(FakeLazyMetricNoSpec()) == "CUSTOM_METRIC"
650+
)
651+
652+
653+
class TestResolveLossAnalysisConfig:
654+
"""Unit tests for _resolve_loss_analysis_config."""
655+
656+
def test_auto_infer_single_metric_and_candidate(self):
657+
eval_result = _make_eval_result(
658+
metrics=["task_success_v1"], candidate_names=["agent-1"]
659+
)
660+
resolved = _evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
661+
assert resolved.metric == "task_success_v1"
662+
assert resolved.candidate == "agent-1"
663+
664+
def test_explicit_metric_and_candidate(self):
665+
eval_result = _make_eval_result(
666+
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
667+
)
668+
resolved = _evals_utils._resolve_loss_analysis_config(
669+
eval_result=eval_result, metric="m1", candidate="c2"
670+
)
671+
assert resolved.metric == "m1"
672+
assert resolved.candidate == "c2"
673+
674+
def test_config_provides_metric_and_candidate(self):
675+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
676+
config = common_types.LossAnalysisConfig(
677+
metric="m1", candidate="c1", predefined_taxonomy="my_taxonomy"
678+
)
679+
resolved = _evals_utils._resolve_loss_analysis_config(
680+
eval_result=eval_result, config=config
681+
)
682+
assert resolved.metric == "m1"
683+
assert resolved.candidate == "c1"
684+
assert resolved.predefined_taxonomy == "my_taxonomy"
685+
686+
def test_explicit_args_override_config(self):
687+
eval_result = _make_eval_result(
688+
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
689+
)
690+
config = common_types.LossAnalysisConfig(metric="m1", candidate="c1")
691+
resolved = _evals_utils._resolve_loss_analysis_config(
692+
eval_result=eval_result, config=config, metric="m2", candidate="c2"
693+
)
694+
assert resolved.metric == "m2"
695+
assert resolved.candidate == "c2"
696+
697+
def test_error_multiple_metrics_no_explicit(self):
698+
eval_result = _make_eval_result(metrics=["m1", "m2"], candidate_names=["c1"])
699+
with pytest.raises(ValueError, match="multiple metrics"):
700+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
701+
702+
def test_error_multiple_candidates_no_explicit(self):
703+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1", "c2"])
704+
with pytest.raises(ValueError, match="multiple candidates"):
705+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
706+
707+
def test_error_invalid_metric(self):
708+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
709+
with pytest.raises(ValueError, match="not found in eval_result"):
710+
_evals_utils._resolve_loss_analysis_config(
711+
eval_result=eval_result, metric="nonexistent"
712+
)
713+
714+
def test_error_invalid_candidate(self):
715+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
716+
with pytest.raises(ValueError, match="not found in eval_result"):
717+
_evals_utils._resolve_loss_analysis_config(
718+
eval_result=eval_result, candidate="nonexistent"
719+
)
720+
721+
def test_no_candidates_defaults_to_candidate_1(self):
722+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=[])
723+
eval_result = eval_result.model_copy(
724+
update={"metadata": common_types.EvaluationRunMetadata()}
725+
)
726+
resolved = _evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
727+
assert resolved.metric == "m1"
728+
assert resolved.candidate == "candidate_1"
729+
730+
def test_no_eval_case_results_raises(self):
731+
eval_result = common_types.EvaluationResult()
732+
with pytest.raises(ValueError, match="no metric results"):
733+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
734+
735+
572736
class TestEvals:
573737
"""Unit tests for the GenAI client."""
574738

vertexai/_genai/_evals_utils.py

Lines changed: 175 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -419,34 +419,181 @@ def _postprocess_user_scenarios_response(
419419
def _display_loss_analysis_result(
420420
result: types.LossAnalysisResult,
421421
) -> None:
422-
"""Displays a LossAnalysisResult as a formatted pandas DataFrame."""
423-
metric = result.config.metric if result.config else None
424-
candidate = result.config.candidate if result.config else None
425-
rows: list[dict[str, Any]] = []
426-
for cluster in result.clusters or []:
427-
entry = cluster.taxonomy_entry
428-
row = {
429-
"metric": metric,
430-
"candidate": candidate,
431-
"cluster_id": cluster.cluster_id,
432-
"l1_category": entry.l1_category if entry else None,
433-
"l2_category": entry.l2_category if entry else None,
434-
"description": entry.description if entry else None,
435-
"item_count": cluster.item_count,
436-
}
437-
rows.append(row)
438-
439-
if not rows:
440-
logger.info("No loss clusters found.")
441-
return
442-
443-
df = pd.DataFrame(rows)
444-
try:
445-
from IPython.display import display # pylint: disable=g-import-not-at-top
446-
447-
display(df)
448-
except ImportError:
449-
print(df.to_string()) # pylint: disable=print-function
422+
"""Displays a LossAnalysisResult as a formatted pandas DataFrame."""
423+
metric = result.config.metric if result.config else None
424+
candidate = result.config.candidate if result.config else None
425+
rows: list[dict[str, Any]] = []
426+
for cluster in result.clusters or []:
427+
entry = cluster.taxonomy_entry
428+
row = {
429+
"metric": metric,
430+
"candidate": candidate,
431+
"cluster_id": cluster.cluster_id,
432+
"l1_category": entry.l1_category if entry else None,
433+
"l2_category": entry.l2_category if entry else None,
434+
"description": entry.description if entry else None,
435+
"item_count": cluster.item_count,
436+
}
437+
rows.append(row)
438+
439+
if not rows:
440+
logger.info("No loss clusters found.")
441+
return
442+
443+
df = pd.DataFrame(rows)
444+
try:
445+
from IPython.display import display # pylint: disable=g-import-not-at-top
446+
447+
display(df)
448+
except ImportError:
449+
print(df.to_string()) # pylint: disable=print-function
450+
451+
452+
def _resolve_metric_name(
453+
metric: Optional[Any],
454+
) -> Optional[str]:
455+
"""Extracts a metric name string from a metric argument.
456+
457+
Accepts a string, a Metric object, or a LazyLoadedPrebuiltMetric
458+
(RubricMetric) and returns the metric name as a string.
459+
460+
For LazyLoadedPrebuiltMetric (e.g., RubricMetric.MULTI_TURN_TASK_SUCCESS),
461+
this resolves to the API metric spec name (e.g.,
462+
"multi_turn_task_success_v1") so it matches the keys in eval results.
463+
464+
Args:
465+
metric: A metric name string, Metric object, RubricMetric enum value, or
466+
None.
467+
468+
Returns:
469+
The metric name as a string, or None if metric is None.
470+
"""
471+
if metric is None:
472+
return None
473+
if isinstance(metric, str):
474+
return metric
475+
# LazyLoadedPrebuiltMetric: resolve to versioned API spec name.
476+
if hasattr(metric, "_get_api_metric_spec_name"):
477+
spec_name: Optional[str] = metric._get_api_metric_spec_name()
478+
if spec_name:
479+
return spec_name
480+
# Metric objects and other types with a .name attribute.
481+
if hasattr(metric, "name"):
482+
return str(metric.name)
483+
return str(metric)
484+
485+
486+
def _resolve_loss_analysis_config(
487+
eval_result: types.EvaluationResult,
488+
config: Optional[types.LossAnalysisConfig] = None,
489+
metric: Optional[str] = None,
490+
candidate: Optional[str] = None,
491+
) -> types.LossAnalysisConfig:
492+
"""Resolves and validates the LossAnalysisConfig for generate_loss_clusters.
493+
494+
Auto-infers `metric` and `candidate` from the EvaluationResult when not
495+
explicitly provided. Validates that provided values exist in the eval result.
496+
497+
Args:
498+
eval_result: The EvaluationResult from client.evals.evaluate().
499+
config: Optional explicit LossAnalysisConfig. If provided, metric and
500+
candidate from config take precedence over the separate arguments.
501+
metric: Optional metric name override.
502+
candidate: Optional candidate name override.
503+
504+
Returns:
505+
A resolved LossAnalysisConfig with metric and candidate populated.
506+
507+
Raises:
508+
ValueError: If metric/candidate cannot be inferred or are invalid.
509+
"""
510+
# Start from config if provided, otherwise create a new one.
511+
if config is not None:
512+
resolved_metric = metric or config.metric
513+
resolved_candidate = candidate or config.candidate
514+
resolved_config = config.model_copy(
515+
update={"metric": resolved_metric, "candidate": resolved_candidate}
516+
)
517+
else:
518+
resolved_config = types.LossAnalysisConfig(
519+
metric=metric, candidate=candidate
520+
)
521+
522+
# Collect available metric names from the eval result.
523+
available_metrics: set[str] = set()
524+
if eval_result.eval_case_results:
525+
for case_result in eval_result.eval_case_results:
526+
for resp_cand in case_result.response_candidate_results or []:
527+
for m_name in (resp_cand.metric_results or {}).keys():
528+
available_metrics.add(m_name)
529+
530+
# Collect available candidate names from metadata.
531+
available_candidates: list[str] = []
532+
if eval_result.metadata and eval_result.metadata.candidate_names:
533+
available_candidates = list(eval_result.metadata.candidate_names)
534+
535+
# Auto-infer metric if not provided.
536+
if not resolved_config.metric:
537+
if len(available_metrics) == 1:
538+
resolved_config = resolved_config.model_copy(
539+
update={"metric": next(iter(available_metrics))}
540+
)
541+
elif len(available_metrics) == 0:
542+
raise ValueError(
543+
"Cannot infer metric: no metric results found in eval_result."
544+
" Please provide metric explicitly via"
545+
" config=types.LossAnalysisConfig(metric='...')."
546+
)
547+
else:
548+
raise ValueError(
549+
"Cannot infer metric: multiple metrics found in eval_result:"
550+
f" {sorted(available_metrics)}. Please provide metric"
551+
" explicitly via config=types.LossAnalysisConfig(metric='...')."
552+
)
553+
554+
# Validate metric if provided explicitly.
555+
if available_metrics and resolved_config.metric not in available_metrics:
556+
raise ValueError(
557+
f"Metric '{resolved_config.metric}' not found in eval_result."
558+
f" Available metrics: {sorted(available_metrics)}."
559+
)
560+
561+
# Auto-infer candidate if not provided.
562+
if not resolved_config.candidate:
563+
if len(available_candidates) == 1:
564+
resolved_config = resolved_config.model_copy(
565+
update={"candidate": available_candidates[0]}
566+
)
567+
elif len(available_candidates) == 0:
568+
# Fallback: use default candidate naming convention from SDK.
569+
resolved_config = resolved_config.model_copy(
570+
update={"candidate": "candidate_1"}
571+
)
572+
logger.warning(
573+
"No candidate names found in eval_result.metadata."
574+
" Defaulting to 'candidate_1'. If this is incorrect, provide"
575+
" candidate explicitly via"
576+
" config=types.LossAnalysisConfig(candidate='...')."
577+
)
578+
else:
579+
raise ValueError(
580+
"Cannot infer candidate: multiple candidates found in"
581+
f" eval_result: {available_candidates}. Please provide"
582+
" candidate explicitly via"
583+
" config=types.LossAnalysisConfig(candidate='...')."
584+
)
585+
586+
# Validate candidate if provided explicitly and candidates are known.
587+
if (
588+
available_candidates
589+
and resolved_config.candidate not in available_candidates
590+
):
591+
raise ValueError(
592+
f"Candidate '{resolved_config.candidate}' not found in"
593+
f" eval_result. Available candidates: {available_candidates}."
594+
)
595+
596+
return resolved_config
450597

451598

452599
def _poll_operation(

0 commit comments

Comments
 (0)