diff --git a/src/modelplane/evaluator/context.py b/src/modelplane/evaluator/context.py index bf9eff9..04d575d 100644 --- a/src/modelplane/evaluator/context.py +++ b/src/modelplane/evaluator/context.py @@ -26,7 +26,10 @@ class EvalContext: """Context state passed around during DAG execution.""" def __init__( - self, prompt: str, response: str, metadata: Optional[dict[str, Any]] = None + self, + prompt: str, + response: str, + metadata: Optional[dict[str, Any]] = None, ) -> None: self.prompt = prompt self.response = response diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index 455a7e6..3a86b51 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -11,6 +11,7 @@ from typing import Any, Optional import pandas as pd +from airrlogger.log_config import get_logger from modelbench.cache import DiskCache, NullCache from tqdm import tqdm @@ -19,6 +20,8 @@ from modelplane.evaluator.nodes import Arbiter, CacheableNodeMixin, ComposerNode, Gate from modelplane.evaluator.verdict import Verdict +logger = get_logger(__name__) + def requires_validate_and_build(method): @functools.wraps(method) @@ -49,9 +52,38 @@ class NodeExecutionError(Exception): def __init__(self, node_name: str, original_error: Exception): self.node_name = node_name self.original_error = original_error - super().__init__( - f"Error while executing node '{node_name}': {original_error}" - ) + super().__init__(f"Error while executing node '{node_name}': {original_error}") + + +class ComposerColumnNames: + def __init__( + self, + composer_name: Optional[str] = None, + output_col_name: Optional[str] = None, + error_col_name: Optional[str] = None, + dag_run_col_name: Optional[str] = None, + cost_col_name: Optional[str] = None, + ): + if ( + any( + not name + for name in [ + output_col_name, + error_col_name, + dag_run_col_name, + cost_col_name, + ] + ) + and composer_name is None + ): + raise ValueError( + "If any of the column names are not provided, composer_name must be provided to generate default column names." + ) + + self.output_col = output_col_name or f"{composer_name}_output" + self.error_col = error_col_name or f"{composer_name}_error" + self.dag_run_col = dag_run_col_name or f"{composer_name}_dag_run" + self.cost_col = cost_col_name or f"{composer_name}_dag_cost" class Composer: @@ -77,7 +109,11 @@ class Composer: """ def __init__( - self, name: str, verdict_type: type, cache_path: Optional[Path] = None + self, + name: str, + verdict_type: type, + cache_path: Optional[Path] = None, + col_names: Optional[ComposerColumnNames] = None, ) -> None: self.name = name self._nodes: dict[str, ComposerNode] = {} @@ -90,6 +126,7 @@ def __init__( self._verdict_type = verdict_type self._cache_path = cache_path self._node_caches = {} + self._col_names = col_names or ComposerColumnNames(composer_name=name) @property def verdict_type(self) -> type: @@ -97,19 +134,19 @@ def verdict_type(self) -> type: @property def df_output_col(self) -> str: - return f"{self.name}_output" + return self._col_names.output_col @property def df_error_col(self) -> str: - return f"{self.name}_error" + return self._col_names.error_col @property def df_dag_run_col(self) -> str: - return f"{self.name}_dag_run" + return self._col_names.dag_run_col @property def df_cost_col(self) -> str: - return f"{self.name}_dag_cost" + return self._col_names.cost_col def add_node( self, @@ -237,7 +274,9 @@ def _run_traced( wrapped_error = NodeExecutionError(node.name, e) return ( FailedDAGOutput( - node_outputs=node_outputs, total_cost=total_cost, error=wrapped_error + node_outputs=node_outputs, + total_cost=total_cost, + error=wrapped_error, ), traversed_edges, ) @@ -285,10 +324,21 @@ def run_dataframe( """Run the DAG over every row of a DataFrame.""" def _run_row(row: Any) -> SuccessfulDAGOutput | FailedDAGOutput: + metadata = None + if metadata_col: + row_val = row[metadata_col] + if row_val: + try: + metadata = json.loads(row_val) + except Exception as e: + logger.warning( + "Failed to parse json metadata in row. Proceeding with no metadata." + ) + logger.debug(f"Metadata parsing error: {e}") ctx = EvalContext( prompt=str(row[prompt_col]), response=str(row[response_col]), - metadata=json.loads(row[metadata_col]) if metadata_col else None, + metadata=metadata, ) return self.run(ctx) @@ -368,7 +418,7 @@ def _visualize( traversed_edges: Optional[set[tuple[str, str]]] = None, final_output: Optional[Verdict] = None, ctx: Optional[EvalContext] = None, - ): + ): # pragma: no cover """Render the DAG as a PNG image. In a Jupyter notebook the image is displayed inline. When node_outputs/traversed_edges/final_output are provided (via visualize_run), @@ -606,7 +656,7 @@ def _truncate(s: str, n: int = 24) -> str: ) from e @requires_validate_and_build - def visualize(self): + def visualize(self): # pragma: no cover """Render the DAG structure as a PNG image (inline in Jupyter notebooks). The graph flows left to right. Node shapes and colors: @@ -626,7 +676,7 @@ def visualize(self): return self._visualize() @requires_validate_and_build - def visualize_run(self, ctx: EvalContext): + def visualize_run(self, ctx: EvalContext): # pragma: no cover """Run the DAG on ctx and return a visualization with the executed path highlighted. Identical layout to visualize(), with the following additions: diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_composer.py similarity index 89% rename from tests/unit/evaluator/test_dag.py rename to tests/unit/evaluator/test_composer.py index 022b9b8..5c24388 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_composer.py @@ -7,7 +7,7 @@ import pytest from modelplane.evaluator.context import EvalContext -from modelplane.evaluator.dag import Composer +from modelplane.evaluator.dag import Composer, ComposerColumnNames from modelplane.evaluator.safety import Safety from .conftest import skip_in_ci @@ -22,6 +22,7 @@ UpperCaser, ) + def test_dag_outputs(simple_dag): assert simple_dag.verdict_type == Safety @@ -132,14 +133,11 @@ def test_dag_cache_miss_on_different_context(cached_minimal_dag): def test_dag_cacheable_node_without_cache_path_runs_each_time(sample_ctx): AlwaysTrueCacheable.run_count = 0 - dag = ( - Composer("no_cache", verdict_type=Safety) - .add_node( - AlwaysTrueCacheable( - name="always_true", - routes_true=[Safety(is_safe=True)], - routes_false=[Safety(is_safe=False)], - ) + dag = Composer("no_cache", verdict_type=Safety).add_node( + AlwaysTrueCacheable( + name="always_true", + routes_true=[Safety(is_safe=True)], + routes_false=[Safety(is_safe=False)], ) ) dag.run(sample_ctx) @@ -208,6 +206,7 @@ def test_dag_passes_updated_context_to_downstream_nodes(): .add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5)) ) dag_output = dag.run(ctx) + assert dag_output.node_outputs["lower_caser"].updated_ctx is not None assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello" # Scorer reads ctx.response; 1.0 only if it saw the lowercased update from lower_caser. assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0) @@ -234,6 +233,7 @@ def test_dag_updated_context_not_passed_to_parallel_nodes(): dag_output = dag.run(ctx) assert dag_output.node_outputs["lower_caser"].original_ctx.response == "HELLO" + assert dag_output.node_outputs["lower_caser"].updated_ctx is not None assert dag_output.node_outputs["lower_caser"].updated_ctx.response == "hello" assert dag_output.node_outputs["noop"].original_ctx.response == "HELLO" @@ -277,9 +277,15 @@ def test_dag_run_with_dataframe(simple_dag, tmp_path): { "prompt": ["a", "ab", "abc", "abcd"], # odd, even, odd, even "response": ["Hello world", "Helloworld", "Hello world", "Helloworld"], + "metadata": [ + json.dumps({"key": "value1"}), + json.dumps({"key": "value2"}), + "notvalidjson", + None, + ], } ) - result_df = simple_dag.run_dataframe(df) + result_df = simple_dag.run_dataframe(df, metadata_col="metadata") assert len(result_df) == len(df) assert "prompt" in result_df.columns @@ -370,3 +376,30 @@ def test_visualize_raises_when_graphviz_binary_missing(simple_dag): match="Graphviz system binaries not found", ): simple_dag.visualize() + + +def test_composer_names_orig(simple_dag): + assert simple_dag.name == "simple" + assert simple_dag.df_output_col == "simple_output" + assert simple_dag.df_error_col == "simple_error" + assert simple_dag.df_dag_run_col == "simple_dag_run" + assert simple_dag.df_cost_col == "simple_dag_cost" + + +def test_composer_names_override(): + dag = Composer( + name="dag_name", + verdict_type=Safety, + col_names=ComposerColumnNames( + composer_name="dag_name", output_col_name="my_output" + ), + ) + assert dag.df_output_col == "my_output" + assert dag.df_error_col == "dag_name_error" + assert dag.df_dag_run_col == "dag_name_dag_run" + assert dag.df_cost_col == "dag_name_dag_cost" + + +def test_composer_names_partial_override_no_name_raises(): + with pytest.raises(ValueError, match="composer_name must be provided"): + ComposerColumnNames(output_col_name="my_output")