Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/modelplane/evaluator/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ class EvalContext:
"""Context state passed around during DAG execution."""

def __init__(
self, prompt: str, response: str, metadata: Optional[dict[str, Any]] = None
self,
prompt: str,
response: str,
metadata: Optional[dict[str, Any]] = None,
) -> None:
self.prompt = prompt
self.response = response
Expand Down
76 changes: 63 additions & 13 deletions src/modelplane/evaluator/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Optional

import pandas as pd
from airrlogger.log_config import get_logger
from modelbench.cache import DiskCache, NullCache
from tqdm import tqdm

Expand All @@ -19,6 +20,8 @@
from modelplane.evaluator.nodes import Arbiter, CacheableNodeMixin, ComposerNode, Gate
from modelplane.evaluator.verdict import Verdict

logger = get_logger(__name__)


def requires_validate_and_build(method):
@functools.wraps(method)
Expand Down Expand Up @@ -49,9 +52,38 @@ class NodeExecutionError(Exception):
def __init__(self, node_name: str, original_error: Exception):
self.node_name = node_name
self.original_error = original_error
super().__init__(
f"Error while executing node '{node_name}': {original_error}"
)
super().__init__(f"Error while executing node '{node_name}': {original_error}")


class ComposerColumnNames:
def __init__(
self,
composer_name: Optional[str] = None,
output_col_name: Optional[str] = None,
error_col_name: Optional[str] = None,
dag_run_col_name: Optional[str] = None,
cost_col_name: Optional[str] = None,
):
if (
any(
not name
for name in [
output_col_name,
error_col_name,
dag_run_col_name,
cost_col_name,
]
)
and composer_name is None
):
raise ValueError(
"If any of the column names are not provided, composer_name must be provided to generate default column names."
)

self.output_col = output_col_name or f"{composer_name}_output"
self.error_col = error_col_name or f"{composer_name}_error"
self.dag_run_col = dag_run_col_name or f"{composer_name}_dag_run"
self.cost_col = cost_col_name or f"{composer_name}_dag_cost"


class Composer:
Expand All @@ -77,7 +109,11 @@ class Composer:
"""

def __init__(
self, name: str, verdict_type: type, cache_path: Optional[Path] = None
self,
name: str,
verdict_type: type,
cache_path: Optional[Path] = None,
col_names: Optional[ComposerColumnNames] = None,
) -> None:
self.name = name
self._nodes: dict[str, ComposerNode] = {}
Expand All @@ -90,26 +126,27 @@ def __init__(
self._verdict_type = verdict_type
self._cache_path = cache_path
self._node_caches = {}
self._col_names = col_names or ComposerColumnNames(composer_name=name)

@property
def verdict_type(self) -> type:
return self._verdict_type

@property
def df_output_col(self) -> str:
return f"{self.name}_output"
return self._col_names.output_col

@property
def df_error_col(self) -> str:
return f"{self.name}_error"
return self._col_names.error_col

@property
def df_dag_run_col(self) -> str:
return f"{self.name}_dag_run"
return self._col_names.dag_run_col

@property
def df_cost_col(self) -> str:
return f"{self.name}_dag_cost"
return self._col_names.cost_col

def add_node(
self,
Expand Down Expand Up @@ -237,7 +274,9 @@ def _run_traced(
wrapped_error = NodeExecutionError(node.name, e)
return (
FailedDAGOutput(
node_outputs=node_outputs, total_cost=total_cost, error=wrapped_error
node_outputs=node_outputs,
total_cost=total_cost,
error=wrapped_error,
),
traversed_edges,
)
Expand Down Expand Up @@ -285,10 +324,21 @@ def run_dataframe(
"""Run the DAG over every row of a DataFrame."""

def _run_row(row: Any) -> SuccessfulDAGOutput | FailedDAGOutput:
metadata = None
if metadata_col:
row_val = row[metadata_col]
if row_val:
try:
metadata = json.loads(row_val)
except Exception as e:
logger.warning(
"Failed to parse json metadata in row. Proceeding with no metadata."
)
logger.debug(f"Metadata parsing error: {e}")
ctx = EvalContext(
prompt=str(row[prompt_col]),
response=str(row[response_col]),
metadata=json.loads(row[metadata_col]) if metadata_col else None,
metadata=metadata,
)
return self.run(ctx)

Expand Down Expand Up @@ -368,7 +418,7 @@ def _visualize(
traversed_edges: Optional[set[tuple[str, str]]] = None,
final_output: Optional[Verdict] = None,
ctx: Optional[EvalContext] = None,
):
): # pragma: no cover
"""Render the DAG as a PNG image. In a Jupyter notebook the image is displayed inline.

When node_outputs/traversed_edges/final_output are provided (via visualize_run),
Expand Down Expand Up @@ -606,7 +656,7 @@ def _truncate(s: str, n: int = 24) -> str:
) from e

@requires_validate_and_build
def visualize(self):
def visualize(self): # pragma: no cover
"""Render the DAG structure as a PNG image (inline in Jupyter notebooks).

The graph flows left to right. Node shapes and colors:
Expand All @@ -626,7 +676,7 @@ def visualize(self):
return self._visualize()

@requires_validate_and_build
def visualize_run(self, ctx: EvalContext):
def visualize_run(self, ctx: EvalContext): # pragma: no cover
"""Run the DAG on ctx and return a visualization with the executed path highlighted.

Identical layout to visualize(), with the following additions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from modelplane.evaluator.context import EvalContext
from modelplane.evaluator.dag import Composer
from modelplane.evaluator.dag import Composer, ComposerColumnNames
from modelplane.evaluator.safety import Safety

from .conftest import skip_in_ci
Expand All @@ -22,6 +22,7 @@
UpperCaser,
)


def test_dag_outputs(simple_dag):
assert simple_dag.verdict_type == Safety

Expand Down Expand Up @@ -132,14 +133,11 @@ def test_dag_cache_miss_on_different_context(cached_minimal_dag):

def test_dag_cacheable_node_without_cache_path_runs_each_time(sample_ctx):
AlwaysTrueCacheable.run_count = 0
dag = (
Composer("no_cache", verdict_type=Safety)
.add_node(
AlwaysTrueCacheable(
name="always_true",
routes_true=[Safety(is_safe=True)],
routes_false=[Safety(is_safe=False)],
)
dag = Composer("no_cache", verdict_type=Safety).add_node(
AlwaysTrueCacheable(
name="always_true",
routes_true=[Safety(is_safe=True)],
routes_false=[Safety(is_safe=False)],
)
)
dag.run(sample_ctx)
Expand Down Expand Up @@ -208,6 +206,7 @@ def test_dag_passes_updated_context_to_downstream_nodes():
.add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5))
)
dag_output = dag.run(ctx)
assert dag_output.node_outputs["lower_caser"].updated_ctx is not None
assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello"
# Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser.
assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0)
Expand All @@ -234,6 +233,7 @@ def test_dag_updated_context_not_passed_to_parallel_nodes():
dag_output = dag.run(ctx)

assert dag_output.node_outputs["lower_caser"].original_ctx.response == "HELLO"
assert dag_output.node_outputs["lower_caser"].updated_ctx is not None
assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello"

assert dag_output.node_outputs["noop"].original_ctx.response == "HELLO"
Expand Down Expand Up @@ -277,9 +277,15 @@ def test_dag_run_with_dataframe(simple_dag, tmp_path):
{
"prompt": ["a", "ab", "abc", "abcd"], # odd, even, odd, even
"response": ["Hello world", "Helloworld", "Hello world", "Helloworld"],
"metadata": [
json.dumps({"key": "value1"}),
json.dumps({"key": "value2"}),
"notvalidjson",
None,
],
}
)
result_df = simple_dag.run_dataframe(df)
result_df = simple_dag.run_dataframe(df, metadata_col="metadata")

assert len(result_df) == len(df)
assert "prompt" in result_df.columns
Expand Down Expand Up @@ -370,3 +376,30 @@ def test_visualize_raises_when_graphviz_binary_missing(simple_dag):
match="Graphviz system binaries not found",
):
simple_dag.visualize()


def test_composer_names_orig(simple_dag):
assert simple_dag.name == "simple"
assert simple_dag.df_output_col == "simple_output"
assert simple_dag.df_error_col == "simple_error"
assert simple_dag.df_dag_run_col == "simple_dag_run"
assert simple_dag.df_cost_col == "simple_dag_cost"


def test_composer_names_override():
dag = Composer(
name="dag_name",
verdict_type=Safety,
col_names=ComposerColumnNames(
composer_name="dag_name", output_col_name="my_output"
),
)
assert dag.df_output_col == "my_output"
assert dag.df_error_col == "dag_name_error"
assert dag.df_dag_run_col == "dag_name_dag_run"
assert dag.df_cost_col == "dag_name_dag_cost"


def test_composer_names_partial_override_no_name_raises():
with pytest.raises(ValueError, match="composer_name must be provided"):
ComposerColumnNames(output_col_name="my_output")
Loading