From a074f8f241639c34f02c9a3f03285f4bda1dcb3c Mon Sep 17 00:00:00 2001 From: Nijat K Date: Sun, 3 May 2026 20:43:39 -0400 Subject: [PATCH 1/8] Add Flow.model generated callable models Signed-off-by: Nijat K --- ccflow/__init__.py | 1 + ccflow/_flow_model_binding.py | 555 +++ ccflow/callable.py | 139 +- ccflow/context.py | 60 +- ccflow/evaluators/common.py | 256 +- ccflow/flow_model.py | 1969 +++++++++++ ccflow/tests/config/conf_flow.yaml | 73 + ccflow/tests/evaluators/test_common.py | 98 + ccflow/tests/test_callable.py | 249 ++ ccflow/tests/test_context.py | 9 +- ccflow/tests/test_evaluator.py | 70 +- ccflow/tests/test_flow_context.py | 354 ++ ccflow/tests/test_flow_model.py | 3028 +++++++++++++++++ ccflow/tests/test_flow_model_hydra.py | 135 + docs/wiki/Flow-Model.md | 488 +++ docs/wiki/Key-Features.md | 4 + docs/wiki/Workflows.md | 6 + docs/wiki/_Sidebar.md | 1 + .../config/flow_model_hydra_builder_demo.yaml | 26 + examples/flow_model_example.py | 101 + examples/flow_model_hydra_builder_demo.py | 113 + 21 files changed, 7709 insertions(+), 26 deletions(-) create mode 100644 ccflow/_flow_model_binding.py create mode 100644 ccflow/flow_model.py create mode 100644 ccflow/tests/config/conf_flow.yaml create mode 100644 ccflow/tests/test_flow_context.py create mode 100644 ccflow/tests/test_flow_model.py create mode 100644 ccflow/tests/test_flow_model_hydra.py create mode 100644 docs/wiki/Flow-Model.md create mode 100644 examples/config/flow_model_hydra_builder_demo.yaml create mode 100644 examples/flow_model_example.py create mode 100644 examples/flow_model_hydra_builder_demo.py diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 71263b5..30ffd9c 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -11,6 +11,7 @@ from .callable import * from .context import * from .enums import Enum +from .flow_model import * from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/_flow_model_binding.py b/ccflow/_flow_model_binding.py new file mode 100644 index 0000000..2e67eac --- /dev/null +++ b/ccflow/_flow_model_binding.py @@ -0,0 +1,555 @@ +"""Shared signature and context-contract analysis for Flow authoring APIs.""" + +import inspect +from dataclasses import dataclass, field +from functools import wraps +from types import UnionType +from typing import Annotated, Any, Callable, Dict, Literal, Optional, Tuple, Type, Union, get_args, get_origin, get_type_hints + +from .base import ContextBase, ResultBase +from .context import FlowContext +from .exttypes import PyObjectPath +from .local_persistence import create_ccflow_model +from .result import GenericResult + +_AnyCallable = Callable[..., Any] +_UNION_ORIGINS = (Union, UnionType) + + +class _InternalSentinel: + def __init__(self, name: str): + self._name = name + + def __repr__(self) -> str: + return self._name + + def __reduce__(self): + return (_get_internal_sentinel, (self._name,)) + + +def _get_internal_sentinel(name: str) -> _InternalSentinel: + return _INTERNAL_SENTINELS[name] + + +_INTERNAL_SENTINELS = { + "_UNSET": _InternalSentinel("_UNSET"), + "_REMOVED_CONTEXT_ARGS": _InternalSentinel("_REMOVED_CONTEXT_ARGS"), +} +_UNSET = _INTERNAL_SENTINELS["_UNSET"] +_REMOVED_CONTEXT_ARGS = _INTERNAL_SENTINELS["_REMOVED_CONTEXT_ARGS"] +_RESERVED_FLOW_MODEL_PARAM_NAMES = frozenset({"flow", "meta", "context_type", "result_type"}) + + +class _LazyMarker: + pass + + +class _FromContextMarker: + pass + + +class FromContext: + """Marker used in ``@Flow.model`` signatures for runtime/contextual inputs.""" + + def __class_getitem__(cls, item): + return Annotated[item, _FromContextMarker()] + + +class Lazy: + """Lazy dependency marker used only as ``Lazy[T]`` in type annotations.""" + + def __new__(cls, *args, **kwargs): + raise TypeError("Lazy(model)(...) has been removed. Use model.flow.with_context(...) for contextual rewrites.") + + def __class_getitem__(cls, item): + return Annotated[item, _LazyMarker()] + + +@dataclass(frozen=True) +class _ParsedAnnotation: + base: Any + is_lazy: bool + is_from_context: bool + + +@dataclass(frozen=True) +class _FlowModelParam: + name: str + annotation: Any + kind: str + is_lazy: bool + has_function_default: bool + function_default: Any = _UNSET + context_validation_annotation: Any = _UNSET + + @property + def is_contextual(self) -> bool: + return self.kind == "contextual" + + @property + def validation_annotation(self) -> Any: + if self.context_validation_annotation is not _UNSET: + return self.context_validation_annotation + return self.annotation + + +@dataclass(frozen=True) +class _FlowModelConfig: + func: _AnyCallable + return_annotation: Any + context_type: Type[ContextBase] + result_type: Type[ResultBase] + auto_wrap_result: bool + auto_unwrap: bool + parameters: Tuple[_FlowModelParam, ...] + declared_context_type: Optional[Type[ContextBase]] = None + path: Optional[PyObjectPath] = None + _regular_params: Tuple[_FlowModelParam, ...] = field(init=False, repr=False) + _contextual_params: Tuple[_FlowModelParam, ...] = field(init=False, repr=False) + _regular_param_names: Tuple[str, ...] = field(init=False, repr=False) + _contextual_param_names: Tuple[str, ...] = field(init=False, repr=False) + _params_by_name: Dict[str, _FlowModelParam] = field(init=False, repr=False) + + def __post_init__(self) -> None: + regular = tuple(param for param in self.parameters if not param.is_contextual) + contextual = tuple(param for param in self.parameters if param.is_contextual) + object.__setattr__(self, "_regular_params", regular) + object.__setattr__(self, "_contextual_params", contextual) + object.__setattr__(self, "_regular_param_names", tuple(param.name for param in regular)) + object.__setattr__(self, "_contextual_param_names", tuple(param.name for param in contextual)) + object.__setattr__(self, "_params_by_name", {param.name: param for param in self.parameters}) + + @property + def regular_params(self) -> Tuple[_FlowModelParam, ...]: + return self._regular_params + + @property + def contextual_params(self) -> Tuple[_FlowModelParam, ...]: + return self._contextual_params + + @property + def regular_param_names(self) -> Tuple[str, ...]: + return self._regular_param_names + + @property + def contextual_param_names(self) -> Tuple[str, ...]: + return self._contextual_param_names + + @property + def context_input_types(self) -> Dict[str, Any]: + return {param.name: param.validation_annotation for param in self.contextual_params} + + @property + def context_required_names(self) -> Tuple[str, ...]: + return tuple(param.name for param in self.contextual_params if not param.has_function_default) + + def param(self, name: str) -> _FlowModelParam: + return self._params_by_name[name] + + +@dataclass(frozen=True) +class _AutoContextSpec: + signature: inspect.Signature + base_class: Type[ContextBase] + class_name: str + fields: Dict[str, Tuple[Any, Any]] + + +def _callable_name(func: _AnyCallable) -> str: + return getattr(func, "__name__", type(func).__name__) + + +def _callable_qualname(func: _AnyCallable) -> str: + return getattr(func, "__qualname__", type(func).__qualname__) + + +def _resolved_flow_signature( + fn: _AnyCallable, + *, + resolved_hints: Optional[Dict[str, Any]] = None, + skip_self: bool = False, + require_return_annotation: bool = False, + annotation_error_suffix: str = "", + return_error_suffix: str = "", + function_name: Optional[str] = None, +) -> inspect.Signature: + sig = inspect.signature(fn) + resolved_hints = resolved_hints or {} + function_name = function_name or _callable_name(fn) + parameters = [] + + for name, param in sig.parameters.items(): + if skip_self and name == "self": + continue + if param.kind is inspect.Parameter.POSITIONAL_ONLY: + raise TypeError(f"Function {function_name} does not support positional-only parameter '{name}'.") + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {function_name} does not support {param.kind.description} parameter '{name}'.") + + annotation = resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation{annotation_error_suffix}") + + parameters.append(param.replace(annotation=annotation)) + + return_annotation = resolved_hints.get("return", sig.return_annotation) + if require_return_annotation and return_annotation is inspect.Signature.empty: + raise TypeError(f"Function {function_name} must have a return type annotation{return_error_suffix}") + + return sig.replace(parameters=parameters, return_annotation=return_annotation) + + +def _parse_annotation(annotation: Any) -> _ParsedAnnotation: + is_lazy = False + is_from_context = False + + while get_origin(annotation) is Annotated: + args = get_args(annotation) + annotation = args[0] + for metadata in args[1:]: + if isinstance(metadata, _LazyMarker): + is_lazy = True + elif isinstance(metadata, _FromContextMarker): + is_from_context = True + + return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context) + + +def _strip_annotated(annotation: Any) -> Any: + while get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + return annotation + + +def _is_result_annotation(annotation: Any) -> bool: + origin = get_origin(annotation) or annotation + if isinstance(origin, type) and issubclass(origin, ResultBase): + return True + + if get_origin(annotation) in _UNION_ORIGINS: + args = tuple(arg for arg in get_args(annotation) if arg is not type(None)) + return bool(args) and all(_is_result_annotation(arg) for arg in args) + + return False + + +def _context_type_annotations_compatible(func_annotation: Any, context_annotation: Any) -> bool: + func_annotation = _strip_annotated(func_annotation) + context_annotation = _strip_annotated(context_annotation) + + if func_annotation is Any: + return True + if context_annotation is Any: + return True + if func_annotation is context_annotation or func_annotation == context_annotation: + return True + + func_origin = get_origin(func_annotation) + context_origin = get_origin(context_annotation) + + if func_origin in _UNION_ORIGINS: + raw_func_args = get_args(func_annotation) + func_accepts_none = type(None) in raw_func_args + func_args = tuple(arg for arg in raw_func_args if arg is not type(None)) + if context_origin in _UNION_ORIGINS: + raw_context_args = get_args(context_annotation) + if type(None) in raw_context_args and not func_accepts_none: + return False + context_args = tuple(arg for arg in raw_context_args if arg is not type(None)) + if not context_args: + return func_accepts_none + return bool(context_args) and all( + any(_context_type_annotations_compatible(func_arg, context_arg) for func_arg in func_args) for context_arg in context_args + ) + if context_annotation is type(None): + return func_accepts_none + return any(_context_type_annotations_compatible(func_arg, context_annotation) for func_arg in func_args) + + if context_origin in _UNION_ORIGINS: + raw_context_args = get_args(context_annotation) + if type(None) in raw_context_args: + return False + context_args = tuple(arg for arg in raw_context_args if arg is not type(None)) + return bool(context_args) and all(_context_type_annotations_compatible(func_annotation, context_arg) for context_arg in context_args) + + if func_origin is Literal and context_origin is Literal: + return set(get_args(context_annotation)).issubset(set(get_args(func_annotation))) + if func_origin is Literal: + return False + if context_origin is Literal: + literal_values = get_args(context_annotation) + func_base = func_origin or func_annotation + if isinstance(func_base, type): + return all(isinstance(value, func_base) for value in literal_values) + return False + + func_base = func_origin or func_annotation + context_base = context_origin or context_annotation + if isinstance(func_base, type) and isinstance(context_base, type): + if not issubclass(context_base, func_base): + return False + elif func_base != context_base: + return False + + func_args = get_args(func_annotation) + context_args = get_args(context_annotation) + if bool(func_args) != bool(context_args): + return False + if len(func_args) != len(context_args): + return False + + for func_arg, context_arg in zip(func_args, context_args): + if get_origin(func_arg) is not None or get_origin(context_arg) is not None or isinstance(func_arg, type) or isinstance(context_arg, type): + if not _context_type_annotations_compatible(func_arg, context_arg): + return False + elif func_arg != context_arg: + return False + + return True + + +def _analyze_flow_function( + fn: _AnyCallable, + sig: inspect.Signature, + *, + is_model_dependency: Callable[[Any], bool], +) -> Tuple[_FlowModelParam, ...]: + analyzed_params = [] + + for param in sig.parameters.values(): + parsed = _parse_annotation(param.annotation) + if parsed.is_lazy and parsed.is_from_context: + raise TypeError(f"Parameter '{param.name}' cannot combine Lazy[...] and FromContext[...].") + has_default = param.default is not inspect.Parameter.empty + if parsed.is_lazy and has_default and not is_model_dependency(param.default): + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must default to a CallableModel dependency.") + if parsed.is_from_context and has_default and is_model_dependency(param.default): + raise TypeError(f"Parameter '{param.name}' is marked FromContext[...] and cannot default to a CallableModel.") + + analyzed_params.append( + _FlowModelParam( + name=param.name, + annotation=parsed.base, + kind="contextual" if parsed.is_from_context else "regular", + is_lazy=parsed.is_lazy, + has_function_default=has_default, + function_default=param.default if has_default else _UNSET, + ) + ) + + return tuple(analyzed_params) + + +def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[_FlowModelParam, ...]) -> Type[ContextBase]: + if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): + raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") + + context_fields = getattr(context_type, "model_fields", {}) + contextual_names = {param.name for param in contextual_params} + + missing = sorted(name for name in contextual_names if name not in context_fields) + if missing: + raise TypeError(f"context_type {context_type.__name__} must define fields for all FromContext parameters: {', '.join(missing)}") + + required_extra_fields = sorted( + name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in contextual_names and info.is_required() + ) + if required_extra_fields: + raise TypeError( + f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: " + f"{', '.join(required_extra_fields)}" + ) + + for param in contextual_params: + ctx_field = context_fields[param.name] + if not _context_type_annotations_compatible(param.annotation, ctx_field.annotation): + raise TypeError( + f"FromContext parameter '{param.name}' annotates {param.annotation!r}, but " + f"context_type {context_type.__name__} declares {ctx_field.annotation!r}." + ) + + return context_type + + +def _analyze_flow_model( + fn: _AnyCallable, + sig: inspect.Signature, + *, + context_type: Optional[Type[ContextBase]], + auto_unwrap: bool, + is_model_dependency: Callable[[Any], bool], +) -> _FlowModelConfig: + parameters = _analyze_flow_function(fn, sig, is_model_dependency=is_model_dependency) + reserved = sorted(param.name for param in parameters if param.name in _RESERVED_FLOW_MODEL_PARAM_NAMES) + if reserved: + names = ", ".join(repr(name) for name in reserved) + raise TypeError(f"Parameter name(s) {names} are reserved for generated model framework attributes.") + + contextual_params = tuple(param for param in parameters if param.is_contextual) + declared_context_type = None + if context_type is not None and not contextual_params: + raise TypeError("context_type=... requires FromContext[...] parameters.") + if context_type is not None: + declared_context_type = _validate_declared_context_type(context_type, contextual_params) + + if declared_context_type is not None: + updated_params = [] + context_fields = declared_context_type.model_fields + for param in parameters: + if not param.is_contextual: + updated_params.append(param) + continue + updated_params.append( + _FlowModelParam( + name=param.name, + annotation=param.annotation, + kind=param.kind, + is_lazy=param.is_lazy, + has_function_default=param.has_function_default, + function_default=param.function_default, + context_validation_annotation=context_fields[param.name].annotation, + ) + ) + parameters = tuple(updated_params) + + auto_wrap_result = not _is_result_annotation(sig.return_annotation) + result_type = GenericResult[sig.return_annotation] if auto_wrap_result else sig.return_annotation + + return _FlowModelConfig( + func=fn, + return_annotation=sig.return_annotation, + context_type=FlowContext, + result_type=result_type, + auto_wrap_result=auto_wrap_result, + auto_unwrap=auto_unwrap, + parameters=parameters, + declared_context_type=declared_context_type, + ) + + +def _analyze_flow_context_transform( + fn: _AnyCallable, + sig: inspect.Signature, + *, + is_model_dependency: Callable[[Any], bool], +) -> _FlowModelConfig: + parameters = _analyze_flow_function(fn, sig, is_model_dependency=is_model_dependency) + lazy_params = [param.name for param in parameters if param.is_lazy] + if lazy_params: + raise TypeError(f"Flow.context_transform does not support Lazy[...] parameter(s): {', '.join(lazy_params)}") + return _FlowModelConfig( + func=fn, + return_annotation=sig.return_annotation, + context_type=FlowContext, + result_type=GenericResult, + auto_wrap_result=False, + auto_unwrap=False, + parameters=parameters, + path=PyObjectPath(f"{getattr(fn, '__module__', __name__)}.{_callable_name(fn)}"), + ) + + +def _analyze_auto_context_function( + func: _AnyCallable, + *, + parent: Optional[Type[ContextBase]], + resolved_hints: Dict[str, Any], + is_model_dependency: Callable[[Any], bool], +) -> _AutoContextSpec: + sig = _resolved_flow_signature( + func, + resolved_hints=resolved_hints, + skip_self=True, + require_return_annotation=True, + annotation_error_suffix=" when auto_context=True", + return_error_suffix=" when auto_context=True", + function_name=_callable_qualname(func), + ) + base_class = parent or ContextBase + + if parent is not None: + parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) + sig_params = set(sig.parameters) + missing = parent_fields - sig_params + if missing: + raise TypeError(f"Parent context fields {missing} must be included in function signature") + + for fname in parent_fields: + parent_annotation = parent.model_fields[fname].annotation + func_annotation = sig.parameters[fname].annotation + if func_annotation is inspect.Parameter.empty: + continue + if not _context_type_annotations_compatible(func_annotation, parent_annotation): + raise TypeError( + f"auto_context field '{fname}' has annotation {func_annotation!r} which is incompatible " + f"with parent field annotation {parent_annotation!r}" + ) + + lazy_params = [name for name, param in sig.parameters.items() if _parse_annotation(param.annotation).is_lazy] + if lazy_params: + raise TypeError(f"Flow.call(auto_context=...) does not support Lazy[...] parameter(s): {', '.join(lazy_params)}") + + model_defaults = [ + name for name, param in sig.parameters.items() if param.default is not inspect.Parameter.empty and is_model_dependency(param.default) + ] + if model_defaults: + raise TypeError(f"Flow.call(auto_context=...) parameters cannot default to CallableModel dependencies: {', '.join(model_defaults)}") + + fields = {} + parent_model_fields = {} if parent is None else parent.model_fields + for name, param in sig.parameters.items(): + if name in parent_model_fields: + field_info = parent_model_fields[name] + fields[name] = (field_info.annotation, field_info) + continue + fields[name] = (param.annotation, ... if param.default is inspect.Parameter.empty else param.default) + return _AutoContextSpec( + signature=sig, + base_class=base_class, + class_name=f"{_callable_qualname(func)}_AutoContext", + fields=fields, + ) + + +def _normalize_auto_context_parent(auto_context: Any) -> Type[ContextBase]: + if auto_context is True: + return ContextBase + if inspect.isclass(auto_context) and issubclass(auto_context, ContextBase): + return auto_context + raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}") + + +def _wrap_auto_context_call( + func: _AnyCallable, + *, + parent: Type[ContextBase], + is_model_dependency: Callable[[Any], bool], +) -> _AnyCallable: + resolved_hints = get_type_hints(func, include_extras=True) + spec = _analyze_auto_context_function( + func, + parent=parent, + resolved_hints=resolved_hints, + is_model_dependency=is_model_dependency, + ) + + auto_context_class = create_ccflow_model(spec.class_name, __base__=spec.base_class, **spec.fields) + + @wraps(func) + def wrapper(self, context): + fn_kwargs = {name: getattr(context, name) for name in spec.fields} + return func(self, **fn_kwargs) + + context_default = inspect.Signature.empty + if all(not field.is_required() for field in auto_context_class.model_fields.values()): + context_default = auto_context_class() + + wrapper.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class, default=context_default), + ], + return_annotation=spec.signature.return_annotation, + ) + wrapper.__auto_context__ = auto_context_class + return wrapper diff --git a/ccflow/callable.py b/ccflow/callable.py index a5c545f..986c04b 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -13,13 +13,30 @@ import abc import logging +from dataclasses import dataclass from functools import lru_cache, wraps from inspect import Signature, isclass, signature -from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, +) from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override +from ._flow_model_binding import _normalize_auto_context_parent, _wrap_auto_context_call from .base import ( BaseModel, ContextBase, @@ -29,6 +46,9 @@ ) from .validators import str_to_log_level +if TYPE_CHECKING: + from .flow_model import FlowAPI + __all__ = ( "GraphDepType", "GraphDepList", @@ -50,6 +70,14 @@ log = logging.getLogger(__name__) +@dataclass(frozen=True) +class EvaluationDependency: + """Internal marker for a dependency invocation in an effective identity payload.""" + + model: Any + context: Any + + # ***************************************************************************** # Base CallableModel definitions, before introducing the Flow decorator or # any evaluators @@ -175,6 +203,18 @@ def __deps__( Implementations should be decorated with Flow.call. """ + def _evaluation_identity_payload( + self, + context: Any, + ) -> Optional[Any]: + """Return an effective evaluation identity payload when available. + + Returning ``None`` keeps the model on the existing structural key path. + This is intentionally narrow and internal: only models whose effective + invocation can be described declaratively should override it. + """ + return None + CallableModelType = TypeVar("CallableModelType", bound=_CallableModel) @@ -393,6 +433,23 @@ class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): """Decorator for methods on callable models""" + auto_context = kwargs.pop("auto_context", False) + + if auto_context is not False: + context_parent = _normalize_auto_context_parent(auto_context) + + def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + wrapped = _wrap_auto_context_call( + fn, + parent=context_parent, + is_model_dependency=lambda value: isinstance(value, CallableModel), + ) + return FlowOptions(**kwargs)(wrapped) + + if len(args) == 1 and callable(args[0]): + return auto_context_decorator(args[0]) + return auto_context_decorator + if len(args) == 1 and callable(args[0]): # No arguments to decorator, this is the decorator fn = args[0] @@ -418,6 +475,79 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def model(*args, **kwargs): + """Decorator that generates a CallableModel class from a plain Python function. + + The generated model participates in the normal CallableModel execution + path, including evaluation, caching, dependency discovery, registry use, + and serialization. The function signature declares both the model's + construction-time inputs and its runtime/contextual inputs. + + Args: + context_type: Optional ContextBase subclass used only to validate/coerce + `FromContext[...]` inputs against an existing nominal context shape + auto_unwrap: When True, `.flow.compute(...)` unwraps auto-wrapped + `GenericResult(value=...)` outputs back to the annotated return type. + Explicit `ResultBase` returns are left unchanged. Default: False. + model_base: Optional custom `CallableModel` subclass to use as an + additional base for the generated model class. + cacheable: Enable caching of results (default: False) + volatile: Mark as volatile (default: False) + log_level: Logging verbosity (default: logging.DEBUG) + validate_result: Validate return type (default: True) + verbose: Verbose logging output (default: True) + evaluator: Custom evaluator (default: None) + + Primary authoring model: + Mark runtime/contextual inputs explicitly with `FromContext[...]`. + Ordinary unmarked parameters are regular bound inputs and are never + read implicitly from the runtime context. + + @Flow.model + def greeting(prefix: str, name: FromContext[str]) -> str: + return f"{prefix}, {name}" + + model = greeting(prefix="Hello") + assert model.flow.compute(name="Ada").value == "Hello, Ada" + + Dependencies: + Any ordinary parameter can be bound either to a literal value or + to another CallableModel. When a CallableModel is supplied, the + generated model treats it as an upstream dependency and resolves it + with the current context before calling the underlying function. + + `FromContext[...]` parameters are different: they may be satisfied by + runtime context, construction-time contextual defaults, or function + defaults, but not by CallableModel values. + + Usage: + @Flow.model + def length(text: FromContext[str]) -> int: + return len(text) + + @Flow.model + def score(base: int, bonus: FromContext[int]) -> int: + return base + bonus + + model = score(base=length()) + result = model.flow.compute(text="abcd", bonus=3) + assert result.value == 7 + + Returns: + A factory function that creates CallableModel instances + """ + from .flow_model import flow_model + + return flow_model(*args, **kwargs) + + @staticmethod + def context_transform(*args, **kwargs): + """Decorator that turns a top-level function into a serializable with_context() transform factory.""" + from .flow_model import flow_context_transform + + return flow_context_transform(*args, **kwargs) + # ***************************************************************************** # Define "Evaluators" and associated types @@ -672,6 +802,13 @@ def __deps__( """ return [] + @property + def flow(self) -> "FlowAPI": + """Access flow helpers for execution, context transforms, and introspection.""" + from .flow_model import FlowAPI + + return FlowAPI(self) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. diff --git a/ccflow/context.py b/ccflow/context.py index 9a04fad..56a3774 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,16 +1,18 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" +from collections.abc import Mapping from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, PrivateAttr, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency from .validators import normalize_date, normalize_datetime __all__ = ( + "FlowContext", "NullContext", "GenericContext", "DateContext", @@ -89,6 +91,60 @@ # Starting 0.8.0 Nullcontext is an alias to ContextBase NullContext = ContextBase + +class FlowContext(ContextBase): + """Universal context for @Flow.model functions. + + Instead of generating a new ContextBase subclass for each @Flow.model, + this single class with extra="allow" serves as the universal carrier. + Validation happens via TypedDict + TypeAdapter at compute() time. + + This design avoids: + - Proliferation of dynamic _funcname_Context classes + - Class registration overhead for serialization + - Pickling issues with Ray/distributed computing + """ + + model_config = ConfigDict(extra="allow", frozen=True) + _frozen_hash_key: Hashable | None = PrivateAttr(default=None) + _hash_value: int | None = PrivateAttr(default=None) + + def _hash_key(self) -> Hashable: + if self._frozen_hash_key is None: + self._frozen_hash_key = _freeze_for_hash(self.model_dump(mode="python")) + return self._frozen_hash_key + + def __eq__(self, other: Any) -> bool: + if self is other: + return True + if not isinstance(other, FlowContext): + return False + return self._hash_key() == other._hash_key() + + def __hash__(self) -> int: + if self._hash_value is None: + self._hash_value = hash(self._hash_key()) + return self._hash_value + + +def _freeze_for_hash(value: Any) -> Hashable: + if isinstance(value, Mapping): + return tuple(sorted(((key, _freeze_for_hash(item)) for key, item in value.items()), key=lambda item: repr(item[0]))) + if isinstance(value, (list, tuple)): + return tuple(_freeze_for_hash(item) for item in value) + if isinstance(value, (set, frozenset)): + return frozenset(_freeze_for_hash(item) for item in value) + if hasattr(value, "model_dump"): + return (type(value), _freeze_for_hash(value.model_dump(mode="python"))) + try: + hash(value) + except TypeError as exc: + if hasattr(value, "__dict__"): + return (type(value), _freeze_for_hash(vars(value))) + raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}: {value!r}") from exc + return value + + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/evaluators/common.py b/ccflow/evaluators/common.py index 2cab984..7baa529 100644 --- a/ccflow/evaluators/common.py +++ b/ccflow/evaluators/common.py @@ -7,18 +7,21 @@ from types import MappingProxyType from typing import Any, Callable, Dict, List, Optional, Set, Union -from pydantic import Field, PrivateAttr, field_validator +from pydantic import Field, PrivateAttr, ValidationError, field_validator from typing_extensions import override from ..base import BaseModel, make_lazy_result from ..callable import ( CallableModel, ContextBase, + EvaluationDependency, EvaluatorBase, ModelEvaluationContext, ResultType, TransparentModelEvaluationContext, + WrapperModel, ) +from ..utils.tokenize import compute_cache_token, compute_data_token __all__ = [ "cache_key", @@ -36,6 +39,14 @@ log = logging.getLogger(__name__) +class _EffectiveEvaluationKeyUnavailable(Exception): + """Internal signal to use the existing structural evaluation key.""" + + +_EFFECTIVE_IDENTITY_DECLINED_ERRORS = (TypeError, ValueError, ValidationError) +_EFFECTIVE_EVALUATION_KEY_VERSION = "ccflow_effective_evaluation_key_v1" + + def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[EvaluatorBase]) -> EvaluatorBase: """Helper function to combine evaluators into a new evaluator. @@ -226,21 +237,180 @@ def _format_result(self, result: ResultType) -> str: return f"{msg_str}{pformat(result_dict, **self.format_config.pformat_config)}" -def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel]) -> bytes: - """Returns a key suitable for use in caching and dependency graph deduplication. +def _unwrap_evaluation_context(evaluation_context: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[CallableModel]]: + """Strip transparent evaluator wrappers and keep opaque wrappers in order. + + This preserves the existing structural cache-key behavior: transparent + evaluators are ignored, while non-transparent evaluators remain part of the + identity. The returned function name is the innermost non-``__call__`` name, + so ``__deps__`` does not collapse into ``__call__`` when wrapped. + """ + fn = evaluation_context.fn + outer_to_inner_evaluators = [] + while isinstance(evaluation_context.context, ModelEvaluationContext): + fn = evaluation_context.fn if evaluation_context.fn != "__call__" else fn + if not isinstance(evaluation_context, TransparentModelEvaluationContext): + outer_to_inner_evaluators.append(evaluation_context.model) + evaluation_context = evaluation_context.context + return evaluation_context, fn if fn != "__call__" else evaluation_context.fn, outer_to_inner_evaluators + + +def _evaluator_identity_payload(outer_to_inner_evaluators: List[CallableModel]) -> List[Dict[str, Any]]: + return [evaluator.model_dump(mode="python") for evaluator in outer_to_inner_evaluators] + + +def _memo_token(model: CallableModel, context: Any) -> tuple[int, str]: + if hasattr(context, "model_dump"): + context_value = context.model_dump(mode="python") + else: + context_value = context + return (id(model), compute_data_token((type(context), context_value))) + + +def _effective_model_key( + model: CallableModel, + context: Any, + memo: Dict[tuple[int, str], bytes], + active: Set[tuple[int, str]], +) -> Optional[bytes]: + """Return a model's opt-in effective key, or ``None`` for normal opt-out. + + Plain ``CallableModel`` instances opt out by returning ``None`` from + ``_evaluation_identity_payload()``. Dependency invocations inside the + payload are resolved by ``_resolve_effective_identity_payload()`` so models + declare what matters without constructing recursive keys themselves. + """ + token = _memo_token(model, context) + if token in memo: + return memo[token] + if token in active: + raise _EffectiveEvaluationKeyUnavailable("recursive effective identity") + + active.add(token) + try: + try: + payload = model._evaluation_identity_payload(context) + except _EFFECTIVE_IDENTITY_DECLINED_ERRORS as exc: + # Identity derivation runs before the actual call and may encounter + # the same validation failures as evaluation context construction. + # Falling back preserves existing behavior instead of turning key + # computation into a new failure mode for ordinary models. + raise _EffectiveEvaluationKeyUnavailable(str(exc)) from exc + # For normal CallableModels, `_evaluation_identity_payload` defaults to + # None, so we should hit this path + if payload is None: + return None + payload = _resolve_effective_identity_payload(payload, memo, active) + key = compute_cache_token( + data_values=[(_EFFECTIVE_EVALUATION_KEY_VERSION, payload)], + behavior_classes=[type(model)], + ).encode("utf-8") + memo[token] = key + return key + finally: + active.discard(token) + + +def _resolve_effective_identity_payload( + value: Any, + memo: Dict[tuple[int, str], bytes], + active: Set[tuple[int, str]], +) -> Any: + """Replace dependency invocation markers with recursive effective keys.""" + if isinstance(value, EvaluationDependency): + try: + evaluation = value.model.__call__.get_evaluation_context(value.model, value.context) + except _EFFECTIVE_IDENTITY_DECLINED_ERRORS as exc: + raise _EffectiveEvaluationKeyUnavailable(f"dependency {type(value.model).__name__} could not build evaluation context: {exc}") from exc + return _effective_evaluation_key(evaluation, memo=memo, active=active) + if isinstance(value, dict): + return {key: _resolve_effective_identity_payload(item, memo, active) for key, item in value.items()} + if isinstance(value, list): + return [_resolve_effective_identity_payload(item, memo, active) for item in value] + if isinstance(value, tuple): + return tuple(_resolve_effective_identity_payload(item, memo, active) for item in value) + return value + + +def _effective_evaluation_key( + evaluation_context: ModelEvaluationContext, + memo: Optional[Dict[tuple[int, str], bytes]] = None, + active: Optional[Set[tuple[int, str]]] = None, +) -> bytes: + """Use opt-in effective identity for ``__call__``; otherwise preserve ``cache_key()``.""" + memo = {} if memo is None else memo + active = set() if active is None else active + inner, fn, outer_to_inner_evaluators = _unwrap_evaluation_context(evaluation_context) + if fn != "__call__": + # Keep non-call evaluations, especially ``__deps__``, on the exact + # public structural key path. Effective identity is only meant to + # narrow normal model execution where generated ``@Flow.model`` code + # knows which ambient ``FlowContext`` fields affect the result. + return cache_key(evaluation_context) + if outer_to_inner_evaluators: + # Non-transparent evaluators can inspect the full ModelEvaluationContext + # and change the returned value based on ambient context fields that an + # opt-in model would otherwise ignore. Use the structural key whenever + # such an evaluator is part of the call chain; missing an optimization is + # preferable to returning a value cached under a narrower model identity. + return cache_key(evaluation_context) + + try: + key = _effective_model_key(inner.model, inner.context, memo, active) + except _EffectiveEvaluationKeyUnavailable as exc: + # Effective identity is an optimization/semantic narrowing for opt-in + # generated models. If deriving it is unclear, do not make cache/graph + # key construction a new failure mode; use the old structural key. + log.debug("Falling back to structural evaluation key for %s.__call__: %s", type(inner.model).__name__, exc) + return cache_key(evaluation_context) + if key is None: + # This is the ordinary path for existing CallableModel classes. The + # base implementation returns None, so their cache and graph identities + # remain byte-for-byte equivalent to ``cache_key(evaluation_context)``. + return cache_key(evaluation_context) + + # Preserve the existing evaluation-context semantics around the narrowed + # model key: options still distinguish evaluations and transparent + # evaluators are ignored. + return compute_cache_token( + data_values=[ + ( + _EFFECTIVE_EVALUATION_KEY_VERSION, + "evaluation_context", + { + "fn": fn, + "options": inner.options, + }, + key, + ) + ], + ).encode("utf-8") + + +def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel], *, effective: bool = False) -> bytes: + """Returns a key suitable for caching and dependency graph deduplication. For ``ModelEvaluationContext`` inputs, strips ``TransparentModelEvaluationContext`` layers (evaluators that don't modify the return value) so that the key depends only on the underlying model, context, fn, options, and any non-transparent evaluators in the chain. + By default, this key is structural. Passing ``effective=True`` for a + ``ModelEvaluationContext`` enables generated-model effective identity, which + lets opt-in models ignore unused ambient context fields. Non-opt-in models + and non-evaluation inputs still use the structural key. + When the underlying model has callable methods, a behavior token (SHA-256 of method bytecode) is included so that code changes invalidate the cache. Args: flow_obj: The object to be tokenized to form the cache key. + effective: Whether to use generated-model effective identity for model + evaluations that opt into it. Defaults to ``False`` to preserve the + public structural semantics. """ - from ..utils.tokenize import compute_cache_token + if effective and isinstance(flow_obj, ModelEvaluationContext): + return _effective_evaluation_key(flow_obj) if isinstance(flow_obj, ModelEvaluationContext): flow_obj, fn, non_transparent = _flatten_cache_key_context(flow_obj) @@ -273,9 +443,11 @@ def is_transparent(self, context: ModelEvaluationContext) -> bool: def key(self, context: ModelEvaluationContext): """Function to convert a ModelEvaluationContext to a cache key. - Delegates to ``cache_key()`` which strips transparent evaluator layers. + Generated ``@Flow.model`` instances can use a narrower effective key + that ignores unused ambient context fields; ordinary ``CallableModel`` + paths fall back to the structural key. """ - return cache_key(context) + return cache_key(context, effective=True) @property def cache(self): @@ -311,22 +483,64 @@ class CallableModelGraph(BaseModel): root_id: bytes -def _build_dependency_graph(evaluation_context: ModelEvaluationContext, graph: CallableModelGraph, parent_key: Optional[bytes] = None): - key = cache_key(evaluation_context) - if parent_key: +def _is_wrapper_to_wrapped_edge(parent_model: Optional[CallableModel], current_model: CallableModel) -> bool: + # Effective identity can intentionally collapse a wrapper model and its + # wrapped model to the same graph/cache key. Only that wrapper-to-wrapped + # edge should be treated as a duplicate self-edge. + return isinstance(parent_model, WrapperModel) and parent_model.model is current_model + + +def _build_dependency_graph( + evaluation_context: ModelEvaluationContext, + graph: CallableModelGraph, + parent_key: Optional[bytes] = None, + parent_model: Optional[CallableModel] = None, +): + # Generated/bound ``@Flow.model`` nodes can use effective identity so unused + # ambient FlowContext fields do not split the graph. Normal CallableModel + # nodes opt out and therefore still receive ``cache_key(evaluation_context)``. + key = _effective_evaluation_key(evaluation_context) + unwrapped_evaluation_context, _, _ = _unwrap_evaluation_context(evaluation_context) + current_model = unwrapped_evaluation_context.model + is_same_evaluation_key = parent_key == key + is_collapsed_wrapper_child = is_same_evaluation_key and _is_wrapper_to_wrapped_edge(parent_model, current_model) + + # Bound/wrapper models can share an effective graph key with their wrapped + # model after context rewriting. Adding the wrapper -> wrapped edge in that + # case would create a fake self-loop in the public graph because both ends + # have the same key. Suppress only that edge; real cycles between ordinary + # models are still recorded. + if parent_key and not is_collapsed_wrapper_child: graph.graph[parent_key].add(key) if key not in graph.ids: graph.ids[key] = evaluation_context - if key not in graph.graph: + is_new_graph_key = key not in graph.graph + if is_new_graph_key: graph.graph[key] = set() - # Note that __deps__ will be evaluated using whatever evaluator is configured for the model, - # which could include logging, caching, etc. - deps = evaluation_context.model.__deps__(evaluation_context.context) - # Sequential evaluation of dependencies of dependencies (could have other implementations) - for model, contexts in deps: - for context in contexts: - sub_evaluation_context = model.__call__.get_evaluation_context(model, context) - _build_dependency_graph(sub_evaluation_context, graph, parent_key=key) + + # Main used ``key not in graph.graph`` as the traversal guard. That is no + # longer enough once effective identity can merge multiple model objects + # to one key: a bound wrapper and its wrapped model may share the graph node, + # but the wrapped model still has dependencies that must be traversed. + # + # Preserve normal graph deduplication by key, and make the only exception + # the exact same-key wrapper -> wrapped edge. + if not is_new_graph_key and not is_collapsed_wrapper_child: + return + + # Note that __deps__ will be evaluated using whatever evaluator is configured for the model, + # which could include logging, caching, etc. + deps = evaluation_context.model.__deps__(evaluation_context.context) + # Recursively walk dependency contexts depth-first to build the complete graph. + for model, contexts in deps: + for context in contexts: + sub_evaluation_context = model.__call__.get_evaluation_context(model, context) + _build_dependency_graph( + sub_evaluation_context, + graph, + parent_key=key, + parent_model=current_model, + ) def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> CallableModelGraph: @@ -335,7 +549,11 @@ def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> Callable Args: evaluation_context: The model and context to build the graph for. """ - root_key = cache_key(evaluation_context) + # Keep the root id on the same identity function used for every graph node. + # For existing models this is still ``cache_key(evaluation_context)``; for + # generated flow models it is the narrowed key that ignores unused ambient + # context fields. + root_key = _effective_evaluation_key(evaluation_context) graph = CallableModelGraph(ids={}, graph={}, root_id=root_key) _build_dependency_graph(evaluation_context, graph) return graph diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py new file mode 100644 index 0000000..c358cec --- /dev/null +++ b/ccflow/flow_model.py @@ -0,0 +1,1969 @@ +"""Generated ``@Flow.model`` implementation. + +This module is intentionally the owner of the generated-model API. The public +surface is small: + +* ``@Flow.model`` turns a typed Python function into a ``CallableModel`` factory. +* ``FromContext[T]`` marks function parameters that should come from runtime + context instead of model construction. +* ``Lazy[T]`` marks a dependency that should be passed as a thunk and evaluated + only if user code calls it. +* ``model.flow.compute(...)`` and ``model.flow.with_context(...)`` provide the + ergonomic execution and contextual binding API. + +The implementation has four moving parts that should stay conceptually +separate: + +* Signature analysis lives in ``_flow_model_binding.py`` so ``callable.py`` does + not become a generated-model implementation module. +* Runtime context construction turns arbitrary context objects/kwargs into the + narrow context shape required by a target model. +* Context bindings are represented as serializable specs, then applied at + execution time before the wrapped model validates its context. +* Effective identity describes the parts of generated and bound model + invocations that are known before evaluation so cache/graph keys can ignore + unused ambient ``FlowContext`` fields. + +Two invariants matter more than cleverness here: + +* Existing ``CallableModel`` behavior must remain structural unless a generated + or bound model explicitly opts into effective identity. +* Public ``cache_key(...)`` stays structural by default; evaluators and graph + construction can request effective identity explicitly for opt-in models. + +File layout: + +1. Internal data structures and small value helpers. +2. Type coercion, result unwrapping, and registry-reference helpers. +3. Context-transform and generated-model serialization helpers. +4. Runtime context contracts and dependency context projection. +5. Contextual value resolution and effective identity. +6. ``with_context`` binding validation/application. +7. ``model.flow`` APIs and ``BoundModel``. +8. Generated model class construction and decorators. +""" + +import inspect +import logging +import sys +from base64 import b64decode, b64encode +from collections import OrderedDict +from functools import lru_cache, wraps +from typing import ( + Annotated, + Any, + Callable, + ClassVar, + Dict, + List, + Literal, + Mapping, + NamedTuple, + Optional, + Set, + Tuple, + Type, + cast, + get_args, + get_origin, + get_type_hints, +) + +from pydantic import BaseModel as PydanticModel, Field, TypeAdapter, ValidationError, model_validator +from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation + +from ._flow_model_binding import ( + _REMOVED_CONTEXT_ARGS, + _UNION_ORIGINS, + _UNSET, + FromContext, + Lazy, + _analyze_flow_context_transform, + _analyze_flow_model, + _callable_name, + _FlowModelConfig, + _FlowModelParam, + _resolved_flow_signature, + _strip_annotated, +) +from .base import BaseModel, ContextBase, ContextType, ResultBase +from .callable import CallableModel, EvaluationDependency, Flow, FlowOptions, GraphDepList, WrapperModel +from .context import FlowContext +from .exttypes import PyObjectPath +from .local_persistence import register_ccflow_import_path +from .result import GenericResult + +__all__ = ( + "FlowAPI", + "BoundModel", + "FromContext", + "Lazy", + "ContextTransform", + "clear_flow_model_caches", + "flow_context_transform", +) + +_AnyCallable = Callable[..., Any] +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Internal data structures +# --------------------------------------------------------------------------- + + +class _UnsetFlowInput: + def __repr__(self) -> str: + return "" + + def __reduce__(self): + return (_unset_flow_input_factory, ()) + + +_UNSET_FLOW_INPUT = _UnsetFlowInput() +_TYPE_ADAPTER_CACHE_MAXSIZE = 256 +_HASHABLE_TYPE_ADAPTER_CACHE: "OrderedDict[Any, TypeAdapter]" = OrderedDict() +_UNHASHABLE_TYPE_ADAPTER_CACHE: "OrderedDict[int, Tuple[Any, TypeAdapter]]" = OrderedDict() + + +def _unset_flow_input_factory() -> _UnsetFlowInput: + return _UNSET_FLOW_INPUT + + +def _is_unset_flow_input(value: Any) -> bool: + return value is _UNSET_FLOW_INPUT + + +_ModelContextContract = NamedTuple( + "_ModelContextContract", + [ + ("runtime_context_type", Type[ContextBase]), + ("input_types", Optional[Dict[str, Any]]), + ("required_names", Tuple[str, ...]), + ("generated_model", Optional["_GeneratedFlowModelBase"]), + ], +) + + +class ContextTransform(PydanticModel): + """Serializable binding produced by ``@Flow.context_transform``. + + Importable top-level transforms are stored by import path. Local/nested + transforms fall back to a cloudpickled config payload so bound models can + survive pickle, cloudpickle, and Ray round trips. + """ + + kind: Literal["context_transform"] = "context_transform" + path: Optional[PyObjectPath] = None + serialized_config: Optional[str] = None + bound_args: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def _validate_location(self): + if (self.path is None) == (self.serialized_config is None): + raise ValueError("ContextTransform must define exactly one of path or serialized_config.") + return self + + +class StaticValueSpec(PydanticModel): + """A ``with_context(field=value)`` static contextual override.""" + + kind: Literal["static_value"] = "static_value" + value: Any + + +class FieldContextSpec(PydanticModel): + """A ``with_context(field=transform(...))`` contextual override.""" + + kind: Literal["context_value"] = "context_value" + binding: ContextTransform + + +class PatchContextSpec(PydanticModel): + """A positional ``with_context(transform(...))`` mapping patch.""" + + kind: Literal["context_patch"] = "context_patch" + binding: ContextTransform + + +_FieldOverrideSpec = StaticValueSpec | FieldContextSpec + + +class _BoundContextSpec(PydanticModel): + """Normalized, serializable representation of all context bindings.""" + + patches: List[PatchContextSpec] = Field(default_factory=list) + field_overrides: Dict[str, _FieldOverrideSpec] = Field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Small value helpers +# --------------------------------------------------------------------------- + + +def _context_values(context: ContextBase) -> Dict[str, Any]: + return dict(context) + + +def _context_transform_repr(transform: Any) -> str: + if isinstance(transform, ContextTransform): + name = _callable_name(_load_context_transform_config_from_binding(transform).func) + if not transform.bound_args: + return f"{name}()" + args = ", ".join(f"{key}={value!r}" for key, value in sorted(transform.bound_args.items())) + return f"{name}({args})" + return repr(transform) + + +def _context_transform_identifier(binding: ContextTransform) -> str: + if binding.path is not None: + return str(binding.path) + return _callable_name(_load_context_transform_config_from_binding(binding).func) + + +def _is_model_dependency(value: Any) -> bool: + return isinstance(value, CallableModel) + + +def _bound_field_names(model: Any) -> set[str]: + fields_set = getattr(model, "model_fields_set", None) + if fields_set is not None: + return set(fields_set) + return set() + + +def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: + if isinstance(context_type, type) and issubclass(context_type, ContextBase): + return context_type + + if get_origin(context_type) in _UNION_ORIGINS: + for arg in get_args(context_type): + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, ContextBase): + return arg + + return None + + +# --------------------------------------------------------------------------- +# Type coercion, lazy thunks, and registry references +# --------------------------------------------------------------------------- + + +def _remember_type_adapter(cache: "OrderedDict[Any, Any]", key: Any, value: Any) -> Any: + cache[key] = value + cache.move_to_end(key) + if len(cache) > _TYPE_ADAPTER_CACHE_MAXSIZE: + cache.popitem(last=False) + return value + + +def _type_adapter(annotation: Any) -> TypeAdapter: + try: + adapter = _HASHABLE_TYPE_ADAPTER_CACHE.pop(annotation) + except TypeError: + key = id(annotation) + cached = _UNHASHABLE_TYPE_ADAPTER_CACHE.pop(key, None) + if cached is not None and cached[0] is annotation: + _UNHASHABLE_TYPE_ADAPTER_CACHE[key] = cached + return cached[1] + adapter = TypeAdapter(annotation) + return _remember_type_adapter(_UNHASHABLE_TYPE_ADAPTER_CACHE, key, (annotation, adapter))[1] + except KeyError: + adapter = TypeAdapter(annotation) + return _remember_type_adapter(_HASHABLE_TYPE_ADAPTER_CACHE, annotation, adapter) + _HASHABLE_TYPE_ADAPTER_CACHE[annotation] = adapter + return adapter + + +def _can_validate_type(annotation: Any) -> bool: + try: + _type_adapter(annotation) + except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation, TypeError, ValueError): + return False + return True + + +def _expected_type_repr(annotation: Any) -> str: + try: + return annotation.__name__ + except AttributeError: + return repr(annotation) + + +def _coerce_value(name: str, value: Any, annotation: Any, source: str) -> Any: + if not _can_validate_type(annotation): + return value + try: + return _type_adapter(annotation).validate_python(value) + except (ValidationError, ValueError, TypeError) as exc: + expected = _expected_type_repr(annotation) + raise TypeError(f"{source} '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc + + +def _unwrap_model_result(value: Any) -> Any: + if isinstance(value, GenericResult): + return value.value + return value + + +def _make_lazy_thunk(value: CallableModel, context: ContextBase) -> Callable[[], Any]: + cache: Dict[str, Any] = {} + + def thunk(): + if "result" not in cache: + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + cache["result"] = _unwrap_model_result(dependency_model(dependency_context)) + return cache["result"] + + return thunk + + +def _make_coercing_lazy_thunk(inner_thunk: Callable[[], Any], name: str, annotation: Any) -> Callable[[], Any]: + cache: Dict[str, Any] = {} + + def thunk(): + if "result" not in cache: + cache["result"] = _coerce_value(name, inner_thunk(), annotation, "Regular parameter") + return cache["result"] + + return thunk + + +def _maybe_auto_unwrap_external_result(target: CallableModel, result: Any) -> Any: + generated = _generated_model_instance(target) + if generated is None: + return result + + config = type(generated).__flow_model_config__ + if config.auto_wrap_result and config.auto_unwrap: + return _unwrap_model_result(result) + return result + + +def _type_accepts_str(annotation: Any) -> bool: + if annotation is Any or annotation is inspect.Parameter.empty: + return True + if annotation is str: + return True + origin = get_origin(annotation) + if origin is Annotated: + return _type_accepts_str(get_args(annotation)[0]) + if origin in _UNION_ORIGINS: + return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) + return False + + +def _resolve_registry_candidate(value: str) -> Any: + try: + candidate = BaseModel.model_validate(value) + except ValidationError: + return None + return candidate if isinstance(candidate, BaseModel) else None + + +def _registry_candidate_allowed(expected_type: Any, candidate: Any) -> bool: + if _is_model_dependency(candidate): + return True + if not _can_validate_type(expected_type): + return True + try: + _type_adapter(expected_type).validate_python(candidate) + except ValidationError: + return False + return True + + +def _ensure_top_level_named_function(fn: _AnyCallable, *, decorator_name: str) -> None: + if not inspect.isfunction(fn): + raise TypeError(f"{decorator_name} only supports Python functions.") + + name = getattr(fn, "__name__", "") + if name == "": + raise TypeError(f"{decorator_name} only supports named Python functions.") + + +# --------------------------------------------------------------------------- +# Context-transform serialization and generated-model persistence +# --------------------------------------------------------------------------- + + +@lru_cache(maxsize=None) +def _load_context_transform_factory(path: str) -> _AnyCallable: + return PyObjectPath(path).object + + +@lru_cache(maxsize=None) +def _load_context_transform_config(path: str) -> _FlowModelConfig: + factory = _load_context_transform_factory(path) + config = getattr(factory, "__flow_context_transform_config__", None) + if not isinstance(config, _FlowModelConfig): + raise TypeError(f"Stored context transform path '{path}' does not resolve to a Flow.context_transform binding.") + return config + + +def _serialize_context_transform_config(config: _FlowModelConfig) -> str: + import cloudpickle + + payload = cloudpickle.dumps(config, protocol=5) + return b64encode(payload).decode("ascii") + + +@lru_cache(maxsize=None) +def _load_serialized_context_transform_config(serialized_config: str) -> _FlowModelConfig: + import cloudpickle + + config = cloudpickle.loads(b64decode(serialized_config.encode("ascii"))) + if not isinstance(config, _FlowModelConfig): + raise TypeError("Stored context transform payload does not contain a Flow.context_transform binding.") + return config + + +def _load_context_transform_config_from_binding(binding: ContextTransform) -> _FlowModelConfig: + if binding.path is not None: + return _load_context_transform_config(str(binding.path)) + if binding.serialized_config is None: + raise TypeError("ContextTransform has neither path nor serialized_config.") + return _load_serialized_context_transform_config(binding.serialized_config) + + +def clear_flow_model_caches() -> None: + """Clear module-level caches used by Flow.model internals.""" + + _HASHABLE_TYPE_ADAPTER_CACHE.clear() + _UNHASHABLE_TYPE_ADAPTER_CACHE.clear() + _load_context_transform_factory.cache_clear() + _load_context_transform_config.cache_clear() + _load_serialized_context_transform_config.cache_clear() + + +def _is_mapping_annotation(annotation: Any) -> bool: + if annotation is inspect.Signature.empty: + return False + annotation = _strip_annotated(annotation) + origin = get_origin(annotation) + if origin in _UNION_ORIGINS: + variants = [arg for arg in get_args(annotation) if arg is not type(None)] + return bool(variants) and all(_is_mapping_annotation(arg) for arg in variants) + origin = origin or annotation + try: + return issubclass(origin, Mapping) + except TypeError: + return False + + +def _restore_pickled_flow_model(type_path: str, state: Dict[str, Any]) -> BaseModel: + cls = cast(type[BaseModel], PyObjectPath(type_path).object) + instance = cls.__new__(cls) + instance.__setstate__(state) + return instance + + +def _restore_pickled_local_flow_model(serialized_factory_payload: bytes, state: Dict[str, Any]) -> BaseModel: + import cloudpickle + + fn, factory_kwargs = cloudpickle.loads(serialized_factory_payload) + factory = flow_model(fn, **factory_kwargs) + cls = cast(type[BaseModel], getattr(factory, "_generated_model")) + instance = cls.__new__(cls) + instance.__setstate__(state) + return instance + + +def _restore_generated_flow_model(factory_path: str, state: Dict[str, Any]) -> BaseModel: + """Restore a generated flow model by importing its factory function. + + This is the cross-process-safe restore path: importing the factory's module + triggers the ``@Flow.model`` decorator, which re-creates the GeneratedModel + class. We then reconstruct the instance from the pickled state. + """ + factory = PyObjectPath(factory_path).object + generated_cls = getattr(factory, "_generated_model", None) + if generated_cls is None: + raise ImportError(f"Cannot restore generated flow model: '{factory_path}' does not have a _generated_model attribute.") + instance = generated_cls.__new__(generated_cls) + instance.__setstate__(state) + return instance + + +def _is_importable_function(func: _AnyCallable) -> bool: + """Return True if *func* is a top-level, importable named function.""" + module = getattr(func, "__module__", None) + name = getattr(func, "__name__", None) + qualname = getattr(func, "__qualname__", None) + return bool(module and module != "__main__" and name and qualname and qualname == name and "" not in qualname) + + +def _importable_function_path(func: _AnyCallable) -> Optional[str]: + if not _is_importable_function(func): + return None + return f"{func.__module__}.{func.__name__}" + + +def _generated_model_factory_path_for_pickle(config: _FlowModelConfig, generated_cls: type) -> Optional[str]: + path = _importable_function_path(config.func) + if path is None: + return None + try: + factory = PyObjectPath(path).object + except ImportError: + return None + if getattr(factory, "_generated_model", None) is generated_cls: + return path + return None + + +def _register_generated_model_class(config: _FlowModelConfig, generated_cls: type) -> None: + """Make generated classes importable when their factory function is importable. + + Importable module-level ``@Flow.model`` functions should serialize by a + stable module path. Local, nested, and ``__main__`` definitions still use + local-persistence registration because there is no durable import path for + their generated class. + """ + + if _importable_function_path(config.func) is None: + register_ccflow_import_path(generated_cls) + return + + module_name = getattr(config.func, "__module__", None) + module = sys.modules.get(module_name or "") + qualname = getattr(generated_cls, "__qualname__", "") + if module is None or not qualname or "" in qualname: + register_ccflow_import_path(generated_cls) + return + + obj = module + parts = qualname.split(".") + for part in parts[:-1]: + obj = getattr(obj, part, None) + if obj is None: + register_ccflow_import_path(generated_cls) + return + setattr(obj, parts[-1], generated_cls) + + +def _context_transform_should_use_import_path(config: _FlowModelConfig) -> bool: + path = config.path + if path is None or not _is_importable_function(config.func): + return False + try: + resolved = PyObjectPath(str(path)).object + except ImportError: + return True + return isinstance(getattr(resolved, "__flow_context_transform_config__", None), _FlowModelConfig) + + +# --------------------------------------------------------------------------- +# Runtime context contracts and dependency projection +# --------------------------------------------------------------------------- + + +def _runtime_context_for_model(model: CallableModel, values: Dict[str, Any]) -> ContextBase: + """Build the runtime context object expected by ``model`` from raw values.""" + + contract = _model_context_contract(model) + if contract.runtime_context_type is FlowContext: + return FlowContext(**values) + return contract.runtime_context_type.model_validate(values) + + +def _project_context_values_for_model(model: CallableModel, values: Dict[str, Any]) -> Dict[str, Any]: + """Keep only the context fields a target model knows how to consume. + + Generated models and ordinary ``ContextBase`` subclasses have declared input + fields. ``FlowContext`` and opaque context types do not, so their context is + forwarded unchanged. + """ + + contract = _model_context_contract(model) + if contract.runtime_context_type is FlowContext or contract.input_types is None: + return values + return {name: values[name] for name in contract.input_types if name in values} + + +def _dependency_context_values(model: CallableModel, context: ContextBase) -> Dict[str, Any]: + return _project_context_values_for_model(model, _context_values(context)) + + +def _dependency_context_for_model(model: CallableModel, context: ContextBase) -> ContextBase: + return _runtime_context_for_model(model, _dependency_context_values(model, context)) + + +def _resolved_dependency_invocation(value: CallableModel, context: ContextBase) -> Tuple[CallableModel, ContextBase]: + """Return the concrete ``(model, context)`` pair for a dependency call. + + Bound models must receive the full ambient ``FlowContext`` so their binding + transforms can read source fields before narrowing to the wrapped model's + context. Unbound dependencies can be projected immediately. + """ + + if isinstance(value, BoundModel): + return value, FlowContext(**_context_values(context)) + return value, _dependency_context_for_model(value, context) + + +def _merge_context_specs( + existing: _BoundContextSpec, patches: List[PatchContextSpec], field_overrides: Dict[str, _FieldOverrideSpec] +) -> _BoundContextSpec: + return _BoundContextSpec( + patches=[*existing.patches, *patches], + field_overrides={**existing.field_overrides, **field_overrides}, + ) + + +def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: + model = stage.model if isinstance(stage, BoundModel) else stage + if isinstance(model, _GeneratedFlowModelBase): + return model + return None + + +def _model_context_contract( + model: CallableModel, +) -> _ModelContextContract: + """Describe how ``model`` consumes runtime context. + + This is the central adapter between generated models, plain CallableModels, + optional/opaque context annotations, and ``FlowContext``. Callers use it to + decide which fields are contextual inputs, which are required, and which + concrete ``ContextBase`` subclass should validate runtime values. + """ + + generated = _generated_model_instance(model) + if generated is not None: + config = type(generated).__flow_model_config__ + return _ModelContextContract(FlowContext, dict(config.context_input_types), config.context_required_names, generated) + + context_cls = _concrete_context_type(model.context_type) + if context_cls is None: + return _ModelContextContract(FlowContext, None, (), None) + if context_cls is FlowContext or not hasattr(context_cls, "model_fields"): + return _ModelContextContract(context_cls, None, (), None) + return _ModelContextContract( + context_cls, + {name: info.annotation for name, info in context_cls.model_fields.items()}, + tuple(name for name, info in context_cls.model_fields.items() if info.is_required()), + None, + ) + + +def _model_base_field_names(generated: "_GeneratedFlowModelBase") -> set[str]: + """Return field names from model_base that aren't function parameters or internal fields.""" + config = type(generated).__flow_model_config__ + param_names = {param.name for param in config.parameters} + return {name for name in type(generated).model_fields if name not in param_names and name != "meta"} + + +def _missing_regular_param_names(model: "_GeneratedFlowModelBase", config: _FlowModelConfig) -> List[str]: + missing = [] + for param in config.regular_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + missing.append(param.name) + return missing + + +# --------------------------------------------------------------------------- +# Generated model input resolution +# --------------------------------------------------------------------------- + + +def _resolve_regular_param_value(model: "_GeneratedFlowModelBase", param: _FlowModelParam, context: ContextBase) -> Any: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + raise TypeError( + f"Regular parameter '{param.name}' for {_callable_name(type(model).__flow_model_config__.func)} is still unbound. " + "Bind it at construction time." + ) + if _is_model_dependency(value): + if param.is_lazy: + return _make_lazy_thunk(value, context) + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + return _unwrap_model_result(dependency_model(dependency_context)) + if param.is_lazy: + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") + return value + + +def _collect_contextual_values( + model: "_GeneratedFlowModelBase", + config: _FlowModelConfig, + explicit_values: Dict[str, Any], +) -> Tuple[Dict[str, Any], List[str]]: + """Resolve ``FromContext`` values from runtime values, model defaults, and function defaults.""" + + resolved: Dict[str, Any] = {} + missing: List[str] = [] + + for param in config.contextual_params: + if param.name in explicit_values: + resolved[param.name] = explicit_values[param.name] + continue + + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + resolved[param.name] = value + continue + + if param.has_function_default: + resolved[param.name] = param.function_default + continue + + missing.append(param.name) + + return resolved, missing + + +def _resolved_contextual_inputs(model: "_GeneratedFlowModelBase", config: _FlowModelConfig, context: ContextBase) -> Dict[str, Any]: + """Validate and return the contextual kwargs passed to the user function.""" + + resolved, missing_contextual = _collect_contextual_values(model, config, _context_values(context)) + + if missing_contextual: + missing = ", ".join(sorted(missing_contextual)) + raise TypeError( + f"Missing contextual input(s) for {_callable_name(config.func)}: {missing}. " + "Supply them via the runtime context, compute(), with_context(), or construction-time contextual defaults." + ) + + if config.declared_context_type is not None: + return _validate_declared_context_values(config, resolved) + + return { + param.name: _coerce_value(param.name, resolved[param.name], param.validation_annotation, "Context field") + for param in config.contextual_params + } + + +def _validate_declared_context_values(config: _FlowModelConfig, values: Dict[str, Any]) -> Dict[str, Any]: + if config.declared_context_type is None: + return values + + validated = config.declared_context_type.model_validate(values) + return {param.name: getattr(validated, param.name) for param in config.contextual_params} + + +def _validate_declared_context_field(config: _FlowModelConfig, name: str, value: Any) -> Any: + if config.declared_context_type is None: + return _UNSET + + try: + validated = config.declared_context_type.model_validate({name: value}) + except ValidationError as exc: + field_errors = [error for error in exc.errors() if error.get("loc") and error["loc"][0] == name] + if field_errors: + raise + return _UNSET + return getattr(validated, name) + + +def _coerce_contextual_value(config: _FlowModelConfig, param: _FlowModelParam, value: Any, source: str) -> Any: + declared_value = _validate_declared_context_field(config, param.name, value) + if declared_value is not _UNSET: + return declared_value + return _coerce_value(param.name, value, param.validation_annotation, source) + + +def _coerce_model_context_value(model: CallableModel, field_name: str, value: Any, source: str) -> Any: + """Coerce a value for a contextual field when the target field type is known.""" + + generated = _generated_model_instance(model) + if generated is not None: + config = type(generated).__flow_model_config__ + if field_name in config.contextual_param_names: + return _coerce_contextual_value(config, config.param(field_name), value, source) + + contract = _model_context_contract(model) + if contract.input_types is None or field_name not in contract.input_types: + return value + return _coerce_value(field_name, value, contract.input_types[field_name], source) + + +# --------------------------------------------------------------------------- +# Effective identity helpers +# --------------------------------------------------------------------------- + + +def _identity_context_values_for_model(model: CallableModel, context: ContextBase) -> Dict[str, Any]: + return _identity_context_values_for_model_values(model, _context_values(context)) + + +def _identity_context_values_for_model_values(model: CallableModel, values: Dict[str, Any]) -> Dict[str, Any]: + """Project context values for identity, not execution. + + This intentionally mirrors context projection but is kept separate because + identity cares about "what affects the result", while execution cares about + "what context object can validate for the target model". + """ + + contract = _model_context_contract(model) + if contract.input_types is None: + return values + return {name: values[name] for name in contract.input_types if name in values} + + +def _identity_context_values_and_missing_for_model(model: CallableModel, values: Dict[str, Any]) -> Tuple[Dict[str, Any], Tuple[str, ...]]: + """Return identity-relevant context values and required fields still missing.""" + + generated = _generated_model_instance(model) + if generated is not None: + config = type(generated).__flow_model_config__ + resolved, missing = _collect_contextual_values(generated, config, values) + return ( + {param.name: resolved[param.name] for param in config.contextual_params if param.name in resolved}, + tuple(missing), + ) + + context_values = _identity_context_values_for_model_values(model, values) + missing = tuple(name for name in _model_context_contract(model).required_names if name not in context_values) + return context_values, missing + + +def _context_transform_missing_context_names(binding: ContextTransform, values: Dict[str, Any]) -> Tuple[str, ...]: + config = _load_context_transform_config_from_binding(binding) + return tuple(param.name for param in config.contextual_params if param.name not in values and not param.has_function_default) + + +def _evaluate_context_transform_from_values(binding: ContextTransform, values: Dict[str, Any]) -> Any: + """Run a context transform against a raw value mapping. + + Transform contextual inputs are read from the original runtime context for a + binding layer. This keeps field transforms independent of earlier patches + in the same ``with_context(...)`` call and makes ordering rules explicit. + """ + + config = _load_context_transform_config_from_binding(binding) + kwargs = _bound_context_transform_regular_kwargs(config, binding) + + for param in config.contextual_params: + if param.name in values: + kwargs[param.name] = _coerce_value(param.name, values[param.name], param.annotation, "Context transform field") + elif param.has_function_default: + kwargs[param.name] = param.function_default + else: + raise TypeError( + f"Missing contextual input(s) for context transform {_callable_name(config.func)}: {param.name}. " + "Supply them via the runtime context or with_context() ordering." + ) + + return config.func(**kwargs) + + +def _apply_context_spec_values_for_identity( + model: CallableModel, context_spec: "_BoundContextSpec", context: ContextBase +) -> Tuple[Dict[str, Any], Tuple[Tuple[str, Tuple[str, ...]], ...]]: + """Apply a binding spec for identity derivation. + + Unlike execution, identity must be able to describe a binding even when a + transform cannot yet run because some source context fields are missing. In + that case the missing transform inputs are recorded so two unresolved lazy + dependencies do not collapse accidentally. + """ + + current_values = _context_values(context) + missing_transforms: List[Tuple[str, Tuple[str, ...]]] = [] + + for patch in context_spec.patches: + missing = _context_transform_missing_context_names(patch.binding, _context_values(context)) + if missing: + missing_transforms.append((_context_transform_identifier(patch.binding), missing)) + continue + result = _evaluate_context_transform_from_values(patch.binding, _context_values(context)) + current_values.update(_validate_patch_result(model, result)) + + for name, spec in context_spec.field_overrides.items(): + if isinstance(spec, StaticValueSpec): + current_values[name] = spec.value + continue + + missing = _context_transform_missing_context_names(spec.binding, _context_values(context)) + if missing: + missing_transforms.append((name, missing)) + current_values.pop(name, None) + continue + result = _evaluate_context_transform_from_values(spec.binding, _context_values(context)) + current_values[name] = _coerce_model_context_value(model, name, result, "with_context()") + + return current_values, tuple(missing_transforms) + + +def _unresolved_lazy_dependency_descriptor( + value: CallableModel, + context_values: Dict[str, Any], + missing_context: Tuple[str, ...], + missing_transform_context: Tuple[Tuple[str, Tuple[str, ...]], ...] = (), +) -> Dict[str, Any]: + """Describe a lazy dependency whose runtime context cannot be resolved yet.""" + + return { + "kind": "unresolved_lazy_dependency", + "model_type": str(PyObjectPath.validate(type(value))), + "model": value.model_dump(mode="python"), + "context_type": str(PyObjectPath.validate(FlowContext)), + "context": context_values, + "missing_context": missing_context, + "missing_transform_context": missing_transform_context, + } + + +def _lazy_dependency_identity( + value: CallableModel, + context: ContextBase, +) -> Tuple[Optional[Dict[str, Any]], Optional[CallableModel], Optional[ContextBase]]: + """Resolve or describe a lazy dependency for effective identity. + + If all required context is available, return the concrete dependency + invocation so evaluator/common can recursively derive its effective key. If + not, return a stable unresolved descriptor instead of evaluating or raising. + + This deliberately does not try to prove whether the lazy thunk will be used + by the downstream function. That would require executing user logic before + key construction. The current policy is conservative: resolvable lazy + dependencies participate in identity; lazy dependencies with unresolved + runtime context are represented by descriptors. + """ + + if isinstance(value, BoundModel): + dependency_model = value.model + values, missing_transform_context = _apply_context_spec_values_for_identity(dependency_model, value.context_spec, context) + if missing_transform_context: + context_values: Dict[str, Any] = {} + missing_context = _model_context_contract(dependency_model).required_names + else: + context_values, missing_context = _identity_context_values_and_missing_for_model(dependency_model, values) + if missing_context or missing_transform_context: + return _unresolved_lazy_dependency_descriptor(value, context_values, missing_context, missing_transform_context), None, None + return None, *_resolved_dependency_invocation(value, context) + + dependency_model = value + context_values, missing_context = _identity_context_values_and_missing_for_model(dependency_model, _context_values(context)) + if missing_context: + return _unresolved_lazy_dependency_descriptor(value, context_values, missing_context), None, None + return None, *_resolved_dependency_invocation(value, context) + + +def _validate_bound_param_value( + config: _FlowModelConfig, + param: _FlowModelParam, + value: Any, + source: str, +) -> Any: + """Validate a construction-time bound parameter for a generated model.""" + + if param.is_contextual: + if _is_model_dependency(value): + raise TypeError( + f"Parameter '{param.name}' is marked FromContext[...] and cannot be bound to a CallableModel. " + "Bind a literal contextual default or supply it via compute()/with_context()." + ) + return _coerce_contextual_value(config, param, value, source) + + if param.is_lazy and not _is_model_dependency(value): + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") + if _is_model_dependency(value): + return value + return _coerce_value(param.name, value, param.annotation, source) + + +def _generated_model_identity_payload( + model: "_GeneratedFlowModelBase", + context: ContextBase, +) -> Optional[Dict[str, Any]]: + """Describe the generated model's effective invocation for cache keys. + + Contract: + - contextual identity is projected to the ``FromContext[...]`` fields the + generated model consumes; + - unused ambient ``FlowContext`` fields are ignored; + - regular literal inputs are included directly; + - regular ``CallableModel`` inputs are recorded as dependency invocations; + - lazy ``CallableModel`` inputs are recorded conservatively when their + context can be resolved, even if the lazy thunk is not called later; and + - unresolved lazy dependency runtime context is recorded explicitly instead of + forcing eager dependency resolution. + + Returning ``None`` asks the evaluator to use the structural key. + """ + + config = type(model).__flow_model_config__ + regular_inputs = [] + for param in config.regular_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + return None + + descriptor = {"name": param.name, "lazy": param.is_lazy} + if _is_model_dependency(value): + if param.is_lazy: + unresolved, dependency_model, dependency_context = _lazy_dependency_identity(value, context) + if unresolved is not None: + descriptor.update(unresolved) + else: + assert dependency_model is not None + assert dependency_context is not None + descriptor.update({"kind": "dependency", "evaluation": EvaluationDependency(dependency_model, dependency_context)}) + else: + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + descriptor.update({"kind": "dependency", "evaluation": EvaluationDependency(dependency_model, dependency_context)}) + else: + descriptor.update({"kind": "literal", "value": value}) + regular_inputs.append(descriptor) + + model_base_fields = {name: getattr(model, name) for name in sorted(_model_base_field_names(model))} + + return { + "kind": "generated_flow_model_v1", + "model_type": str(PyObjectPath.validate(type(model))), + "contextual_inputs": _resolved_contextual_inputs(model, config, context), + "regular_inputs": regular_inputs, + "model_base_fields": model_base_fields, + } + + +# --------------------------------------------------------------------------- +# Static binding resolution and with_context normalization +# --------------------------------------------------------------------------- + + +def _resolved_static_contextual_values( + model: "_GeneratedFlowModelBase", + config: _FlowModelConfig, + static_overrides: Optional[Dict[str, StaticValueSpec]] = None, +) -> Optional[Dict[str, Any]]: + override_values = {name: spec.value for name, spec in (static_overrides or {}).items()} + resolved, missing = _collect_contextual_values(model, config, override_values) + return None if missing else resolved + + +def _validate_bound_declared_context_defaults(model: "_GeneratedFlowModelBase", config: _FlowModelConfig) -> None: + resolved = _resolved_static_contextual_values(model, config) + if resolved is None: + return + + validated = _validate_declared_context_values(config, resolved) + for param in config.contextual_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + object.__setattr__(model, param.name, validated[param.name]) + + +def _bound_context_transform_regular_kwargs(config: _FlowModelConfig, binding: ContextTransform) -> Dict[str, Any]: + """Return already-bound regular kwargs for a context transform invocation.""" + + kwargs: Dict[str, Any] = {} + for param in config.regular_params: + if param.name in binding.bound_args: + kwargs[param.name] = binding.bound_args[param.name] + elif param.has_function_default: + kwargs[param.name] = param.function_default + else: + raise TypeError(f"Context transform '{_callable_name(config.func)}' is missing required regular parameter '{param.name}'.") + return kwargs + + +def _evaluate_static_context_transform(binding: ContextTransform) -> Any: + """Evaluate a transform at binding time if it has no required contextual inputs.""" + + config = _load_context_transform_config_from_binding(binding) + kwargs = _bound_context_transform_regular_kwargs(config, binding) + + for param in config.contextual_params: + if param.has_function_default: + kwargs[param.name] = param.function_default + continue + return _UNSET + + return config.func(**kwargs) + + +def _static_field_override_value(model: CallableModel, field_name: str, spec: _FieldOverrideSpec) -> Any: + if isinstance(spec, StaticValueSpec): + return spec.value + + value = _evaluate_static_context_transform(spec.binding) + if value is _UNSET: + return _UNSET + + contract = _model_context_contract(model) + if contract.input_types is None or field_name not in contract.input_types: + return value + return _coerce_model_context_value(model, field_name, value, "with_context()") + + +def _statically_resolved_context_values(model: CallableModel, context_spec: _BoundContextSpec) -> Optional[Dict[str, Any]]: + """Return static binding values when the whole spec can be resolved without runtime context.""" + + values: Dict[str, Any] = {} + + for patch in context_spec.patches: + result = _evaluate_static_context_transform(patch.binding) + if result is _UNSET: + return None + values.update(_validate_patch_result(model, result)) + + for name, spec in context_spec.field_overrides.items(): + value = _static_field_override_value(model, name, spec) + if value is _UNSET: + return None + values[name] = value + + return values + + +def _statically_resolved_context_field_names(model: CallableModel, context_spec: _BoundContextSpec) -> Set[str]: + names: Set[str] = set() + + for patch in context_spec.patches: + result = _evaluate_static_context_transform(patch.binding) + if result is _UNSET: + continue + names.update(_validate_patch_result(model, result)) + + for name, spec in context_spec.field_overrides.items(): + if _static_field_override_value(model, name, spec) is not _UNSET: + names.add(name) + + return names + + +def _context_transform_input_types(binding: ContextTransform, *, required_only: bool) -> Dict[str, Any]: + config = _load_context_transform_config_from_binding(binding) + names = config.context_required_names if required_only else config.contextual_param_names + return {name: config.context_input_types[name] for name in names} + + +def _validate_static_context_spec_declared_context(model: CallableModel, context_spec: _BoundContextSpec) -> _BoundContextSpec: + generated = _generated_model_instance(model) + if generated is None: + return context_spec + + config = type(generated).__flow_model_config__ + if config.declared_context_type is None: + return context_spec + + static_context_values = _statically_resolved_context_values(model, context_spec) + if static_context_values is None: + return context_spec + + static_overrides = {name: StaticValueSpec(value=value) for name, value in static_context_values.items()} + resolved = _resolved_static_contextual_values(generated, config, static_overrides) + if resolved is None: + return context_spec + + _validate_declared_context_values(config, resolved) + return context_spec + + +def _validate_with_context_field_names(model: CallableModel, names: List[str]) -> None: + contract = _model_context_contract(model) + if contract.input_types is not None: + invalid = sorted(set(names) - set(contract.input_types)) + if invalid: + names = ", ".join(invalid) + raise TypeError(f"with_context() only accepts contextual fields. Invalid field(s): {names}.") + + +def _binding_uses_patch_shape(binding: ContextTransform) -> bool: + return _is_mapping_annotation(_load_context_transform_config_from_binding(binding).return_annotation) + + +def _validate_context_transform_factory_kwargs(config: _FlowModelConfig, kwargs: Dict[str, Any]) -> Dict[str, Any]: + unknown = sorted(set(kwargs) - {param.name for param in config.parameters}) + if unknown: + raise TypeError(f"{_callable_name(config.func)}() got unexpected keyword argument(s): {', '.join(unknown)}") + + contextual = sorted(name for name in kwargs if config.param(name).is_contextual) + if contextual: + raise TypeError( + f"{_callable_name(config.func)}() only binds regular parameters. Do not pass contextual parameter(s): {', '.join(contextual)}." + ) + + missing = [param.name for param in config.regular_params if param.name not in kwargs and not param.has_function_default] + if missing: + raise TypeError(f"{_callable_name(config.func)}() is missing required regular parameter(s): {', '.join(missing)}") + + validated: Dict[str, Any] = {} + for param in config.regular_params: + if param.name not in kwargs: + continue + value = kwargs[param.name] + validated[param.name] = _coerce_value(param.name, value, param.annotation, "Context transform argument") + return validated + + +def _validate_patch_result(model: CallableModel, result: Any) -> Dict[str, Any]: + if not isinstance(result, Mapping): + raise TypeError( + f"Patch context transform for {model!r} must return a mapping of contextual field names to values, got {type(result).__name__}." + ) + + patch = dict(result) + if not all(isinstance(name, str) for name in patch): + raise TypeError("Patch context transforms must return a mapping with string field names.") + + _validate_with_context_field_names(model, list(patch)) + contract = _model_context_contract(model) + if contract.input_types is None: + return patch + + return {name: _coerce_model_context_value(model, name, value, "with_context() patch") for name, value in patch.items()} + + +def _normalize_with_context(model: CallableModel, patches: Tuple[Any, ...], field_overrides: Dict[str, Any]) -> _BoundContextSpec: + """Validate and normalize user-facing ``with_context(...)`` arguments.""" + + normalized_patches = [] + for patch in patches: + if callable(patch): + raise TypeError("with_context() no longer accepts raw callables. Replace the callable with a top-level @Flow.context_transform binding.") + if not isinstance(patch, ContextTransform): + raise TypeError("Positional with_context() arguments must be @Flow.context_transform bindings that return a mapping.") + if not _binding_uses_patch_shape(patch): + raise TypeError( + "Field context transforms must be passed by keyword to with_context(...). Patch transforms belong in positional arguments." + ) + normalized_patches.append(PatchContextSpec(binding=patch)) + + _validate_with_context_field_names(model, list(field_overrides)) + contract = _model_context_contract(model) + normalized_field_overrides: Dict[str, _FieldOverrideSpec] = {} + for name, value in field_overrides.items(): + if callable(value): + raise TypeError("with_context() no longer accepts raw callables. Replace the callable with a top-level @Flow.context_transform binding.") + if isinstance(value, ContextTransform): + if _binding_uses_patch_shape(value): + raise TypeError("Patch transforms must be passed positionally to with_context(...), not as keyword field overrides.") + normalized_field_overrides[name] = FieldContextSpec(binding=value) + continue + normalized_field_overrides[name] = StaticValueSpec( + value=value + if contract.input_types is None or name not in contract.input_types + else _coerce_model_context_value(model, name, value, "with_context()") + ) + + context_spec = _BoundContextSpec(patches=normalized_patches, field_overrides=normalized_field_overrides) + return _validate_static_context_spec_declared_context(model, context_spec) + + +# --------------------------------------------------------------------------- +# Bound context application and compute context construction +# --------------------------------------------------------------------------- + + +def _apply_context_spec_values(model: CallableModel, context_spec: _BoundContextSpec, context: ContextBase) -> Dict[str, Any]: + """Apply a binding spec at execution time and return rewritten context values.""" + + current_values = _context_values(context) + + for patch in context_spec.patches: + result = _evaluate_context_transform_from_values(patch.binding, _context_values(context)) + current_values.update(_validate_patch_result(model, result)) + + for name, spec in context_spec.field_overrides.items(): + if isinstance(spec, StaticValueSpec): + current_values[name] = spec.value + continue + result = _evaluate_context_transform_from_values(spec.binding, _context_values(context)) + current_values[name] = _coerce_model_context_value(model, name, result, "with_context()") + + return current_values + + +def _apply_context_spec(model: CallableModel, context_spec: _BoundContextSpec, context: ContextBase) -> ContextBase: + """Apply bindings, project to the wrapped model, and build its runtime context.""" + + if not context_spec.patches and not context_spec.field_overrides: + return _dependency_context_for_model(model, context) + + values = _apply_context_spec_values(model, context_spec, context) + return _runtime_context_for_model(model, _project_context_values_for_model(model, values)) + + +def _build_compute_context(model: CallableModel, context: Any, kwargs: Dict[str, Any]) -> Optional[ContextBase]: + """Construct the context used by ``FlowAPI.compute`` for a target model. + + ``compute`` is intentionally not a second constructor. For generated models + it only supplies contextual inputs; regular parameters and model_base fields + must already be bound on the model instance. + """ + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + + ctx_type = model.context_type + _ctx_is_optional = get_origin(ctx_type) in _UNION_ORIGINS and type(None) in get_args(ctx_type) + + contract = _model_context_contract(model) + + if context is not _UNSET: + if context is None and _ctx_is_optional: + return None + if isinstance(context, FlowContext): + return context + if isinstance(context, ContextBase): + return _runtime_context_for_model(model, _context_values(context)) + return contract.runtime_context_type.model_validate(context) + + if contract.generated_model is None: + if not kwargs and _ctx_is_optional: + return None + return contract.runtime_context_type.model_validate(kwargs) + + generated = contract.generated_model + config = type(generated).__flow_model_config__ + regular_kwargs = sorted(name for name in config.regular_param_names if name in kwargs) + unresolved_regular = sorted(name for name in regular_kwargs if _is_unset_flow_input(getattr(generated, name, _UNSET_FLOW_INPUT))) + if unresolved_regular: + names = ", ".join(unresolved_regular) + raise TypeError( + f"compute() cannot satisfy unbound regular parameter(s): {names}. " + "Bind them at construction time; compute() only supplies runtime context." + ) + + already_bound_regular = sorted(name for name in regular_kwargs if name not in unresolved_regular) + if already_bound_regular: + names = ", ".join(already_bound_regular) + raise TypeError( + f"compute() does not accept regular parameter override(s): {names}. " + "Those parameters are already bound on the model. Pass a context object if you need ambient fields with the same names." + ) + + base_kwargs = sorted(name for name in _model_base_field_names(generated) if name in kwargs) + if base_kwargs: + names = ", ".join(base_kwargs) + raise TypeError( + f"compute() does not accept model configuration override(s): {names}. Those fields are bound on the model at construction time." + ) + + ambient = dict(kwargs) + for param in config.contextual_params: + if param.name not in kwargs: + continue + ambient[param.name] = _coerce_contextual_value(config, param, kwargs[param.name], "compute() input") + return FlowContext(**ambient) + + +def _is_optional_context_type(context_type: Any) -> bool: + return get_origin(context_type) in _UNION_ORIGINS and type(None) in get_args(context_type) + + +def _bound_model_preserves_none_context(bound_model: "BoundModel") -> bool: + return ( + not bound_model.context_spec.patches + and not bound_model.context_spec.field_overrides + and _is_optional_context_type(bound_model.model.context_type) + ) + + +def _build_bound_compute_context(bound_model: "BoundModel", context: Any, kwargs: Dict[str, Any]) -> Optional[ContextBase]: + """Construct the ambient context passed into a ``BoundModel`` by ``compute``.""" + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + if context is not _UNSET: + return context + if not kwargs and _bound_model_preserves_none_context(bound_model): + return None + return FlowContext(**kwargs) + + +# --------------------------------------------------------------------------- +# model.flow API and BoundModel wrapper +# --------------------------------------------------------------------------- + + +class FlowAPI: + """API namespace exposed as ``model.flow``. + + ``FlowAPI`` works for both generated models and ordinary ``CallableModel`` + instances. Generated models get richer introspection because their function + signature declares regular and contextual inputs. Plain models only expose + what can be inferred from their ``context_type`` and pydantic fields. + """ + + def __init__(self, model: CallableModel): + self._model = model + + @property + def _compute_target(self) -> CallableModel: + return self._model + + def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = None, **kwargs) -> Any: + """Evaluate the model after building a runtime context from ``context`` or kwargs.""" + + target = self._compute_target + built_context = _build_compute_context(target, context, kwargs) + return _maybe_auto_unwrap_external_result(target, target(built_context, _options=_options)) + + @property + def context_inputs(self) -> Dict[str, Any]: + """Contextual input names and expected types for this model.""" + + contract = _model_context_contract(self._model) + return dict(contract.input_types or {}) + + @property + def unbound_inputs(self) -> Dict[str, Any]: + """Required contextual inputs that are not already satisfied.""" + + contract = _model_context_contract(self._model) + if contract.generated_model is None and _is_optional_context_type(self._model.context_type): + return {} + if contract.generated_model is None: + return {} if contract.input_types is None else {name: contract.input_types[name] for name in contract.required_names} + + generated = contract.generated_model + config = type(generated).__flow_model_config__ + result = {} + for param in config.contextual_params: + if not _is_unset_flow_input(getattr(generated, param.name, _UNSET_FLOW_INPUT)): + continue + if param.has_function_default: + continue + result[param.name] = param.annotation if contract.input_types is None else contract.input_types[param.name] + return result + + @property + def bound_inputs(self) -> Dict[str, Any]: + """Inputs already fixed by construction-time values or static context bindings.""" + + generated = _model_context_contract(self._model).generated_model + if generated is not None: + config = type(generated).__flow_model_config__ + result: Dict[str, Any] = {} + explicit_fields = _bound_field_names(generated) + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value + for param in config.contextual_params: + if param.name not in explicit_fields: + continue + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value + for name in _model_base_field_names(generated): + if name in explicit_fields: + result[name] = getattr(generated, name) + return result + + result: Dict[str, Any] = {} + model_fields = getattr(self._model.__class__, "model_fields", {}) + for name in model_fields: + if name == "meta": + continue + result[name] = getattr(self._model, name) + return result + + def with_context(self, *patches, **field_overrides) -> "BoundModel": + """Return a wrapper that rewrites runtime context before evaluating this model.""" + + context_spec = _normalize_with_context(self._model, patches, field_overrides) + return BoundModel(model=self._model, context_spec=context_spec) + + +class BoundModel(WrapperModel): + """A wrapper that rewrites context for exactly one wrapped model. + + ``BoundModel`` is deliberately a ``WrapperModel`` rather than mutating the + wrapped model. This keeps context bindings scoped to the edge where they + are used, lets dependency graphs show the wrapped model, and derives + effective identity from the rewritten wrapped invocation. + """ + + context_spec: _BoundContextSpec = Field(default_factory=_BoundContextSpec, repr=False) + + def __reduce__(self): + return (_restore_pickled_flow_model, (str(PyObjectPath.validate(type(self))), self.__getstate__())) + + def _rewrite_context(self, context: ContextBase) -> ContextBase: + """Apply this wrapper's context bindings to an ambient runtime context.""" + + return _apply_context_spec(self.model, self.context_spec, context) + + @property + def context_type(self) -> Any: + if _bound_model_preserves_none_context(self): + return self.model.context_type + return FlowContext + + @Flow.call + def __call__(self, context: ContextType) -> ResultBase: + """Evaluate the wrapped model after rewriting context.""" + + if context is None and _bound_model_preserves_none_context(self): + return self.model(None) + return self.model(self._rewrite_context(context)) + + @Flow.deps + def __deps__(self, context: ContextType) -> GraphDepList: + """Expose the wrapped model as the single dependency of this binding wrapper.""" + + if context is None and _bound_model_preserves_none_context(self): + return self.model.__deps__(None) + return [(self.model, [self._rewrite_context(context)])] + + def __repr__(self) -> str: + args = [_context_transform_repr(patch.binding) for patch in self.context_spec.patches] + args.extend( + f"{name}={_context_transform_repr(spec.binding if isinstance(spec, FieldContextSpec) else spec.value)}" + for name, spec in self.context_spec.field_overrides.items() + ) + return f"{self.model!r}.flow.with_context({', '.join(args)})" + + def _evaluation_identity_payload( + self, + context: ContextBase, + ) -> Optional[Any]: + """Describe this binding in terms of the rewritten wrapped call.""" + + return { + "kind": "bound_model_v1", + "model": EvaluationDependency(self.model, self._rewrite_context(context)), + } + + @property + def flow(self) -> "FlowAPI": + return _BoundFlowAPI(self) + + +class _BoundFlowAPI(FlowAPI): + """``model.flow`` implementation for ``BoundModel`` wrappers.""" + + def __init__(self, bound_model: BoundModel): + self._bound = bound_model + super().__init__(bound_model.model) + + @property + def _compute_target(self) -> CallableModel: + return self._bound + + def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = None, **kwargs) -> Any: + """Evaluate the bound wrapper after building its ambient context.""" + + built_context = _build_bound_compute_context(self._bound, context, kwargs) + return _maybe_auto_unwrap_external_result(self._bound, self._bound(built_context, _options=_options)) + + @property + def bound_inputs(self) -> Dict[str, Any]: + result = super().bound_inputs + for patch in self._bound.context_spec.patches: + patch_result = _evaluate_static_context_transform(patch.binding) + if patch_result is not _UNSET: + result.update(_validate_patch_result(self._bound.model, patch_result)) + for name, spec in self._bound.context_spec.field_overrides.items(): + value = _static_field_override_value(self._bound.model, name, spec) + if value is not _UNSET: + result[name] = value + else: + result.pop(name, None) + return result + + @property + def context_inputs(self) -> Dict[str, Any]: + result = super().context_inputs + for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): + result.pop(name, None) + for patch in self._bound.context_spec.patches: + if _evaluate_static_context_transform(patch.binding) is _UNSET: + result.update(_context_transform_input_types(patch.binding, required_only=False)) + for name, spec in self._bound.context_spec.field_overrides.items(): + if isinstance(spec, FieldContextSpec) and _static_field_override_value(self._bound.model, name, spec) is _UNSET: + result.pop(name, None) + result.update(_context_transform_input_types(spec.binding, required_only=False)) + return result + + @property + def unbound_inputs(self) -> Dict[str, Any]: + result = super().unbound_inputs + for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): + result.pop(name, None) + for patch in self._bound.context_spec.patches: + if _evaluate_static_context_transform(patch.binding) is _UNSET: + result.update(_context_transform_input_types(patch.binding, required_only=True)) + for name, spec in self._bound.context_spec.field_overrides.items(): + if isinstance(spec, FieldContextSpec) and _static_field_override_value(self._bound.model, name, spec) is _UNSET: + result.pop(name, None) + result.update(_context_transform_input_types(spec.binding, required_only=True)) + return result + + def with_context(self, *patches, **field_overrides) -> BoundModel: + context_spec = _normalize_with_context(self._bound.model, patches, field_overrides) + merged = _merge_context_specs(self._bound.context_spec, context_spec.patches, context_spec.field_overrides) + return BoundModel( + model=self._bound.model, + context_spec=_validate_static_context_spec_declared_context(self._bound.model, merged), + ) + + +class _GeneratedFlowModelBase(CallableModel): + """Base class for all classes created by ``@Flow.model``.""" + + __flow_model_config__: ClassVar[_FlowModelConfig] + + def __reduce__(self): + """Prefer import-path restoration, falling back to serialized local factories.""" + + config = type(self).__flow_model_config__ + factory_path = _generated_model_factory_path_for_pickle(config, type(self)) + if factory_path is not None: + return (_restore_generated_flow_model, (factory_path, self.__getstate__())) + import cloudpickle + + payload = (config.func, type(self).__flow_model_factory_kwargs__) + return (_restore_pickled_local_flow_model, (cloudpickle.dumps(payload, protocol=5), self.__getstate__())) + + @model_validator(mode="before") + @classmethod + def _resolve_registry_refs(cls, values): + """Resolve registry string references for regular dependency fields.""" + + if not isinstance(values, dict): + return values + + config = getattr(cls, "__flow_model_config__", None) + if config is None: + return values + + resolved = dict(values) + for param in config.regular_params: + if param.name not in resolved: + continue + value = resolved[param.name] + if not isinstance(value, str): + continue + if _type_accepts_str(param.annotation): + continue + candidate = _resolve_registry_candidate(value) + if candidate is None: + continue + if _registry_candidate_allowed(param.annotation, candidate): + resolved[param.name] = candidate + return resolved + + @model_validator(mode="after") + def _validate_flow_model_fields(self): + """Validate all bound regular and contextual defaults after pydantic construction.""" + + config = self.__class__.__flow_model_config__ + + for param in config.parameters: + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + object.__setattr__( + self, + param.name, + _validate_bound_param_value(config, param, value, "Contextual default" if param.is_contextual else "Field"), + ) + + _validate_bound_declared_context_defaults(self, config) + return self + + @property + def context_type(self) -> Type[ContextBase]: + return self.__class__.__flow_model_config__.context_type + + @property + def result_type(self) -> Type[ResultBase]: + return self.__class__.__flow_model_config__.result_type + + @property + def flow(self) -> FlowAPI: + return FlowAPI(self) + + def _evaluation_identity_payload( + self, + context: ContextBase, + ) -> Optional[Any]: + return _generated_model_identity_payload(self, context) + + +# --------------------------------------------------------------------------- +# Generated model method builders and decorators +# --------------------------------------------------------------------------- + + +def _make_call_impl(config: _FlowModelConfig) -> _AnyCallable: + """Create the ``__call__`` implementation for one generated model class.""" + + def __call__(self, context): + """Resolve bound inputs, dependency inputs, and context inputs, then call the user function.""" + + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError( + f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. " + "Bind them at construction time; compute() only supplies contextual inputs." + ) + + fn_kwargs: Dict[str, Any] = {} + for param in config.regular_params: + value = _resolve_regular_param_value(self, param, context) + if param.is_lazy: + fn_kwargs[param.name] = _make_coercing_lazy_thunk(value, param.name, param.annotation) + else: + fn_kwargs[param.name] = _coerce_value(param.name, value, param.annotation, "Regular parameter") + + fn_kwargs.update(_resolved_contextual_inputs(self, config, context)) + + raw_result = config.func(**fn_kwargs) + if config.auto_wrap_result: + return config.result_type.model_validate(raw_result) + return raw_result + + cast(Any, __call__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=config.result_type, + ) + return __call__ + + +def _make_deps_impl(config: _FlowModelConfig) -> _AnyCallable: + """Create the ``__deps__`` implementation for one generated model class.""" + + def __deps__(self, context): + """Declare non-lazy regular ``CallableModel`` inputs as graph dependencies.""" + + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError(f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. Bind them before dependency evaluation.") + + deps = [] + for param in config.regular_params: + if param.is_lazy: + continue + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if isinstance(value, CallableModel): + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + deps.append((dependency_model, [dependency_context])) + return deps + + cast(Any, __deps__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ + + +def _factory_param_annotation(param: _FlowModelParam) -> Any: + if param.is_contextual: + return FromContext[param.annotation] + if param.is_lazy: + return Lazy[param.annotation] + return param.annotation + + +def _factory_signature(config: _FlowModelConfig, generated_cls: Type[BaseModel]) -> inspect.Signature: + """Return the public construction signature for a generated factory.""" + + parameters = [] + for param in config.parameters: + parameters.append( + inspect.Parameter( + param.name, + inspect.Parameter.KEYWORD_ONLY, + annotation=_factory_param_annotation(param), + default=param.function_default if param.has_function_default else _UNSET_FLOW_INPUT, + ) + ) + + param_names = {param.name for param in config.parameters} + for name, field_info in generated_cls.model_fields.items(): + if name == "meta" or name in param_names: + continue + default = _UNSET_FLOW_INPUT if field_info.is_required() or getattr(field_info, "default_factory", None) is not None else field_info.default + parameters.append( + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.annotation, + default=default, + ) + ) + + return inspect.Signature(parameters=parameters, return_annotation=generated_cls) + + +def _resolve_generated_model_bases(model_base: Type[CallableModel]) -> Tuple[type, ...]: + """Return the class bases for a generated model, preserving custom model bases.""" + + if not isinstance(model_base, type) or not issubclass(model_base, CallableModel): + raise TypeError(f"model_base must be a CallableModel subclass, got {model_base!r}") + + if issubclass(model_base, _GeneratedFlowModelBase): + return (model_base,) + if model_base is CallableModel: + return (_GeneratedFlowModelBase,) + return (_GeneratedFlowModelBase, model_base) + + +def flow_context_transform(func: Optional[_AnyCallable] = None) -> _AnyCallable: + """Decorator that turns a function into a serializable ``with_context`` transform factory. + + Regular parameters are bound when the transform factory is called. + ``FromContext`` parameters are read from the runtime context when the bound + model executes. Transform functions returning mappings are positional patch + transforms; transforms returning scalar values are field transforms. + """ + + def decorator(fn: _AnyCallable) -> _AnyCallable: + """Analyze one transform function and return its binding factory.""" + + _ensure_top_level_named_function(fn, decorator_name="@Flow.context_transform") + try: + resolved_hints = get_type_hints(fn, include_extras=True) + except AttributeError: + resolved_hints = {} + sig = _resolved_flow_signature( + fn, + resolved_hints=resolved_hints, + require_return_annotation=True, + function_name=_callable_name(fn), + ) + config = _analyze_flow_context_transform(fn, sig, is_model_dependency=_is_model_dependency) + serialized_config = None if _context_transform_should_use_import_path(config) else _serialize_context_transform_config(config) + + @wraps(fn) + def factory(**kwargs) -> ContextTransform: + """Bind regular transform arguments into a serializable spec.""" + + return ContextTransform( + path=config.path if serialized_config is None else None, + serialized_config=serialized_config, + bound_args=_validate_context_transform_factory_kwargs(config, kwargs), + ) + + cast(Any, factory).__flow_context_transform_config__ = config + return factory + + if func is not None: + return decorator(func) + return decorator + + +def flow_model( + func: Optional[_AnyCallable] = None, + *, + context_args: Any = _REMOVED_CONTEXT_ARGS, + context_type: Optional[Type[ContextBase]] = None, + auto_unwrap: bool = False, + model_base: Type[CallableModel] = CallableModel, + cacheable: Any = _UNSET, + volatile: Any = _UNSET, + log_level: Any = _UNSET, + validate_result: Any = _UNSET, + verbose: Any = _UNSET, + evaluator: Any = _UNSET, +) -> _AnyCallable: + """Decorator that generates a ``CallableModel`` class from a plain function. + + Unmarked parameters become construction-time model fields. Parameters + annotated as ``FromContext[T]`` are contextual inputs supplied by + ``FlowContext``, a declared ``context_type``, ``compute(...)`` kwargs, or + ``with_context(...)`` bindings. The returned object is a factory that + creates instances of the generated model class. + """ + + if context_args is not _REMOVED_CONTEXT_ARGS: + raise TypeError("context_args=... has been removed. Mark runtime/contextual parameters with FromContext[...] instead.") + + def decorator(fn: _AnyCallable) -> _AnyCallable: + """Analyze one user function and synthesize its generated model class.""" + + try: + resolved_hints = get_type_hints(fn, include_extras=True) + except AttributeError: + resolved_hints = {} + sig = _resolved_flow_signature( + fn, + resolved_hints=resolved_hints, + require_return_annotation=True, + function_name=_callable_name(fn), + ) + config = _analyze_flow_model( + fn, + sig, + context_type=context_type, + auto_unwrap=auto_unwrap, + is_model_dependency=_is_model_dependency, + ) + + annotations: Dict[str, Any] = {} + namespace: Dict[str, Any] = { + "__module__": getattr(fn, "__module__", __name__), + "__qualname__": f"_{_callable_name(fn)}_Model", + "__call__": Flow.call( + **{ + name: value + for name, value in [ + ("cacheable", cacheable), + ("volatile", volatile), + ("log_level", log_level), + ("validate_result", validate_result), + ("verbose", verbose), + ("evaluator", evaluator), + ] + if value is not _UNSET + } + )(_make_call_impl(config)), + "__deps__": Flow.deps(_make_deps_impl(config)), + } + + for param in config.parameters: + annotations[param.name] = Any + if param.is_contextual: + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + elif param.has_function_default: + namespace[param.name] = param.function_default + else: + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + + namespace["__annotations__"] = annotations + + GeneratedModel = cast( + type[_GeneratedFlowModelBase], + type(f"_{_callable_name(fn)}_Model", _resolve_generated_model_bases(model_base), namespace), + ) + GeneratedModel.__flow_model_config__ = config + GeneratedModel.__flow_model_factory_kwargs__ = { + "context_type": context_type, + "auto_unwrap": auto_unwrap, + "model_base": model_base, + "cacheable": cacheable, + "volatile": volatile, + "log_level": log_level, + "validate_result": validate_result, + "verbose": verbose, + "evaluator": evaluator, + } + _register_generated_model_class(config, GeneratedModel) + GeneratedModel.model_rebuild() + + @wraps(fn) + def factory(**kwargs) -> _GeneratedFlowModelBase: + """Create a generated model instance with regular/contextual defaults bound.""" + + return GeneratedModel(**kwargs) + + cast(Any, factory)._generated_model = GeneratedModel + cast(Any, factory).__signature__ = _factory_signature(config, GeneratedModel) + factory.__doc__ = fn.__doc__ + return factory + + if func is not None: + return decorator(func) + return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml new file mode 100644 index 0000000..a697229 --- /dev/null +++ b/ccflow/tests/config/conf_flow.yaml @@ -0,0 +1,73 @@ +# Flow.model configurations for Hydra integration tests. + +flow_loader: + _target_: ccflow.tests.test_flow_model.basic_loader + source: fixture_input + multiplier: 5 + +flow_processor: + _target_: ccflow.tests.test_flow_model.string_processor + prefix: "value=" + suffix: "!" + +flow_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 100 + +flow_transformer: + _target_: ccflow.tests.test_flow_model.data_transformer + source: flow_source + factor: 3 + +flow_stage1: + _target_: ccflow.tests.test_flow_model.pipeline_stage1 + initial: 10 + +flow_stage2: + _target_: ccflow.tests.test_flow_model.pipeline_stage2 + stage1_output: flow_stage1 + multiplier: 2 + +flow_stage3: + _target_: ccflow.tests.test_flow_model.pipeline_stage3 + stage2_output: flow_stage2 + offset: 50 + +diamond_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 10 + +diamond_branch_a: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 2 + +diamond_branch_b: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 5 + +diamond_aggregator: + _target_: ccflow.tests.test_flow_model.data_aggregator + input_a: diamond_branch_a + input_b: diamond_branch_b + operation: add + +flow_date_loader: + _target_: ccflow.tests.test_flow_model.date_range_loader_previous_day + source: calendar_feed + include_weekends: false + +flow_date_processor: + _target_: ccflow.tests.test_flow_model.date_range_processor + raw_data: flow_date_loader + normalize: true + +contextual_loader_model: + _target_: ccflow.tests.test_flow_model.contextual_loader + source: data_source + +contextual_processor_model: + _target_: ccflow.tests.test_flow_model.contextual_processor + data: contextual_loader_model + prefix: output diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index 6db6d07..dd39c97 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -315,6 +315,30 @@ def __call__(self, context: ModelEvaluationContext): ) assert cache_key(opaque2) == cache_key(opaque2b) + def test_opaque_mec_order_preserved(self): + """Non-transparent evaluator wrapper order is identity-significant.""" + + class OpaqueEval(EvaluatorBase): + tag: str + + def __call__(self, context: ModelEvaluationContext): + return context() + + m = MyDateCallable(offset=1) + ctx = DateContext(date=date(2022, 1, 1)) + inner = ModelEvaluationContext(model=m, context=ctx) + + inner_then_outer = ModelEvaluationContext( + model=OpaqueEval(tag="outer"), + context=ModelEvaluationContext(model=OpaqueEval(tag="inner"), context=inner), + ) + outer_then_inner = ModelEvaluationContext( + model=OpaqueEval(tag="inner"), + context=ModelEvaluationContext(model=OpaqueEval(tag="outer"), context=inner), + ) + + assert cache_key(inner_then_outer) != cache_key(outer_then_inner) + def test_fn_deps_preserved_through_transparent(self): """fn='__deps__' is preserved when walking through transparent layers.""" m = MyDateCallable(offset=1) @@ -362,6 +386,52 @@ def test_basic(self): self.assertNotIn(key, evaluator.cache) self.assertNotIn(key, evaluator.ids) + def test_plain_callable_key_matches_public_cache_key(self): + """Existing CallableModels stay on the structural cache-key path.""" + m1 = MyDateCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, options=dict(cacheable=True)) + + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context, effective=True)) + + def test_plain_callable_key_fallback_does_not_log(self): + """Normal opt-out from effective identity should stay quiet.""" + m1 = MyDateCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, options=dict(cacheable=True)) + + with self.assertNoLogs("ccflow.evaluators.common", level="DEBUG"): + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) + + def test_effective_key_exception_fallback_logs_debug(self): + """Unexpected effective-identity failures are visible without breaking existing calls.""" + + class BadIdentityCallable(MyDateCallable): + def _evaluation_identity_payload(self, context): + raise ValueError("identity broke") + + m1 = BadIdentityCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, options=dict(cacheable=True)) + + with self.assertLogs("ccflow.evaluators.common", level="DEBUG") as captured: + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) + self.assertIn("Falling back to structural evaluation key for BadIdentityCallable.__call__: identity broke", captured.output[0]) + + def test_plain_callable_deps_key_matches_public_cache_key(self): + """Non-__call__ evaluations stay structural.""" + m1 = MyDateCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__deps__", options=dict(cacheable=True)) + + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context, effective=True)) + def test_caching(self): # Create some hard-to hash structure with all kinds of custom types # We will put this on the callable to make sure caching still works @@ -511,6 +581,34 @@ def test_graph_deps_diamond(self): elif v.model.meta.name == "n0": self.assertEqual(set(graph.ids[dep_key].context.model.meta.name for dep_key in graph.graph[k]), set()) + def test_plain_callable_graph_keys_match_public_cache_key(self): + """Dependency graphs for ordinary CallableModels keep structural keys.""" + n0 = NodeModel(meta=dict(name="n0")) + n1 = NodeModel(meta=dict(name="n1"), deps_model=[n0]) + n2 = NodeModel(meta=dict(name="n2"), deps_model=[n0]) + root = NodeModel(meta=dict(name="n3"), deps_model=[n1, n2]) + context = DateContext(date=date(2022, 1, 1)) + + graph = get_dependency_graph(ModelEvaluationContext(model=root, context=context)) + + for key, evaluation_context in graph.ids.items(): + self.assertEqual(key, cache_key(evaluation_context)) + self.assertEqual(graph.root_id, cache_key(ModelEvaluationContext(model=root, context=context))) + + def test_plain_callable_graph_deduplicates_equal_models_by_key(self): + """Ordinary graph traversal should keep main's key-only deduplication.""" + leaf1 = NodeModel(meta=dict(name="leaf")) + leaf2 = NodeModel(meta=dict(name="leaf")) + root = NodeModel(meta=dict(name="root"), deps_model=[leaf1, leaf2]) + context = DateContext(date=date(2022, 1, 1)) + + NodeModel._deps_calls = [] + graph = get_dependency_graph(ModelEvaluationContext(model=root, context=context)) + + self.assertEqual(NodeModel._deps_calls, [("root", date(2022, 1, 1)), ("leaf", date(2022, 1, 1))]) + self.assertEqual(len(graph.ids), 2) + self.assertEqual(graph.ids.keys(), graph.graph.keys()) + def test_graph_deps_circular(self): root = CircularModel() context = DateContext(date=date(2022, 1, 1)) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 43f86b5..8c62ac0 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -14,6 +14,7 @@ Flow, GenericResult, GraphDepList, + Lazy, MetaData, ModelRegistry, NullContext, @@ -783,3 +784,251 @@ class MyCallableParent_bad_decorator(MyCallableParent): @Flow.deps def foo(self, context): return [] + + +class TestAutoContext(TestCase): + """Tests for the opt-in @Flow.call(auto_context=...) path.""" + + def test_basic_usage_with_kwargs(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = AutoContextCallable() + + self.assertEqual(model(x=42, y="hello").value, "42-hello") + self.assertEqual(model(x=10).value, "10-default") + + def test_no_arg_call_uses_generated_context_defaults_only_for_auto_context(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int = 1, y: str = "a") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + class PlainCallable(CallableModel): + @Flow.call + def __call__(self, context: MyContext = MyContext(a="plain")) -> MyResult: + return MyResult(x=1, y=context.a) + + self.assertEqual(AutoContextCallable()().value, "1-a") + self.assertEqual(PlainCallable()().y, "plain") + + def test_no_arg_call_still_rejects_required_generated_context_fields(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "a") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + with self.assertRaisesRegex(TypeError, "missing 1 required positional argument: 'context'"): + AutoContextCallable()() + + def test_auto_context_attribute_and_registration(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + inner = AutoContextCallable.__call__.__wrapped__ + self.assertTrue(hasattr(inner, "__auto_context__")) + + auto_ctx = inner.__auto_context__ + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertIn("value", auto_ctx.model_fields) + self.assertTrue(hasattr(auto_ctx, "__ccflow_import_path__")) + self.assertTrue(auto_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + def test_call_with_context_object(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + ctx = auto_ctx(x=99, y="context") + + self.assertEqual(AutoContextCallable()(ctx).value, "99-context") + + def test_with_parent_context(self): + class ParentContext(ContextBase): + base_value: str = "base" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + self.assertTrue(issubclass(auto_ctx, ParentContext)) + self.assertIn("base_value", auto_ctx.model_fields) + self.assertIn("x", auto_ctx.model_fields) + self.assertEqual(AutoContextCallable()(x=42, base_value="custom").value, "42-custom") + + def test_parent_fields_must_be_in_signature(self): + class ParentContext(ContextBase): + required_field: str + + with self.assertRaisesRegex(TypeError, "must be included in function signature"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + def test_parent_field_type_incompatibility_rejected(self): + class ParentContext(ContextBase): + base: int + + with self.assertRaisesRegex(TypeError, "incompatible"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, base: str) -> GenericResult: + return GenericResult(value=base) + + def test_parent_field_defaults_remain_authoritative_for_auto_context(self): + class ParentContext(ContextBase): + base: str = "parent" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, base: str = "function") -> GenericResult: + return GenericResult(value=base) + + self.assertEqual(AutoContextCallable()().value, "parent") + + def test_cloudpickle_roundtrip(self): + class AutoContextCallable(CallableModel): + multiplier: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + restored = rcploads(rcpdumps(AutoContextCallable(multiplier=3))) + + self.assertEqual(restored(x=10).value, 30) + + def test_ray_task_execution(self): + class AutoContextCallable(CallableModel): + factor: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(AutoContextCallable(factor=5), x=10, y=2)) + + self.assertEqual(result, 60) + + def test_postponed_annotations_are_resolved(self): + namespace = {} + exec( + """ +from __future__ import annotations + +from ccflow import CallableModel, Flow, GenericResult + + +class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + +result = AutoContextCallable().flow.compute(x=1) +""", + namespace, + namespace, + ) + + self.assertEqual(namespace["result"].value, 1) + + def test_postponed_annotations_unresolved_names_stay_loud(self): + namespace = {} + with self.assertRaises(NameError): + exec( + """ +from __future__ import annotations + +from ccflow import CallableModel, Flow, GenericResult + + +class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: MissingType) -> GenericResult[int]: + return GenericResult(value=x) +""", + namespace, + namespace, + ) + + def test_normal_keyword_only_flow_call_without_auto_context_still_fails(self): + class BadCallable(CallableModel): + @Flow.call + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + with self.assertRaisesRegex(ValueError, "__call__ method must take a single argument, named 'context'"): + BadCallable() + + def test_invalid_auto_context_value(self): + with self.assertRaisesRegex(TypeError, "auto_context must be False, True, or a ContextBase subclass"): + + @Flow.call(auto_context="invalid") + def bad_func(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + def test_auto_context_rejects_var_args(self): + with self.assertRaisesRegex(TypeError, "variadic positional"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *args: int) -> GenericResult: + return GenericResult(value=len(args)) + + def test_auto_context_rejects_var_kwargs(self): + with self.assertRaisesRegex(TypeError, "variadic keyword"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, **kwargs: int) -> GenericResult: + return GenericResult(value=len(kwargs)) + + def test_auto_context_requires_return_annotation(self): + with self.assertRaisesRegex(TypeError, "must have a return type annotation"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int): + return GenericResult(value=value) + + def test_auto_context_rejects_missing_annotation(self): + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value) -> GenericResult: + return GenericResult(value=value) + + def test_auto_context_rejects_lazy_annotation(self): + with self.assertRaisesRegex(TypeError, "Lazy"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: Lazy[int]) -> GenericResult: + return GenericResult(value=value) + + def test_auto_context_rejects_callable_model_default(self): + with self.assertRaisesRegex(TypeError, "CallableModel"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int = MyCallable(i=1)) -> GenericResult: + return GenericResult(value=value) diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index ad98bd9..64d71e8 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -275,8 +275,13 @@ def split_camel(name: str): def test_inheritance(self): """Test that if a context has a superset of fields of another context, it is a subclass of that context.""" - for parent_name, parent_class in self.classes.items(): - for child_name, child_class in self.classes.items(): + # Exclude FlowContext from this test - it's a special universal carrier with no + # declared fields (uses extra="allow"), so the "superset implies subclass" logic + # doesn't apply to it. + classes_to_check = {name: cls for name, cls in self.classes.items() if name != "FlowContext"} + + for parent_name, parent_class in classes_to_check.items(): + for child_name, child_class in classes_to_check.items(): if parent_class is child_class: continue diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index cc34155..dabf815 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -1,9 +1,21 @@ from datetime import date from unittest import TestCase -from ccflow import DateContext, Evaluator, ModelEvaluationContext +import pytest -from .evaluators.util import MyDateCallable +from ccflow import CallableModel, DateContext, Evaluator, Flow, ModelEvaluationContext + +from .evaluators.util import MyDateCallable, MyResult + + +class MyAutoContextDateCallable(CallableModel): + """Auto context version of MyDateCallable for testing evaluators.""" + + offset: int + + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date) -> MyResult: + return MyResult(x=date.day + self.offset) class TestEvaluator(TestCase): @@ -32,3 +44,57 @@ def test_evaluator_deps(self): evaluator = Evaluator() out2 = evaluator.__deps__(model_evaluation_context) self.assertEqual(out2, out) + + +@pytest.mark.parametrize( + "callable_class", + [MyDateCallable, MyAutoContextDateCallable], + ids=["standard", "auto_context"], +) +class TestEvaluatorParametrized: + """Test evaluators work with both standard and auto_context callables.""" + + def test_evaluator_with_context_object(self, callable_class): + """Test evaluator with a context object.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + + out = model_evaluation_context() + assert out == MyResult(x=2) # day 1 + offset 1 + + evaluator = Evaluator() + out2 = evaluator(model_evaluation_context) + assert out2 == out + + def test_evaluator_with_fn_specified(self, callable_class): + """Test evaluator with fn='__call__' explicitly specified.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__call__") + + out = model_evaluation_context() + assert out == MyResult(x=2) + + def test_evaluator_direct_call_matches(self, callable_class): + """Test that evaluator result matches direct call.""" + m1 = callable_class(offset=5) + context = DateContext(date=date(2022, 1, 15)) + + # Direct call + direct_result = m1(context) + + # Via evaluator + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + evaluator_result = model_evaluation_context() + + assert direct_result == evaluator_result + assert direct_result == MyResult(x=20) # day 15 + offset 5 + + def test_evaluator_with_kwargs(self, callable_class): + """Test that evaluator works when callable is called with kwargs.""" + m1 = callable_class(offset=1) + + # Call with kwargs + result = m1(date=date(2022, 1, 10)) + assert result == MyResult(x=11) # day 10 + offset 1 diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py new file mode 100644 index 0000000..1bccc55 --- /dev/null +++ b/ccflow/tests/test_flow_context.py @@ -0,0 +1,354 @@ +"""Tests for FlowContext, FlowAPI, and BoundModel under the FromContext design.""" + +import pickle +from concurrent.futures import ThreadPoolExecutor +from datetime import date, timedelta + +import cloudpickle +import pytest + +from ccflow import BoundModel, CallableModel, ContextBase, Flow, FlowContext, FromContext, GenericResult + + +class NumberContext(ContextBase): + x: int + + +class OffsetModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: NumberContext) -> GenericResult[int]: + return GenericResult(value=context.x + self.offset) + + +@Flow.context_transform +def shift_start_date(start_date: FromContext[date], days: int) -> date: + return start_date - timedelta(days=days) + + +@Flow.context_transform +def shift_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +@Flow.context_transform +def offset_value(value: FromContext[int], amount: int) -> int: + return value + amount + + +@Flow.context_transform +def offset_b(b: FromContext[int], amount: int) -> int: + return b + amount + + +@Flow.context_transform +def double_x(x: FromContext[int]) -> int: + return x * 2 + + +def test_flow_context_basic_properties(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), label="x") + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + assert ctx.label == "x" + assert dict(ctx) == {"start_date": date(2024, 1, 1), "end_date": date(2024, 1, 31), "label": "x"} + + +def test_flow_context_value_semantics_and_hash(): + first = FlowContext(x=1, values=[1, 2]) + second = FlowContext(x=1, values=[1, 2]) + third = FlowContext(x=2, values=[1, 2]) + + assert first == second + assert first != third + assert len({first, second, third}) == 2 + + +def test_flow_context_hash_handles_nested_models_and_rejects_opaque_unhashable_values(): + class WithDict: + __hash__ = None + + def __init__(self): + self.values = {"a": [1, 2]} + + class UnhashableNoState: + __slots__ = () + __hash__ = None + + nested = FlowContext(model=NumberContext(x=1), values={"items": [{2, 1}]}) + same = FlowContext(model=NumberContext(x=1), values={"items": [{1, 2}]}) + + assert nested == nested + assert nested != {"model": NumberContext(x=1)} + assert hash(nested) == hash(same) + assert hash(FlowContext(value=WithDict())) == hash(FlowContext(value=WithDict())) + + with pytest.raises(TypeError, match="unhashable value"): + hash(FlowContext(value=UnhashableNoState())) + + +def test_flow_context_pickle_and_cloudpickle_roundtrip(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), tags=frozenset({"a", "b"})) + assert pickle.loads(pickle.dumps(ctx)) == ctx + assert cloudpickle.loads(cloudpickle.dumps(ctx)) == ctx + + +def test_flow_api_introspection_for_from_context_model(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + model = add(a=10) + assert model.flow.context_inputs == {"b": int, "c": int} + assert model.flow.unbound_inputs == {"b": int} + assert model.flow.bound_inputs == {"a": 10} + assert model.flow.compute(b=2).value == 17 + + +def test_flow_api_compute_accepts_single_context_or_kwargs_but_not_both(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + assert model.flow.compute(b=5).value == 15 + assert model.flow.compute(FlowContext(b=6)).value == 16 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(FlowContext(b=5), b=6) + + +def test_bound_model_with_context_static_and_transform(): + @Flow.model + def load_window(start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_window() + shifted = model.flow.with_context( + shift_window(days=7), + end_date=date(2024, 1, 31), + ) + + result = shifted(FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 2, 7))) + assert result.value == {"start": date(2024, 1, 1), "end": date(2024, 1, 31)} + + +def test_bound_model_with_context_is_branch_local_and_chained(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value + + base = source() + left = base.flow.with_context(value=offset_value(amount=1)) + right = base.flow.with_context(value=offset_value(amount=2)).flow.with_context(value=offset_value(amount=10)) + model = combine(left=left, right=right) + + assert model.flow.compute(value=5).value == (6 + 15 + 5) + + +def test_chained_with_context_merges_patch_transforms(): + @Flow.model + def load(start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"start": start_date, "end": end_date} + + base = load() + # First with_context applies a patch, second chains another patch on top + chained = base.flow.with_context(shift_window(days=7)).flow.with_context(shift_window(days=3)) + + # Both patches should be present in the merged context spec. + assert len(chained.context_spec.patches) == 2 + + # Patches evaluate against the original context, merge left-to-right. + # patch1: start - 7, end - 7 => Jan 1, Jan 24 + # patch2: start - 3, end - 3 => Jan 5, Jan 28 (overwrites patch1 keys) + result = chained(FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31))) + assert result.value == {"start": date(2024, 1, 5), "end": date(2024, 1, 28)} + + +def test_compute_kwargs_can_supply_ambient_context_for_upstream_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + base = source() + model = combine( + left=base.flow.with_context(value=offset_value(amount=1)), + right=base.flow.with_context(value=offset_value(amount=10)), + ) + + assert model.flow.context_inputs == {"bonus": int} + assert model.flow.compute(value=5, bonus=100).value == (6 + 15 + 100) + + +def test_bound_model_rejects_regular_field_context_overrides(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="only accepts contextual fields"): + add(a=1).flow.with_context(a=3) + + +def test_bound_model_repr_matches_user_facing_api(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=1) + bound = model.flow.with_context(b=offset_b(amount=1)) + assert repr(bound) == f"{model!r}.flow.with_context(b=offset_b(amount=1))" + + +def test_bound_model_serialization_roundtrip_preserves_static_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=5) + + dumped = bound.model_dump(mode="python") + assert dumped["context_spec"] == {"patches": [], "field_overrides": {"b": {"kind": "static_value", "value": 5}}} + + restored = type(bound).model_validate(dumped) + assert restored.flow.compute().value == 15 + assert restored.model.flow.bound_inputs == {"a": 10} + + +def test_bound_model_json_roundtrip_preserves_context_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + dumped = bound.model_dump(mode="json") + assert dumped["context_spec"]["field_overrides"]["b"]["binding"]["path"].endswith(".offset_b") + + restored = type(bound).model_validate(dumped) + assert restored.flow.compute(b=4).value == 15 + + +def test_bound_model_cloudpickle_roundtrip_preserves_context_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + restored = cloudpickle.loads(cloudpickle.dumps(bound)) + assert restored.flow.compute(b=4).value == 15 + + +def test_bound_model_plain_pickle_roundtrip_preserves_context_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + restored = pickle.loads(pickle.dumps(bound, protocol=5)) + assert restored.flow.compute(b=4).value == 15 + + +def test_transformed_dag_cloudpickle_roundtrip_preserves_context_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value + + base = source() + model = combine( + left=base.flow.with_context(value=offset_value(amount=1)), + right=base.flow.with_context(value=offset_value(amount=10)), + ) + restored = cloudpickle.loads(cloudpickle.dumps(model)) + + assert restored.flow.compute(value=5).value == (6 + 15 + 5) + + +def test_bound_model_pydantic_roundtrip_preserves_context_transforms(): + """model_dump(mode='python') + model_validate must preserve serialized context bindings.""" + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + assert bound.flow.compute(b=4).value == 15 + + dumped = bound.model_dump(mode="python") + assert dumped["context_spec"]["field_overrides"]["b"]["binding"]["kind"] == "context_transform" + + restored = type(bound).model_validate(dumped) + assert restored.flow.compute(b=4).value == 15 + + +def test_bound_model_context_spec_dump_contains_patch_and_field_specs(): + """model_dump(mode='json') should emit explicit tagged context-spec objects.""" + + @Flow.model + def load_window(start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + dumped = load_window().flow.with_context(shift_window(days=7), start_date=shift_start_date(days=1)).model_dump(mode="json") + assert dumped["context_spec"]["patches"][0]["kind"] == "context_patch" + assert dumped["context_spec"]["field_overrides"]["start_date"]["kind"] == "context_value" + + +def test_regular_callable_models_still_support_with_context(): + model = OffsetModel(offset=10) + shifted = model.flow.with_context(x=double_x()) + assert shifted(NumberContext(x=5)).value == 20 + + +def test_flow_api_for_regular_callable_model(): + model = OffsetModel(offset=10) + assert model.flow.compute(x=5).value == 15 + assert model.flow.context_inputs == {"x": int} + assert model.flow.unbound_inputs == {"x": int} + assert model.flow.bound_inputs == {"offset": 10} + + +def test_generated_flow_model_compute_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int]) -> int: + return a + b + c + + model = add(a=10) + + def worker(n: int) -> int: + return model.flow.compute(b=n, c=n + 1).value + + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) + + assert results == [10 + n + n + 1 for n in range(20)] + + +def test_bound_model_restore_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + dumped = add(a=10).flow.with_context(b=5).model_dump(mode="python") + + def worker(_: int) -> int: + restored = BoundModel.model_validate(dumped) + return restored.flow.compute().value + + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) + + assert results == [15] * 20 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py new file mode 100644 index 0000000..6c82625 --- /dev/null +++ b/ccflow/tests/test_flow_model.py @@ -0,0 +1,3028 @@ +"""Focused tests for the FromContext-based Flow.model API.""" + +import base64 +import graphlib +import inspect +import pickle +import subprocess +import sys +from datetime import date, timedelta +from types import ModuleType +from typing import Annotated, Any, Literal, Optional + +import pytest +import ray +from pydantic import Field, ValidationError, model_validator +from ray.cloudpickle import dumps as rcpdumps, loads as rcploads + +import ccflow +import ccflow._flow_model_binding as flow_binding_module +import ccflow.flow_model as flow_model_module +from ccflow import ( + CallableModel, + ContextBase, + DateRangeContext, + EvaluatorBase, + Flow, + FlowContext, + FlowOptionsOverride, + FromContext, + GenericResult, + Lazy, + ModelEvaluationContext, + ModelRegistry, +) +from ccflow.callable import FlowOptions +from ccflow.evaluators import GraphEvaluator, LoggingEvaluator, MemoryCacheEvaluator, cache_key, combine_evaluators, get_dependency_graph +from ccflow.exttypes import PyObjectPath + + +class SimpleContext(ContextBase): + value: int + + +class ParentRangeContext(ContextBase): + start_date: date + end_date: date + + +class RichRangeContext(ParentRangeContext): + label: str = "child" + + +class OrderedContext(ContextBase): + a: int + b: int + + @model_validator(mode="after") + def _validate_order(self): + if self.a > self.b: + raise ValueError("a must be <= b") + return self + + +@Flow.model +def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + +@Flow.model +def string_processor(value: FromContext[int], prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: + return GenericResult(value=f"{prefix}{value}{suffix}") + + +@Flow.model +def data_source(base_value: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + base_value) + + +@Flow.model +def data_transformer(source: int, factor: int) -> GenericResult[int]: + return GenericResult(value=source * factor) + + +@Flow.model +def data_aggregator(input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: + if operation == "add": + return GenericResult(value=input_a + input_b) + raise ValueError(f"unsupported operation: {operation}") + + +@Flow.model +def pipeline_stage1(initial: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + initial) + + +@Flow.model +def pipeline_stage2(stage1_output: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=stage1_output * multiplier) + + +@Flow.model +def pipeline_stage3(stage2_output: int, offset: int) -> GenericResult[int]: + return GenericResult(value=stage2_output + offset) + + +@Flow.model +def date_range_loader_previous_day( + source: str, + start_date: FromContext[date], + end_date: FromContext[date], + include_weekends: bool = False, +) -> GenericResult[dict]: + del include_weekends + return GenericResult( + value={ + "source": source, + "start_date": str(start_date - timedelta(days=1)), + "end_date": str(end_date), + } + ) + + +@Flow.model +def date_range_processor(raw_data: dict, normalize: bool = False) -> GenericResult[str]: + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") + + +@Flow.model +def contextual_loader(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) + + +@Flow.model +def contextual_processor( + prefix: str, + data: dict, + start_date: FromContext[date], + end_date: FromContext[date], +) -> GenericResult[str]: + del start_date, end_date + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") + + +@Flow.context_transform +def increment_b(b: FromContext[int], amount: int) -> int: + return b + amount + + +@Flow.context_transform +def shift_integer_window(start_date: FromContext[int], end_date: FromContext[int], amount: int) -> dict[str, object]: + return { + "start_date": start_date + amount, + "end_date": end_date + amount, + } + + +@Flow.context_transform +def bump_start_date(start_date: FromContext[int], amount: int) -> int: + return start_date + amount + + +@Flow.context_transform +def annotated_start_patch(start_date: FromContext[int]) -> Annotated[dict[str, object], "meta"]: + return {"start_date": start_date + 1} + + +@Flow.context_transform +def optional_start_patch(start_date: FromContext[int]) -> dict[str, object] | None: + return {"start_date": start_date + 2} + + +@Flow.context_transform +def parity_bucket(raw: FromContext[int]) -> int: + return raw % 2 + + +@Flow.context_transform +def seed_plus_one(seed: FromContext[int]) -> int: + return seed + 1 + + +@Flow.context_transform +def non_idempotent_a_step(a: FromContext[int]) -> int: + return 2 if a == 1 else 3 + + +@Flow.context_transform +def static_bad() -> int: + return 2 + + +def lazy_context_transform_for_rejection(value: Lazy[int]) -> int: + return value() + + +@Flow.context_transform +def static_patch() -> dict[str, object]: + return {"a": 2} + + +def test_module_level_flow_model_examples_and_transforms_execute(): + assert data_aggregator(input_a=1, input_b=2, operation="add").flow.compute().value == 3 + with pytest.raises(ValueError, match="unsupported operation"): + data_aggregator(input_a=1, input_b=2, operation="multiply").flow.compute() + + raw = date_range_loader_previous_day(source="library").flow.compute(start_date=date(2024, 1, 2), end_date=date(2024, 1, 3)).value + assert raw == {"source": "library", "start_date": "2024-01-01", "end_date": "2024-01-03"} + assert date_range_processor(raw_data=raw).flow.compute().value.startswith("raw:library") + assert date_range_processor(raw_data=raw, normalize=True).flow.compute().value.startswith("normalized:library") + + contextual = contextual_loader(source="library").flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value + assert contextual_processor(prefix="p", data=contextual).flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value == ( + "p:library:2024-01-01 to 2024-01-02" + ) + + assert ( + pipeline_stage3(stage2_output=pipeline_stage2(stage1_output=pipeline_stage1(initial=2), multiplier=3), offset=4).flow.compute(value=5).value + == 25 + ) + + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int], bucket: FromContext[int]) -> int: + return start_date * 100 + end_date * 10 + bucket + + shifted = load().flow.with_context( + shift_integer_window(amount=2), + start_date=bump_start_date(amount=10), + bucket=parity_bucket(), + ) + assert shifted.flow.compute(start_date=1, end_date=5, raw=7).value == 1171 + + @Flow.model + def add(a: FromContext[int]) -> int: + return a + + assert add().flow.with_context(a=non_idempotent_a_step()).flow.compute(a=1).value == 2 + assert add().flow.with_context(a=non_idempotent_a_step()).flow.compute(a=10).value == 3 + assert add().flow.with_context(static_patch()).flow.compute(a=1).value == 2 + + +def test_context_transform_internal_error_and_repr_paths(): + assert flow_model_module._context_transform_repr(static_patch()) == "static_patch()" + assert flow_model_module._context_transform_repr(increment_b(amount=2)) == "increment_b(amount=2)" + assert flow_model_module._context_transform_repr(123) == "123" + assert flow_model_module._context_transform_identifier(increment_b(amount=1)).endswith(".increment_b") + + with pytest.raises(ValidationError, match="exactly one"): + flow_model_module.ContextTransform() + + with pytest.raises(ValidationError, match="exactly one"): + flow_model_module.ContextTransform(path="ccflow.tests.test_flow_model.increment_b", serialized_config="also-set") + + with pytest.raises(TypeError, match="does not resolve to a Flow.context_transform binding"): + flow_model_module._load_context_transform_config("ccflow.tests.test_flow_model.lazy_context_transform_for_rejection") + + invalid_payload = base64.b64encode(pickle.dumps({"not": "a config"}, protocol=5)).decode("ascii") + with pytest.raises(TypeError, match="payload does not contain"): + flow_model_module._load_serialized_context_transform_config(invalid_payload) + + invalid_binding = flow_model_module.ContextTransform.model_construct(path=None, serialized_config=None, bound_args={}) + with pytest.raises(TypeError, match="neither path nor serialized_config"): + flow_model_module._load_context_transform_config_from_binding(invalid_binding) + + with pytest.raises(ImportError, match="does not have a _generated_model"): + flow_model_module._restore_generated_flow_model("ccflow.tests.test_flow_model.lazy_context_transform_for_rejection", {}) + + +def test_flow_model_low_level_value_helpers_cover_edge_paths(): + assert flow_model_module._bound_field_names(object()) == set() + assert flow_model_module._concrete_context_type(int | None) is None + no_name_annotation = int | str + assert flow_model_module._expected_type_repr(no_name_annotation) == repr(no_name_annotation) + assert flow_model_module._coerce_value("x", "still-raw", object(), "test") == "still-raw" + assert flow_model_module._unwrap_model_result(7) == 7 + assert flow_model_module._type_accepts_str(Annotated[str, "meta"]) is True + assert flow_model_module._type_accepts_str(int | str) is True + assert flow_binding_module._is_result_annotation(GenericResult[int] | None) is True + assert flow_model_module._registry_candidate_allowed(object(), data_source(base_value=1)) is True + assert flow_model_module._registry_candidate_allowed(int, GenericResult(value=1)) is False + assert flow_model_module._is_mapping_annotation(inspect.Signature.empty) is False + assert flow_model_module._is_mapping_annotation(123) is False + generated_type = type(basic_loader(source="s", multiplier=2)) + assert flow_model_module._resolve_generated_model_bases(generated_type) == (generated_type,) + assert callable(Flow.context_transform()) + + metadata: list[object] = [] + annotation = Annotated[int, metadata] + assert flow_model_module._type_adapter(annotation) is flow_model_module._type_adapter(annotation) + + with pytest.raises(TypeError, match="only supports Python functions"): + flow_model_module._ensure_top_level_named_function(123, decorator_name="@Flow.model") + with pytest.raises(TypeError, match="only supports named Python functions"): + flow_model_module._ensure_top_level_named_function(lambda: None, decorator_name="@Flow.model") + + +def test_lazy_thunks_and_regular_resolution_edge_paths(): + calls = {"dependency": 0, "inner": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["dependency"] += 1 + return value + 10 + + thunk = flow_model_module._make_lazy_thunk(source(), FlowContext(value=2)) + assert thunk() == 12 + assert thunk() == 12 + assert calls["dependency"] == 1 + + def inner(): + calls["inner"] += 1 + return "13" + + coercing = flow_model_module._make_coercing_lazy_thunk(inner, "value", int) + assert coercing() == 13 + assert coercing() == 13 + assert calls["inner"] == 1 + + @Flow.model + def missing_regular(x: int) -> int: + return x + + missing_config = type(missing_regular()).__flow_model_config__ + with pytest.raises(TypeError, match="still unbound"): + flow_model_module._resolve_regular_param_value(missing_regular(), missing_config.param("x"), FlowContext()) + + @Flow.model + def lazy_consumer(x: Lazy[int]) -> int: + return x() + + lazy_model = getattr(lazy_consumer, "_generated_model").model_construct(x=1) + lazy_config = type(lazy_model).__flow_model_config__ + with pytest.raises(TypeError, match="must be bound to a CallableModel"): + flow_model_module._resolve_regular_param_value(lazy_model, lazy_config.param("x"), FlowContext()) + + +def test_context_transform_validation_and_static_resolution_edge_paths(): + @Flow.context_transform + def default_amount(amount: int = 5) -> int: + return amount + + @Flow.context_transform + def default_seed(seed: FromContext[int] = 9) -> int: + return seed + 1 + + @Flow.context_transform + def dynamic_patch(seed: FromContext[int]) -> dict[str, object]: + return {"a": seed} + + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + assert add().flow.with_context(a=default_amount(), b=default_seed()).flow.compute().value == 15 + assert flow_model_module._evaluate_context_transform_from_values(default_seed(), {}) == 10 + with pytest.raises(TypeError, match="Missing contextual input"): + flow_model_module._evaluate_context_transform_from_values(seed_plus_one(), {}) + + dynamic_spec = flow_model_module._BoundContextSpec( + patches=[flow_model_module.PatchContextSpec(binding=dynamic_patch())], + field_overrides={}, + ) + assert flow_model_module._statically_resolved_context_values(add(), dynamic_spec) is None + assert flow_model_module._statically_resolved_context_field_names(add(), dynamic_spec) == set() + + identity_values, missing_transforms = flow_model_module._apply_context_spec_values_for_identity(add(), dynamic_spec, FlowContext(b=2)) + assert identity_values == {"b": 2} + assert missing_transforms == ((flow_model_module._context_transform_identifier(dynamic_patch()), ("seed",)),) + + missing_regular = default_amount() + config = flow_model_module._load_context_transform_config_from_binding(default_amount()) + assert flow_model_module._bound_context_transform_regular_kwargs(config, missing_regular) == {"amount": 5} + + with pytest.raises(TypeError, match="unexpected keyword"): + increment_b(amount=1, extra=2) + with pytest.raises(TypeError, match="Do not pass contextual"): + increment_b(b=1, amount=2) + with pytest.raises(TypeError, match="missing required regular"): + increment_b() + + with pytest.raises(TypeError, match="must return a mapping"): + flow_model_module._validate_patch_result(add(), 1) + with pytest.raises(TypeError, match="string field names"): + flow_model_module._validate_patch_result(add(), {1: 2}) + + class OpaqueModel: + context_type = object + + assert flow_model_module._validate_patch_result(OpaqueModel(), {"x": 1}) == {"x": 1} + flow_model_module._validate_with_context_field_names(OpaqueModel(), ["anything"]) + assert ( + flow_model_module._static_field_override_value(OpaqueModel(), "anything", flow_model_module.FieldContextSpec(binding=default_amount())) == 5 + ) + + with pytest.raises(TypeError, match="raw callables"): + add().flow.with_context(lambda: {"a": 1}) + with pytest.raises(TypeError, match="Positional with_context"): + add().flow.with_context(123) + + +def test_additional_flow_model_source_edge_paths(monkeypatch): + @Flow.context_transform + def default_seed(seed: FromContext[int] = 9) -> int: + return seed + 1 + + @Flow.context_transform + def dynamic_patch(seed: FromContext[int]) -> dict[str, object]: + return {"a": seed} + + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + @Flow.model + def regular_required(x: int) -> int: + return x + + @Flow.model + def lazy_consumer(x: Lazy[int]) -> int: + return x() + + restored = flow_model_module._restore_generated_flow_model( + "ccflow.tests.test_flow_model.basic_loader", + basic_loader(source="s", multiplier=2).__getstate__(), + ) + assert restored.flow.compute(value=3).value == 6 + + class FailingPath: + def __init__(self, path): + self.path = path + + @property + def object(self): + raise ImportError(self.path) + + original_path = flow_model_module.PyObjectPath + monkeypatch.setattr(flow_model_module, "PyObjectPath", FailingPath) + config = type(basic_loader(source="s", multiplier=2)).__flow_model_config__ + assert flow_model_module._generated_model_factory_path_for_pickle(config, type(basic_loader(source="s", multiplier=2))) is None + monkeypatch.setattr(flow_model_module, "PyObjectPath", original_path) + + assert flow_model_module._registry_candidate_allowed(int, 1) is True + opaque_model = type("OpaqueModel", (), {"context_type": object})() + assert flow_model_module._coerce_model_context_value(opaque_model, "anything", "raw", "test") == "raw" + assert flow_model_module._generated_model_identity_payload(regular_required(), FlowContext()) is None + + context_spec = flow_model_module._BoundContextSpec( + patches=[flow_model_module.PatchContextSpec(binding=dynamic_patch())], + field_overrides={"b": flow_model_module.FieldContextSpec(binding=default_seed())}, + ) + values, missing = flow_model_module._apply_context_spec_values_for_identity(add(), context_spec, FlowContext(seed=1)) + assert values == {"a": 1, "b": 2, "seed": 1} + assert missing == () + assert flow_model_module._statically_resolved_context_values(add(), context_spec) is None + + bound = add().flow.with_context(dynamic_patch()) + assert bound.flow.context_inputs == {"a": int, "b": int, "seed": int} + assert bound.flow.unbound_inputs == {"a": int, "b": int, "seed": int} + + with pytest.raises(TypeError, match="missing required regular"): + flow_model_module._bound_context_transform_regular_kwargs( + flow_model_module._load_context_transform_config_from_binding(increment_b(amount=1)), + increment_b(amount=1).model_copy(update={"bound_args": {}}), + ) + with pytest.raises(TypeError, match="Missing regular parameter"): + regular_required().__deps__(FlowContext()) + assert lazy_consumer(x=data_source(base_value=1)).__deps__(FlowContext(value=1)) == [] + assert getattr(basic_loader, "_generated_model")._resolve_registry_refs("raw") == "raw" + assert flow_model_module._GeneratedFlowModelBase._resolve_registry_refs({}) == {} + + def transform_with_bad_hints(value: FromContext[int]) -> int: + return value + + def raise_attribute_error(*args, **kwargs): + raise AttributeError("bad hints") + + monkeypatch.setattr(flow_model_module, "get_type_hints", raise_attribute_error) + assert Flow.context_transform(transform_with_bad_hints)().serialized_config is not None + + +def test_plain_and_bound_optional_compute_paths_and_identity_helpers(): + class AnyContextModel: + context_type = object + + class FlowContextModel: + context_type = FlowContext + + class OptionalContextModel(CallableModel): + @Flow.call + def __call__(self, context: Optional[SimpleContext] = None) -> GenericResult[int]: + return GenericResult(value=0 if context is None else context.value) + + assert flow_model_module._model_context_contract(AnyContextModel()).input_types is None + assert flow_model_module._model_context_contract(FlowContextModel()).input_types is None + assert flow_model_module._identity_context_values_for_model(AnyContextModel(), FlowContext(extra=1)) == {"extra": 1} + assert OptionalContextModel().flow.compute(None).value == 0 + assert OptionalContextModel().flow.compute().value == 0 + assert OptionalContextModel().flow.unbound_inputs == {} + + bound = OptionalContextModel().flow.with_context() + assert bound.flow.compute(FlowContext(value=3)).value == 3 + with pytest.raises(TypeError, match="either one context object"): + bound.flow.compute(FlowContext(value=3), value=4) + assert bound.flow._compute_target is bound + + +def test_bound_optional_none_context_preserves_wrapped_dependencies(): + class Dep(CallableModel): + @Flow.call + def __call__(self, context: FlowContext) -> GenericResult[int]: + return GenericResult(value=1) + + class Root(CallableModel): + dep: Dep + + @Flow.call + def __call__(self, context: Optional[FlowContext] = None) -> GenericResult[int]: + return GenericResult(value=self.dep(FlowContext()).value + (0 if context is None else context.bonus)) + + @Flow.deps + def __deps__(self, context: Optional[FlowContext]) -> list[tuple[CallableModel, list[ContextBase]]]: + return [(self.dep, [FlowContext()])] + + root = Root(dep=Dep()) + bound = root.flow.with_context() + graph = get_dependency_graph(bound.__call__.get_evaluation_context(bound, None)) + + assert len(graph.ids) == 2 + assert bound.flow.compute().value == 1 + + +def test_from_context_anchor_behavior(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + assert foo(a=11).flow.compute(b=12).value == 23 + assert foo(a=11, b=12).flow.compute().value == 23 + + with pytest.raises(TypeError, match="compute\\(\\) cannot satisfy unbound regular parameter\\(s\\): a"): + foo().flow.compute(a=11, b=12) + + +def test_regular_param_accepts_upstream_model(): + @Flow.model + def source(value: FromContext[int], offset: int) -> int: + return value + offset + + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + model = foo(a=source(offset=5)) + assert model.flow.compute(value=7, b=12).value == 24 + assert model.flow.compute(FlowContext(value=7, b=12)).value == 24 + + +def test_regular_param_upstream_dependency_coerced(): + """Upstream model returning str should be coerced to downstream int annotation.""" + + @Flow.model + def str_source(tag: FromContext[str]) -> str: + return tag + + @Flow.model + def consumer(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = consumer(x=str_source()) + # str_source returns "42" (a str); consumer expects x: int; should be coerced + result = model.flow.compute(tag="42", bonus=10) + assert result.value == 52 + assert isinstance(result.value, int) + + # Also test that invalid coercion raises + with pytest.raises(TypeError, match="Regular parameter"): + model.flow.compute(tag="not_a_number", bonus=10) + + +def test_regular_param_lazy_upstream_dependency_coerced(): + """Lazy upstream model output should be coerced on first call.""" + + @Flow.model + def lazy_source(v: FromContext[int]) -> str: + return str(v) + + @Flow.model + def consumer(x: Lazy[int], bonus: FromContext[int]) -> int: + return x() + bonus + + model = consumer(x=lazy_source()) + result = model.flow.compute(v=7, bonus=3) + assert result.value == 10 + assert isinstance(result.value, int) + + +def test_regular_param_plain_callable_model_projects_dependency_context(): + class ValueContext(ContextBase): + value: int + + class PlainSource(CallableModel): + @property + def context_type(self): + return ValueContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: ValueContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=PlainSource()) + deps = model.__deps__(FlowContext(value=3, bonus=7)) + + assert len(deps) == 1 + dep_model, dep_contexts = deps[0] + assert isinstance(dep_model, PlainSource) + assert dep_contexts == [ValueContext(value=3)] + assert model.flow.compute(FlowContext(value=3, bonus=7)).value == 37 + + +def test_bound_regular_param_name_can_collide_with_ambient_context(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + @Flow.model + def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + model = combine(a=100, left=source()) + assert model.flow.compute(FlowContext(a=7, bonus=5)).value == 112 + + with pytest.raises(TypeError, match="does not accept regular parameter override\\(s\\): a"): + model.flow.compute(a=7, bonus=5) + + +def test_contextual_param_rejects_callable_model(): + @Flow.model + def source(offset: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + offset) + + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="cannot be bound to a CallableModel"): + foo(a=1, b=source(offset=2)) + + +def test_contextual_construction_defaults_and_bound_inputs(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + model = foo(a=11, b=12) + assert model.flow.bound_inputs == {"a": 11, "b": 12} + assert model.flow.context_inputs == {"b": int} + assert model.flow.unbound_inputs == {} + assert model.flow.compute().value == 23 + + +def test_contextual_function_defaults_remain_contextual(): + @Flow.model + def foo(a: int, b: FromContext[int] = 5) -> int: + return a + b + + model = foo(a=2) + assert model.flow.bound_inputs == {"a": 2} + assert model.flow.context_inputs == {"b": int} + assert model.flow.unbound_inputs == {} + assert model.flow.compute().value == 7 + assert model.flow.compute(b=10).value == 12 + + +def test_compute_rejects_kwargs_for_already_bound_regular_params(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=1) + with pytest.raises(TypeError, match="does not accept regular parameter override\\(s\\): a"): + model.flow.compute(a=999, b=2) + + +def test_context_type_accepts_richer_subclass_for_from_context(): + @Flow.model(context_type=ParentRangeContext) + def span_days(multiplier: int, start_date: FromContext[date], end_date: FromContext[date]) -> int: + return multiplier * ((end_date - start_date).days + 1) + + model = span_days(multiplier=2) + assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-03").value == 6 + assert model(RichRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 4), label="x")).value == 8 + + +def test_declared_context_type_introspection_reports_effective_field_type(): + class ModeContext(ContextBase): + mode: Literal["a"] + + @Flow.model(context_type=ModeContext) + def choose(mode: FromContext[str]) -> str: + return mode + + model = choose() + + assert model.flow.context_inputs == {"mode": Literal["a"]} + assert model.flow.unbound_inputs == {"mode": Literal["a"]} + assert model.flow.compute(mode="a").value == "a" + with pytest.raises(ValidationError): + model.flow.compute(mode="b") + + +def test_context_type_validation_applies_to_resolved_contextual_values(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.compute(a=2, b=1) + + with pytest.raises(ValueError, match="a must be <= b"): + add(a=2, b=1) + + +def test_context_type_validates_construction_time_contextual_defaults_early(): + class PositiveContext(ContextBase): + x: int = Field(gt=0) + + @Flow.model(context_type=PositiveContext) + def identity(x: FromContext[int]) -> int: + return x + + with pytest.raises(ValueError, match="greater than 0"): + identity(x=-1) + + +def test_context_type_validates_partial_contextual_defaults_early(): + class PositivePairContext(ContextBase): + x: int = Field(gt=0) + y: int + + @Flow.model(context_type=PositivePairContext) + def add(x: FromContext[int], y: FromContext[int]) -> int: + return x + y + + with pytest.raises(ValueError, match="greater than 0"): + add(x=-1) + + +def test_context_type_validates_static_with_context_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(a=2, b=1) + + +def test_context_type_validates_chained_static_with_context_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(a=2).flow.with_context(b=1) + + +def test_context_type_validates_partial_static_with_context_overrides_early(): + class PositivePairContext(ContextBase): + x: int = Field(gt=0) + y: int + + @Flow.model(context_type=PositivePairContext) + def add(x: FromContext[int], y: FromContext[int]) -> int: + return x + y + + with pytest.raises(ValueError, match="greater than 0"): + add().flow.with_context(x=-1) + + +def test_context_type_validates_static_field_transform_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int] = 1) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(a=static_bad()) + + +def test_context_type_validates_static_patch_transform_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int] = 1) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(static_patch()) + + +def test_context_named_parameters_are_just_regular_parameters(): + @Flow.model + def loader(context: DateRangeContext, source: str = "db") -> GenericResult[str]: + return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") + + model = loader(context=DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)), source="api") + assert model.flow.bound_inputs["context"] == DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)) + assert model.flow.context_inputs == {} + assert model.flow.compute().value == "api:2024-01-01:2024-01-02" + + with pytest.raises(TypeError, match="Missing regular parameter\\(s\\) for loader: context"): + loader(source="api").flow.compute(start_date="2024-01-01", end_date="2024-01-02") + + +def test_auto_unwrap_defaults_to_false_for_auto_wrapped_results(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + result = model.flow.compute(b=5) + assert model.result_type == GenericResult[int] + assert inspect.signature(type(model).__call__).return_annotation == GenericResult[int] + assert type(result) is GenericResult[int] + assert repr(result) == "GenericResult[int](value=15)" + assert result.value == 15 + + +def test_compute_does_not_unwrap_explicit_generic_result_returns(): + @Flow.model + def load(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = load() + result = model.flow.compute(value=3) + assert model.result_type == GenericResult[int] + assert type(result) is GenericResult[int] + assert repr(result) == "GenericResult[int](value=6)" + assert result.value == 6 + + +def test_auto_unwrap_can_be_enabled_for_auto_wrapped_results(): + @Flow.model(auto_unwrap=True) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + assert add(a=10).flow.compute(b=5) == 15 + + +def test_auto_unwrap_only_affects_external_compute_results(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 2 + + @Flow.model(auto_unwrap=True) + def add(left: int, bonus: FromContext[int]) -> int: + return left + bonus + + model = add(left=source()) + assert model.flow.compute(FlowContext(value=4, bonus=3)) == 11 + + +def test_auto_wrap_validates_return_type(): + @Flow.model + def bad(x: FromContext[int]) -> int: + return "oops" + + with pytest.raises(ValidationError, match=r"GenericResult\[int\]"): + bad().flow.compute(x=1) + + +def test_auto_wrap_coerces_compatible_return(): + @Flow.model + def coerce(x: FromContext[int]) -> float: + return 3 + + result = coerce().flow.compute(x=1) + assert result.value == 3.0 + assert isinstance(result.value, float) + + +def test_model_base_allows_custom_callable_model_subclass(): + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @model_validator(mode="after") + def _validate_multiplier(self): + if self.multiplier <= 0: + raise ValueError("multiplier must be positive") + return self + + def scaled(self, value: int) -> int: + return value * self.multiplier + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + assert isinstance(model, CustomFlowBase) + assert model.multiplier == 3 + assert model.scaled(4) == 12 + assert model.flow.compute(b=5).value == 15 + + with pytest.raises(ValueError, match="multiplier must be positive"): + add(a=10, multiplier=0) + + +def test_model_base_must_be_callable_model_subclass(): + with pytest.raises(TypeError, match="model_base must be a CallableModel subclass"): + + @Flow.model(model_base=int) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + +def test_context_named_regular_parameter_can_coexist_with_from_context(): + @Flow.model + def mixed(context: SimpleContext, y: FromContext[int]) -> int: + return context.value + y + + model = mixed(context=SimpleContext(value=10)) + assert model.flow.bound_inputs == {"context": SimpleContext(value=10)} + assert model.flow.context_inputs == {"y": int} + assert model.flow.compute(y=5).value == 15 + + +@pytest.mark.parametrize("reserved_name", ["flow", "meta", "context_type", "result_type"]) +def test_flow_model_rejects_reserved_parameter_names(reserved_name): + namespace = {"Flow": Flow, "FromContext": FromContext} + exec( + f"def bad({reserved_name}: str, value: FromContext[int]) -> str:\n return str(value)\n", + namespace, + ) + + with pytest.raises(TypeError, match=f"Parameter name\\(s\\) '{reserved_name}' are reserved"): + Flow.model(namespace["bad"]) + + +def test_context_args_keyword_is_removed(): + with pytest.raises(TypeError, match="context_args=... has been removed"): + + @Flow.model(context_args=["x"]) + def bad(x: int) -> int: + return x + + +def test_context_type_requires_from_context(): + with pytest.raises(TypeError, match="context_type=... requires FromContext"): + + @Flow.model(context_type=DateRangeContext) + def bad(x: int) -> int: + return x + + +def test_lazy_dependency_remains_lazy(): + calls = {"source": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value * 10 + + @Flow.model + def choose(value: int, lazy_value: Lazy[int], threshold: FromContext[int]) -> int: + if value > threshold: + return value + return lazy_value() + + eager = choose(value=50, lazy_value=source()) + assert eager.flow.compute(FlowContext(value=3, threshold=10)).value == 50 + assert calls["source"] == 0 + + deferred = choose(value=5, lazy_value=source()) + assert deferred.flow.compute(FlowContext(value=3, threshold=10)).value == 30 + assert calls["source"] == 1 + + +def test_lazy_parameter_requires_dependency_binding(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def choose(lazy_value: Lazy[int]) -> int: + return lazy_value() + + with pytest.raises(TypeError, match="Lazy"): + choose(lazy_value=1) + + assert choose(lazy_value=source()).flow.compute(value=3).value == 3 + + +def test_lazy_parameter_rejects_literal_function_default(): + with pytest.raises(TypeError, match="Lazy"): + + @Flow.model + def choose(lazy_value: Lazy[int] = 1) -> int: + return lazy_value() + + +def test_lazy_runtime_helper_is_removed(): + @Flow.model + def source(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="Lazy\\(model\\)\\(\\.\\.\\.\\) has been removed"): + Lazy(source()) + + +def test_lazy_and_from_context_combination_is_rejected(): + with pytest.raises(TypeError, match="cannot combine Lazy"): + + @Flow.model + def bad(x: Lazy[FromContext[int]]) -> int: + return x() + + +def test_auto_wrap_and_serialization_roundtrip(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + assert restored.flow.bound_inputs == {"a": 10} + assert restored.flow.unbound_inputs == {"b": int} + assert restored.flow.compute(b=5).value == 15 + + +def test_generated_models_cloudpickle_roundtrip(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = rcploads(rcpdumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 + + +def test_generated_models_plain_pickle_roundtrip(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = pickle.loads(pickle.dumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 + + +def test_generated_model_direct_call_plain_pickle_uses_serialized_factory(monkeypatch): + module = ModuleType("ccflow_test_direct_model") + + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + multiply.__module__ = module.__name__ + multiply.__qualname__ = multiply.__name__ + module.multiply = multiply + monkeypatch.setitem(sys.modules, module.__name__, module) + + factory = Flow.model(multiply) + assert not hasattr(module.multiply, "_generated_model") + + model = factory(a=6) + restored = pickle.loads(pickle.dumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 + + +def test_generated_models_cloudpickle_preserves_unset_validation_sentinel(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = rcploads(rcpdumps(model, protocol=5)) + param = type(restored).__flow_model_config__.contextual_params[0] + + assert param.context_validation_annotation is flow_model_module._UNSET + assert param.validation_annotation is int + + +def test_generated_model_pydantic_roundtrip_via_base_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + from ccflow import BaseModel + + model = add(a=10) + dumped = model.model_dump(mode="python") + + # BaseModel.model_validate uses the _target_ field (PyObjectPath) to + # reconstruct the correct generated class. + restored = BaseModel.model_validate(dumped) + assert type(restored) is type(model) + assert restored.flow.compute(b=5).value == 15 + + +def test_generated_model_json_roundtrip(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + from ccflow import BaseModel + + model = add(a=10, b=3) + json_str = model.model_dump_json() + restored = BaseModel.model_validate_json(json_str) + assert type(restored) is type(model) + assert restored.flow.compute().value == 13 + assert restored.flow.compute(b=7).value == 17 + + +def test_importable_generated_model_uses_stable_module_path_for_type_serialization(): + model = basic_loader(source="library", multiplier=3) + stable_path = f"{__name__}._basic_loader_Model" + + assert getattr(sys.modules[__name__], "_basic_loader_Model") is type(model) + assert "__ccflow_import_path__" not in type(model).__dict__ + assert str(PyObjectPath.validate(type(model))) == stable_path + assert str(model.model_dump(mode="python")["type_"]) == stable_path + + +def test_importable_generated_model_json_roundtrip_cross_process(): + model = basic_loader(source="library", multiplier=3) + payload = model.model_dump_json() + script = ( + "from ccflow import BaseModel\n" + f"data = {payload!r}\n" + "model = BaseModel.model_validate_json(data)\n" + "result = model.flow.compute(value=4)\n" + "assert result.value == 12, f'Expected 12, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process JSON roundtrip failed:\n{result.stderr}" + + +def test_importable_bound_model_context_transform_json_roundtrip_cross_process(): + model = basic_loader(source="library", multiplier=3).flow.with_context(value=increment_b(amount=3)) + payload = model.model_dump_json() + script = ( + "from ccflow import BaseModel\n" + f"data = {payload!r}\n" + "model = BaseModel.model_validate_json(data)\n" + "result = model.flow.compute(b=1)\n" + "assert result.value == 12, f'Expected 12, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process bound-model JSON roundtrip failed:\n{result.stderr}" + + +def test_dependency_graph_cloudpickle_roundtrip(): + from ccflow.evaluators import get_dependency_graph + + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, penalty: FromContext[int]) -> int: + return x + penalty + + model = root(x=source()) + ctx = FlowContext(value=10, penalty=1) + graph = get_dependency_graph(model.__call__.get_evaluation_context(model, ctx)) + + restored = rcploads(rcpdumps(graph)) + assert restored.root_id == graph.root_id + assert set(restored.graph.keys()) == set(graph.graph.keys()) + assert set(restored.ids.keys()) == set(graph.ids.keys()) + + # The restored graph's evaluation contexts should still be functional + for key in graph.ids: + original_ec = graph.ids[key] + restored_ec = restored.ids[key] + assert type(restored_ec.model).__name__ == type(original_ec.model).__name__ + assert restored_ec.fn == original_ec.fn + + +def test_with_context_validates_static_override_types(): + """Static value type mismatch should be caught when context specs are normalized.""" + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="with_context\\(\\)"): + add(a=1).flow.with_context(b="not_an_int") + + +def test_context_transform_serializes_import_path_and_bound_args(): + binding = increment_b(amount=3) + assert isinstance(binding, flow_model_module.ContextTransform) + assert binding.kind == "context_transform" + assert binding.path is not None + assert binding.serialized_config is None + assert binding.bound_args == {"amount": 3} + assert str(binding.path).endswith(".increment_b") + + +def test_context_transform_rejects_none_for_required_param(): + with pytest.raises(TypeError, match="Context transform argument"): + increment_b(amount=None) + + +def test_context_transform_rejects_lazy_params(): + with pytest.raises(TypeError, match="does not support Lazy"): + Flow.context_transform(lazy_context_transform_for_rejection) + + +def test_context_transform_direct_call_uses_serialized_payload_when_original_binding_is_plain(monkeypatch): + module = ModuleType("ccflow_test_direct_context_transform") + + def increment(value: FromContext[int]) -> int: + return value + 1 + + increment.__module__ = module.__name__ + increment.__qualname__ = increment.__name__ + module.increment = increment + monkeypatch.setitem(sys.modules, module.__name__, module) + + transform_factory = Flow.context_transform(increment) + binding = transform_factory() + + assert binding.path is None + assert binding.serialized_config is not None + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=binding) + restored = pickle.loads(pickle.dumps(bound, protocol=5)) + assert restored.flow.compute(value=4).value == 15 + + +def test_context_transform_supports_nested_functions_with_serialized_payload(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + @Flow.context_transform + def nested_transform(b: FromContext[int], amount: int) -> int: + return b + amount + + binding = nested_transform(amount=3) + assert binding.path is None + assert binding.serialized_config is not None + + bound = add(a=1).flow.with_context(b=binding) + restored = rcploads(rcpdumps(bound)) + assert restored.flow.compute(b=4).value == 8 + + +def test_context_transform_supports_non_importable_main_functions_with_serialized_payload(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + def main_transform(value: FromContext[int]) -> int: + return value + 1 + + main_transform.__module__ = "__main__" + main_transform.__qualname__ = main_transform.__name__ + + transformed = Flow.context_transform(main_transform) + binding = transformed() + assert binding.path is None + assert binding.serialized_config is not None + + bound = add(a=1).flow.with_context(b=binding) + restored = rcploads(rcpdumps(bound)) + assert restored.flow.compute(value=4).value == 6 + + +def test_context_transform_nested_function_survives_ray_task(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + @Flow.context_transform + def nested_transform(b: FromContext[int], amount: int) -> int: + return b + amount + + bound = add(a=1).flow.with_context(b=nested_transform(amount=3)) + + @ray.remote + def run_model(model): + return model.flow.compute(b=4).value + + with ray.init(num_cpus=1): + assert ray.get(run_model.remote(bound)) == 8 + + +def test_with_context_rejects_raw_callables(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="no longer accepts raw callables"): + add(a=1).flow.with_context(b=lambda ctx: ctx.b + 1) + + +def test_with_context_rejects_wrong_transform_position(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date + end_date + + with pytest.raises(TypeError, match="Field context transforms must be passed by keyword"): + load().flow.with_context(increment_b(amount=1)) + + with pytest.raises(TypeError, match="Patch transforms must be passed positionally"): + load().flow.with_context(start_date=shift_integer_window(amount=10)) + + +def test_with_context_accepts_wrapped_mapping_patch_annotations(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date * 1000 + end_date + + annotated = load().flow.with_context(annotated_start_patch()) + optional = load().flow.with_context(optional_start_patch()) + + assert annotated.flow.compute(start_date=1, end_date=5).value == 2005 + assert optional.flow.compute(start_date=1, end_date=5).value == 3005 + + +def test_patch_then_keyword_override_precedence(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date * 1000 + end_date + + bound = load().flow.with_context(shift_integer_window(amount=10), start_date=100) + result = bound(FlowContext(start_date=1, end_date=2)) + assert result.value == 100_012 + + dumped = bound.model_dump(mode="json") + assert dumped["context_spec"]["patches"][0]["binding"]["bound_args"] == {"amount": 10} + assert dumped["context_spec"]["field_overrides"]["start_date"]["kind"] == "static_value" + + +def test_transforms_evaluate_against_original_runtime_context(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date * 1000 + end_date + + bound = load().flow.with_context( + shift_integer_window(amount=10), + start_date=bump_start_date(amount=100), + ) + + result = bound(FlowContext(start_date=1, end_date=2)) + assert result.value == 101_012 + + +def test_graph_integration_fanout_fanin(): + @Flow.model + def source(base: int, value: FromContext[int]) -> int: + return value + base + + @Flow.model + def scale(data: int, factor: int) -> int: + return data * factor + + @Flow.model + def merge(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + src = source(base=10) + left = scale(data=src, factor=2) + right = scale(data=src, factor=5) + model = merge(left=left, right=right) + + assert model.flow.compute(FlowContext(value=3, bonus=7)).value == ((3 + 10) * 2) + ((3 + 10) * 5) + 7 + + +def test_graph_integration_cycle_raises_cleanly(): + @Flow.model + def increment(x: int, n: FromContext[int]) -> int: + return x + n + + root = increment() + branch = increment(x=root) + object.__setattr__(root, "x", branch) + + with FlowOptionsOverride(options={"evaluator": GraphEvaluator()}): + with pytest.raises(graphlib.CycleError): + root.flow.compute(n=1) + + +def test_large_contextual_contract_stress(): + @Flow.model + def total( + base: int, + x1: FromContext[int], + x2: FromContext[int], + x3: FromContext[int], + x4: FromContext[int], + x5: FromContext[int], + x6: FromContext[int], + ) -> int: + return base + x1 + x2 + x3 + x4 + x5 + x6 + + model = total(base=10) + assert model.flow.context_inputs == {"x1": int, "x2": int, "x3": int, "x4": int, "x5": int, "x6": int} + assert model.flow.compute(x1=1, x2=2, x3=3, x4=4, x5=5, x6=6).value == 31 + + +def test_registry_integration_for_generated_models(): + registry = ModelRegistry.root().clear() + model = basic_loader(source="library", multiplier=3) + registry.add("loader", model) + + retrieved = registry["loader"] + assert isinstance(retrieved, CallableModel) + assert retrieved(SimpleContext(value=4)).value == 12 + + +def test_any_annotation_preserves_literal_strings(): + """A parameter typed Any should keep literal strings; registry should not steal them.""" + registry = ModelRegistry.root().clear() + dep_model = basic_loader(source="library", multiplier=1) + registry.add("my_key", dep_model) + + @Flow.model + def uses_any(x: Any, y: FromContext[int]) -> int: + return y if isinstance(x, str) else 999 + + model = uses_any(x="my_key") + result = model.flow.compute(y=3) + assert result.value == 3, "literal string should not be replaced by registry entry for Any-typed param" + + +def test_unexpected_type_adapter_errors_are_not_silently_swallowed(): + class BrokenSchema: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + raise RuntimeError("boom") + + @Flow.model + def bad(x: BrokenSchema, y: FromContext[int]) -> int: + del x, y + return 0 + + with pytest.raises(RuntimeError, match="boom"): + bad(x=object()) + + +def test_unexpected_type_validation_errors_are_not_rewritten(): + from pydantic_core import core_schema + + class BrokenValidation: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + del source, handler + + def validate(value): + del value + raise RuntimeError("boom") + + return core_schema.no_info_plain_validator_function(validate) + + @Flow.model + def bad(x: BrokenValidation, y: FromContext[int]) -> int: + del x, y + return 0 + + with pytest.raises(RuntimeError, match="boom"): + bad(x=object()) + + +def test_unexpected_type_hint_resolution_errors_propagate(monkeypatch): + def broken_get_type_hints(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(flow_model_module, "get_type_hints", broken_get_type_hints) + + def add(x: int) -> int: + return x + + with pytest.raises(RuntimeError, match="boom"): + Flow.model(add) + + +def test_generated_model_flow_api_introspection_and_execution(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + assert model.flow.context_inputs == {"b": int} + assert model.flow.bound_inputs == {"a": 10} + assert model.flow.unbound_inputs == {"b": int} + assert model.flow.compute(b=5).value == 15 + + +def test_generated_factory_signature_is_keyword_only_and_includes_model_base_fields(): + sig = inspect.signature(basic_loader) + + assert all(param.kind is inspect.Parameter.KEYWORD_ONLY for param in sig.parameters.values()) + assert list(sig.parameters) == ["source", "multiplier", "value"] + assert sig.parameters["source"].default is flow_model_module._UNSET_FLOW_INPUT + assert sig.parameters["multiplier"].default is flow_model_module._UNSET_FLOW_INPUT + assert sig.parameters["value"].default is flow_model_module._UNSET_FLOW_INPUT + + with pytest.raises(TypeError, match="positional"): + basic_loader("library", 3) + + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + custom_sig = inspect.signature(add) + assert custom_sig.parameters["a"].kind is inspect.Parameter.KEYWORD_ONLY + assert custom_sig.parameters["b"].kind is inspect.Parameter.KEYWORD_ONLY + assert custom_sig.parameters["multiplier"].kind is inspect.Parameter.KEYWORD_ONLY + assert custom_sig.parameters["multiplier"].default == 1 + + +def test_type_adapter_caches_are_bounded_and_clearable(monkeypatch): + monkeypatch.setattr(flow_model_module, "_TYPE_ADAPTER_CACHE_MAXSIZE", 2) + flow_model_module.clear_flow_model_caches() + + try: + for annotation in (int, str, float): + flow_model_module._type_adapter(annotation) + + assert list(flow_model_module._HASHABLE_TYPE_ADAPTER_CACHE) == [str, float] + + unhashable_annotations = ( + Annotated[int, []], + Annotated[str, []], + Annotated[float, []], + ) + for annotation in unhashable_annotations: + flow_model_module._type_adapter(annotation) + + assert len(flow_model_module._UNHASHABLE_TYPE_ADAPTER_CACHE) == 2 + assert [entry[0] for entry in flow_model_module._UNHASHABLE_TYPE_ADAPTER_CACHE.values()] == list(unhashable_annotations[-2:]) + + flow_model_module.clear_flow_model_caches() + assert not flow_model_module._HASHABLE_TYPE_ADAPTER_CACHE + assert not flow_model_module._UNHASHABLE_TYPE_ADAPTER_CACHE + finally: + flow_model_module.clear_flow_model_caches() + + +def test_plain_callable_flow_api_paths(): + class PlainModel(CallableModel): + @property + def context_type(self): + return SimpleContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + del context + return [] + + model = PlainModel() + + assert model.flow.context_inputs == {"value": int} + assert model.flow.unbound_inputs == {"value": int} + assert model.flow.bound_inputs == {} + assert model.flow.compute({"value": 3}).value == 3 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(SimpleContext(value=1), value=2) + + +def test_unhashable_annotations_still_validate(): + annotation = Annotated[int, []] + + @Flow.model + def add(x: annotation, y: FromContext[annotation]) -> int: + return x + y + + assert add(x="2").flow.compute(y="3").value == 5 + + +def test_compute_accepts_context_object_for_from_context_models(): + model = basic_loader(source="library", multiplier=3) + + assert model.flow.context_inputs == {"value": int} + assert model.flow.unbound_inputs == {"value": int} + assert model.flow.compute({"value": 4}).value == 12 + assert model.flow.compute(SimpleContext(value=5)).value == 15 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(SimpleContext(value=1), value=2) + + +def test_additional_validation_and_hint_fallback_paths(monkeypatch): + class MissingFieldContext(ContextBase): + start_date: date + + with pytest.raises(TypeError, match="must define fields for all FromContext parameters"): + + @Flow.model(context_type=MissingFieldContext) + def bad_missing(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class ExtraRequiredContext(ContextBase): + start_date: date + end_date: date + label: str + + with pytest.raises(TypeError, match="has required fields that are not declared as FromContext parameters"): + + @Flow.model(context_type=ExtraRequiredContext) + def bad_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class BadAnnotationContext(ContextBase): + value: str + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=BadAnnotationContext) + def bad_annotation(value: FromContext[int]) -> int: + return value + + with pytest.raises(TypeError, match="context_type must be a ContextBase subclass"): + + @Flow.model(context_type=int) + def bad_context_type(value: FromContext[int]) -> int: + return value + + @Flow.model + def source(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="cannot default to a CallableModel"): + + @Flow.model + def bad_default(value: FromContext[int] = source()) -> int: + return value + + with pytest.raises(TypeError, match="return type annotation"): + + @Flow.model + def missing_return(value: int): + return value + + with pytest.raises(TypeError, match="does not support positional-only parameter 'value'"): + + @Flow.model + def bad_positional_only(value: int, /, bonus: FromContext[int]) -> int: + return value + bonus + + with pytest.raises(TypeError, match="does not support variadic positional parameter 'values'"): + + @Flow.model + def bad_varargs(*values: int) -> int: + return sum(values) + + with pytest.raises(TypeError, match="does not support variadic keyword parameter 'values'"): + + @Flow.model + def bad_varkw(**values: int) -> int: + return sum(values.values()) + + @Flow.model + def keyword_only(value: int, *, bonus: FromContext[int]) -> int: + return value + bonus + + assert keyword_only(value=2).flow.compute(bonus=3).value == 5 + + @Flow.model + def keyword_only_context(*, context: SimpleContext, offset: int) -> int: + return context.value + offset + + assert keyword_only_context(context=SimpleContext(value=3), offset=4).flow.compute().value == 7 + + def missing_hints(*args, **kwargs): + raise AttributeError("missing hints") + + monkeypatch.setattr(flow_model_module, "get_type_hints", missing_hints) + + @Flow.model + def add(x: int, y: FromContext[int]) -> int: + return x + y + + assert add(x=1).flow.compute(y=2).value == 3 + + +def test_unresolved_forward_refs_do_not_silently_strip_from_context(): + namespace: dict[str, object] = {} + + with pytest.raises(NameError, match="MissingType"): + exec( + """ +from __future__ import annotations +from ccflow import Flow, FromContext + +@Flow.context_transform +def transform(a: MissingType, b: FromContext[int]) -> int: + return b +""", + namespace, + ) + + with pytest.raises(NameError, match="MissingType"): + exec( + """ +from __future__ import annotations +from ccflow import Flow, FromContext + +@Flow.model +def model(a: MissingType, b: FromContext[int]) -> int: + return b +""", + namespace, + ) + + +def test_context_type_validates_parameterized_annotations(): + class IntListContext(ContextBase): + vals: list[int] + + @Flow.model(context_type=IntListContext) + def total(vals: FromContext[list[int]]) -> int: + return sum(vals) + + assert total().flow.compute(vals=["1", "2"]).value == 3 + + class IntDictContext(ContextBase): + vals: dict[str, int] + + @Flow.model(context_type=IntDictContext) + def total_dict(vals: FromContext[dict[str, int]]) -> int: + return sum(vals.values()) + + assert total_dict().flow.compute(vals={"a": "1", "b": 2}).value == 3 + + class StrListContext(ContextBase): + vals: list[str] + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=StrListContext) + def bad(vals: FromContext[list[int]]) -> int: + return sum(vals) + + +def test_context_type_rejects_nullable_field_for_non_nullable_from_context(): + class OptionalValueContext(ContextBase): + value: int | None + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=OptionalValueContext) + def add_one(value: FromContext[int]) -> int: + return value + 1 + + +@pytest.mark.parametrize( + ("func_annotation", "context_annotation", "expected"), + [ + (int, int, True), + (int, str, False), + (int | None, int, True), + (int, int | None, False), + (int | None, int | None, True), + (list[int], list[str], False), + (list[int], list[int], True), + (int | str, int, True), + (int | str, int | None, False), + (int | None, type(None), True), + (int | None, int | str, False), + (int, int | str, False), + (object, int | str, True), + (Literal["a"], Literal["a"], True), + (Literal["a"], Literal["b"], False), + (str, Literal["a"], True), + (Literal["a", "b"], Literal["a"], True), + (Literal["a"], str, False), + (list[int], Literal["a"], False), + (Annotated[int, "meta"], int, True), + (dict[str, list[int]], dict[str, list[int]], True), + (dict[str, list[int]], dict[str, list[str]], False), + (list, list[int], False), + (tuple[int], tuple[int, str], False), + (Any, int, True), + (Any, str, True), + (int, Any, True), + (str, Any, True), + (Any, Any, True), + ], +) +def test_context_type_annotations_compatible_cases(func_annotation, context_annotation, expected): + assert flow_binding_module._context_type_annotations_compatible(func_annotation, context_annotation) is expected + + +def test_compute_forwards_options_with_custom_evaluator(): + calls = {"count": 0} + + @Flow.model + def counter(value: FromContext[int]) -> int: + calls["count"] += 1 + return value + + cache = MemoryCacheEvaluator() + model = counter() + + result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + + assert result1.value == 10 + assert result2.value == 10 + assert calls["count"] == 1 + + +def test_unset_flow_input_pickle_roundtrip_preserves_singleton(): + restored = pickle.loads(pickle.dumps(flow_model_module._UNSET_FLOW_INPUT, protocol=5)) + assert restored is flow_model_module._UNSET_FLOW_INPUT + + +def test_compute_forwards_options_with_graph_evaluator(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=source()) + + # GraphEvaluator evaluates in topo order; verify _options flows through + # and the graph evaluator is actually used (doesn't raise CycleError, computes correctly) + result = model.flow.compute( + FlowContext(value=3, bonus=7), + _options=FlowOptions(evaluator=GraphEvaluator()), + ) + + assert result.value == 37 + + +def test_compute_forwards_options_through_bound_model(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + bound = add(a=10).flow.with_context(b=5) + + result1 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True)) + result2 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True)) + + assert result1.value == 15 + assert result2.value == 15 + assert calls["count"] == 1 + + +def test_compute_forwards_options_for_plain_callable_model(): + calls = {"count": 0} + + class Counter(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.value + self.offset) + + cache = MemoryCacheEvaluator() + model = Counter(offset=5) + + result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + + assert result1.value == 15 + assert result2.value == 15 + assert calls["count"] == 1 + + +def test_bound_plain_callable_compute_applies_context_before_validation(): + calls = {"count": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.a + context.b) + + assert PlainSource().flow.with_context(a=1).flow.compute(b=2).value == 3 + assert calls["count"] == 1 + + +def test_bound_plain_callable_empty_with_context_preserves_optional_none_context(): + calls = {"count": 0} + + class OptionalContext(ContextBase): + value: int = 1 + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: Optional[OptionalContext] = None) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=0 if context is None else context.value) + + bound = PlainSource().flow.with_context() + + assert bound.flow.compute().value == 0 + assert bound(None).value == 0 + assert bound({"value": 5}).value == 5 + assert bound.flow.compute(value=5).value == 5 + assert bound.flow.compute(_options=FlowOptions(evaluator=GraphEvaluator())).value == 0 + assert calls["count"] == 5 + + +def test_bound_plain_callable_dynamic_context_transform_runs_before_validation(): + calls = {"count": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.a + context.b) + + assert PlainSource().flow.with_context(a=seed_plus_one()).flow.compute(seed=1, b=10).value == 12 + assert calls["count"] == 1 + + +def test_bound_plain_callable_direct_call_applies_context_before_validation(): + calls = {"count": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.a + context.b) + + bound = PlainSource().flow.with_context(a=1) + + assert bound(FlowContext(b=2)).value == 3 + assert bound.__deps__(FlowContext(b=2)) == [(bound.model, [RequiredContext(a=1, b=2)])] + assert calls["count"] == 1 + + +def test_bound_plain_callable_compute_preserves_bound_scoped_options(): + calls = {"source": 0, "evaluator": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + bound = PlainSource().flow.with_context(a=1) + + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(bound,)): + assert bound.flow.compute(b=2).value == 103 + + assert calls == {"source": 1, "evaluator": 1} + + +def test_bound_plain_callable_dependency_preserves_bound_scoped_options(): + calls = {"source": 0, "consumer": 0, "evaluator": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + @Flow.model + def consumer(x: int) -> int: + calls["consumer"] += 1 + return x + + bound = PlainSource().flow.with_context(a=1) + model = consumer(x=bound) + + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(bound,)): + assert model.flow.compute(b=2).value == 103 + + assert calls == {"source": 1, "consumer": 1, "evaluator": 1} + + +def test_bound_plain_callable_dependency_identity_ignores_unused_ambient_context(): + calls = {"source": 0, "consumer": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + @Flow.model + def consumer(x: int) -> int: + calls["consumer"] += 1 + return x + + cache = MemoryCacheEvaluator() + model = consumer(x=PlainSource().flow.with_context(a=1, b=2)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(unused="one").value == 3 + assert model.flow.compute(unused="two").value == 3 + + assert calls == {"source": 1, "consumer": 1} + + +def test_bound_dependency_identity_rewrites_dynamic_context_once(): + calls = {"source": 0, "consumer": 0} + + @Flow.model + def source(a: FromContext[int]) -> int: + calls["source"] += 1 + return a + + @Flow.model + def consumer(x: int) -> int: + calls["consumer"] += 1 + return x + + cache = MemoryCacheEvaluator() + model = consumer(x=source().flow.with_context(a=non_idempotent_a_step())) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(a=1).value == 2 + assert model.flow.compute(a=2).value == 3 + + assert calls == {"source": 2, "consumer": 2} + + +def test_lazy_bound_plain_callable_dependency_preserves_bound_scoped_options(): + calls = {"source": 0, "consumer": 0, "evaluator": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + @Flow.model + def consumer(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["consumer"] += 1 + return lazy_value() if use_lazy else 0 + + bound = PlainSource().flow.with_context(a=1) + model = consumer(lazy_value=bound) + + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(bound,)): + assert model.flow.compute(b=2, use_lazy=True).value == 103 + + assert calls == {"source": 1, "consumer": 1, "evaluator": 1} + + +def test_bound_flow_unbound_inputs_subtracts_static_context(): + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + return GenericResult(value=context.a + context.b) + + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + assert PlainSource().flow.with_context(a=1).flow.unbound_inputs == {"b": int} + assert add().flow.with_context(a=1).flow.unbound_inputs == {"b": int} + assert add().flow.with_context(a=static_bad()).flow.unbound_inputs == {"b": int} + assert add().flow.with_context(static_patch()).flow.unbound_inputs == {"b": int} + assert add().flow.with_context(a=1, b=2).flow.unbound_inputs == {} + + +def test_bound_flow_unbound_inputs_reflects_dynamic_field_transform_inputs(): + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + bound = add().flow.with_context(a=seed_plus_one()) + + assert bound.flow.compute(seed=1, b=10).value == 12 + assert bound.flow.context_inputs == {"b": int, "seed": int} + assert bound.flow.unbound_inputs == {"b": int, "seed": int} + + +def test_bound_flow_bound_inputs_include_static_context_bindings(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=1).flow.with_context(b=2) + + assert bound.flow.context_inputs == {} + assert bound.flow.unbound_inputs == {} + assert bound.flow.bound_inputs == {"a": 1, "b": 2} + + +def test_bound_flow_bound_inputs_drops_static_patch_after_dynamic_override(): + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + bound = add().flow.with_context(static_patch()).flow.with_context(a=seed_plus_one()) + + assert bound.flow.compute(seed=3, b=10).value == 14 + assert bound.flow.bound_inputs == {} + assert bound.flow.context_inputs == {"b": int, "seed": int} + assert bound.flow.unbound_inputs == {"b": int, "seed": int} + + +def test_generated_model_cache_ignores_unused_flow_context_fields(): + calls = {"source": 0, "root": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + calls["root"] += 1 + return x + bonus + + cache = MemoryCacheEvaluator() + model = root(x=source()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model(FlowContext(value=3, bonus=7, unused="one")).value == 37 + assert model(FlowContext(value=3, bonus=7, unused="two")).value == 37 + + assert calls == {"source": 1, "root": 1} + assert len(cache.cache) == 2 + + +def test_generated_model_cache_uses_effective_key_through_transparent_evaluator(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + evaluator = combine_evaluators(LoggingEvaluator(), cache) + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + assert model.flow.compute(b=1, unused="one").value == 11 + assert model.flow.compute(b=1, unused="two").value == 11 + + assert calls["count"] == 1 + assert len(cache.cache) == 1 + + +def test_cache_key_effective_option_exposes_generated_model_identity(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=source()) + ctx1 = FlowContext(value=3, bonus=7, unused="one") + ctx2 = FlowContext(value=3, bonus=7, unused="two") + eval1 = model.__call__.get_evaluation_context(model, ctx1) + eval2 = model.__call__.get_evaluation_context(model, ctx2) + cache = MemoryCacheEvaluator() + + assert cache_key(eval1) != cache_key(eval2) + assert cache_key(eval1, effective=True) == cache_key(eval2, effective=True) + assert cache.key(eval1) == cache_key(eval1, effective=True) + + +def test_cache_key_effective_option_preserves_plain_callable_structural_identity(): + calls = {"count": 0} + + class Counter(CallableModel): + @Flow.call + def __call__(self, context: FlowContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.value) + + model = Counter() + eval1 = model.__call__.get_evaluation_context(model, FlowContext(value=10, unused="one")) + eval2 = model.__call__.get_evaluation_context(model, FlowContext(value=10, unused="two")) + + assert cache_key(eval1) != cache_key(eval2) + assert cache_key(eval1, effective=True) == cache_key(eval1) + assert cache_key(eval2, effective=True) == cache_key(eval2) + + +def test_generated_model_effective_cache_key_includes_behavior_token(monkeypatch): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + def helper_v1(): + return 1 + + def helper_v2(): + return 2 + + model = add(a=10) + model_type = type(model) + context = model.__call__.get_evaluation_context(model, FlowContext(b=1, unused="same"), _options={"cacheable": True}) + cache = MemoryCacheEvaluator() + + monkeypatch.setattr(model_type, "__ccflow_tokenizer_deps__", [helper_v1], raising=False) + key1 = cache.key(context) + + monkeypatch.setattr(model_type, "__ccflow_tokenizer_deps__", [helper_v2], raising=False) + monkeypatch.delattr(model_type, "__ccflow_tokenizer_cache__", raising=False) + key2 = cache.key(context) + + assert key1 != key2 + + +def test_generated_model_effective_cache_key_includes_opaque_evaluator_behavior(): + class OpaqueA(EvaluatorBase): + tag: str = "same" + + def __call__(self, context: ModelEvaluationContext): + return context() + + class OpaqueB(EvaluatorBase): + tag: str = "same" + + def __call__(self, context: ModelEvaluationContext): + result = context() + return result + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + inner = model.__call__.get_evaluation_context(model, FlowContext(b=1, unused="same"), _options={"cacheable": True}) + cache = MemoryCacheEvaluator() + + key1 = cache.key(ModelEvaluationContext(model=OpaqueA(), context=inner)) + key2 = cache.key(ModelEvaluationContext(model=OpaqueB(), context=inner)) + + assert key1 != key2 + + +def test_generated_model_cache_changes_when_consumed_context_field_changes(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(b=1, unused="same").value == 11 + assert model.flow.compute(b=2, unused="same").value == 12 + + assert calls["count"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_model_cache_changes_when_regular_literal_input_changes(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert add(a=10).flow.compute(b=1).value == 11 + assert add(a=20).flow.compute(b=1).value == 21 + + assert calls["count"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_model_cache_uses_structural_key_with_nontransparent_evaluator(): + calls = {"count": 0} + + class OpaqueEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + return context() + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + evaluator = combine_evaluators(OpaqueEvaluator(), cache) + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(b=1, unused="plain").value == 11 + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + assert model.flow.compute(b=1, unused="one").value == 11 + assert model.flow.compute(b=1, unused="two").value == 11 + + assert calls["count"] == 3 + assert len(cache.cache) == 3 + + +def test_generated_model_cache_does_not_ignore_context_read_by_nontransparent_evaluator(): + class AddAmbient(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + result = context() + return result.model_copy(update={"value": result.value + context.context.unused}) + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + cache = MemoryCacheEvaluator() + evaluator = combine_evaluators(AddAmbient(), cache) + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + assert model.flow.compute(b=1, unused=100).value == 111 + assert model.flow.compute(b=1, unused=200).value == 211 + + assert len(cache.cache) == 2 + + +def test_generated_model_cache_uses_effective_key_when_result_validation_disabled(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True, "validate_result": False}): + assert model.flow.compute(b=1, unused="one").value == 11 + assert model.flow.compute(b=1, unused="two").value == 11 + + assert calls["count"] == 1 + assert len(cache.cache) == 1 + + +def test_generated_model_cache_key_preserves_result_validation_option(): + calls = {"count": 0} + + @Flow.model + def raw_result(value: FromContext[int]) -> GenericResult[int]: + calls["count"] += 1 + return {"value": str(value)} + + cache = MemoryCacheEvaluator() + model = raw_result() + + with FlowOptionsOverride(options=FlowOptions(evaluator=cache, cacheable=True, validate_result=False)): + first = model.flow.compute(value=3, unused="one") + + with FlowOptionsOverride(options=FlowOptions(evaluator=cache, cacheable=True, validate_result=True)): + second = model.flow.compute(value=3, unused="two") + + assert first == {"value": "3"} + assert second == GenericResult[int](value=3) + assert calls["count"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_model_cache_ignores_unresolved_unused_lazy_dependency_context(): + calls = {"choose": 0} + + @Flow.model + def source(missing: FromContext[int]) -> int: + return missing * 10 + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=source()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, unused="two").value == 7 + + assert calls["choose"] == 1 + assert len(cache.cache) == 1 + + +def test_unused_lazy_plain_dependency_defers_missing_context_validation(): + calls = {"source": 0, "choose": 0} + + class RequiredContext(ContextBase): + missing: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.missing * 10) + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=PlainSource()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, unused="two").value == 7 + with pytest.raises(ValidationError): + model.flow.compute(use_lazy=True) + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 1 + + +def test_unused_lazy_bound_plain_dependency_applies_static_context_identity(): + calls = {"source": 0, "choose": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=PlainSource().flow.with_context(a=1)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, unused="two").value == 7 + with pytest.raises(ValidationError): + model.flow.compute(use_lazy=True) + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 1 + + +def test_unused_lazy_bound_plain_dependency_dynamic_transform_can_leave_missing_context(): + calls = {"source": 0, "choose": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=PlainSource().flow.with_context(b=seed_plus_one())) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, seed=1, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, seed=1, unused="two").value == 7 + + assert calls == {"source": 0, "choose": 1} + assert len(cache.cache) == 1 + + +def test_unused_lazy_bound_plain_dependency_fully_resolved_identity_ignores_ambient_context(): + calls = {"source": 0, "choose": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=PlainSource().flow.with_context(a=1, b=2)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, unused="two").value == 7 + + assert calls == {"source": 0, "choose": 1} + assert len(cache.cache) == 1 + + +def test_unused_lazy_bound_dependency_uses_partial_context_identity(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=source().flow.with_context(a=1)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, a=100).value == 7 + assert model.flow.compute(use_lazy=False, a=200).value == 7 + with pytest.raises(TypeError, match="Missing contextual input"): + model.flow.compute(use_lazy=True, a=300) + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 1 + + +def test_unused_lazy_resolved_dependency_identity_is_conservative(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int]) -> int: + calls["source"] += 1 + return a + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert choose(x=7, lazy_value=source().flow.with_context(a=1)).flow.compute(use_lazy=False).value == 7 + assert choose(x=7, lazy_value=source().flow.with_context(a=2)).flow.compute(use_lazy=False).value == 7 + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 2 + + +def test_used_lazy_bound_dependency_identity_applies_dynamic_context_transform(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + + @Flow.model + def choose(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else 0 + + cache = MemoryCacheEvaluator() + model = choose(lazy_value=source().flow.with_context(a=seed_plus_one())) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=True, seed=1, b=10).value == 12 + assert model.flow.compute(use_lazy=True, seed=2, b=10).value == 13 + + assert calls == {"source": 2, "choose": 2} + assert len(cache.cache) == 6 + + +def test_unused_lazy_bound_dependency_records_missing_transform_context(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=source().flow.with_context(a=seed_plus_one())) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, unused="two").value == 7 + assert model.flow.compute(use_lazy=False, a=100, b=1).value == 7 + assert model.flow.compute(use_lazy=False, a=200, b=1).value == 7 + + assert calls == {"source": 0, "choose": 1} + assert len(cache.cache) == 1 + + +def test_used_lazy_generated_dependency_identity_respects_contextual_defaults(): + calls = {"dep": 0, "source": 0, "choose": 0} + + @Flow.model + def dep(v: FromContext[int]) -> int: + calls["dep"] += 1 + return v + + @Flow.model + def source(d: int, a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + d + + @Flow.model + def choose(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else 0 + + cache = MemoryCacheEvaluator() + model = choose(lazy_value=source(d=dep(), a=1)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=True, b=10, v=3).value == 14 + assert model.flow.compute(use_lazy=True, b=10, v=999).value == 1010 + + assert calls == {"dep": 2, "source": 2, "choose": 2} + + +def test_generated_model_cache_distinguishes_unresolved_lazy_dependency_models(): + calls = {"choose": 0} + + @Flow.model + def source_one(missing: FromContext[int]) -> int: + return missing * 10 + + @Flow.model + def source_two(missing: FromContext[int]) -> int: + return missing * 100 + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert choose(x=7, lazy_value=source_one()).flow.compute(use_lazy=False).value == 7 + assert choose(x=7, lazy_value=source_two()).flow.compute(use_lazy=False).value == 7 + + assert calls["choose"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_model_diamond_cache_reuses_shared_source_and_ignores_unused_fields(): + calls = {"source": 0, "left": 0, "right": 0, "root": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value + 10 + + @Flow.model + def left(x: int) -> int: + calls["left"] += 1 + return x * 2 + + @Flow.model + def right(x: int) -> int: + calls["right"] += 1 + return x * 5 + + @Flow.model + def root(left_value: int, right_value: int, bonus: FromContext[int]) -> int: + calls["root"] += 1 + return left_value + right_value + bonus + + shared = source() + model = root(left_value=left(x=shared), right_value=right(x=shared)) + cache = MemoryCacheEvaluator() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model(FlowContext(value=3, bonus=7, unused="one")).value == 98 + assert model(FlowContext(value=3, bonus=7, unused="two")).value == 98 + + assert calls == {"source": 1, "left": 1, "right": 1, "root": 1} + assert len(cache.cache) == 4 + + +def test_bound_generated_sibling_dependencies_keep_distinct_rewritten_contexts_with_graph_cache(): + calls = {"source": 0, "root": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value + + @Flow.model + def root(left: int, right: int) -> int: + calls["root"] += 1 + return left + right + + cache = MemoryCacheEvaluator() + evaluator = combine_evaluators(cache, GraphEvaluator()) + shared = source() + model = root(left=shared.flow.with_context(value=1), right=shared.flow.with_context(value=2)) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + assert model.flow.compute(value=99, unused="ambient").value == 3 + + assert calls == {"source": 2, "root": 1} + assert len(cache.cache) >= 3 + + +def test_generated_dependency_graph_identity_ignores_unused_flow_context_fields(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=source()) + graph1 = get_dependency_graph(model.__call__.get_evaluation_context(model, FlowContext(value=3, bonus=7, unused="one"))) + graph2 = get_dependency_graph(model.__call__.get_evaluation_context(model, FlowContext(value=3, bonus=7, unused="two"))) + + assert graph1.root_id == graph2.root_id + assert set(graph1.graph.keys()) == set(graph2.graph.keys()) + assert set(graph1.ids.keys()) == set(graph2.ids.keys()) + + +def test_bound_generated_model_dependency_graph_has_no_self_loop(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=1).flow.with_context(b=2) + graph = get_dependency_graph(bound.__call__.get_evaluation_context(bound, FlowContext(b=99))) + + assert graph.root_id not in graph.graph[graph.root_id] + + +def test_bound_generated_model_dependency_graph_traverses_collapsed_child_deps(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + bound = root(x=source()).flow.with_context(bonus=1) + graph = get_dependency_graph(bound.__call__.get_evaluation_context(bound, FlowContext(value=2, bonus=99))) + + assert graph.root_id not in graph.graph[graph.root_id] + assert len(graph.ids) == 3 + assert len(graph.graph[graph.root_id]) == 1 + + +def test_bound_model_cache_follows_rewritten_context_not_ambient_source_fields(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + bound = add(a=10).flow.with_context(b=parity_bucket()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert bound(FlowContext(raw=1)).value == 11 + assert bound(FlowContext(raw=3)).value == 11 + + assert calls["count"] == 1 + assert len(cache.cache) == 2 + + +def test_bound_model_cache_respects_wrapped_model_scoped_evaluator(): + calls = {"add": 0, "evaluator": 0} + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["add"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + base = add(a=1) + bound = base.flow.with_context(b=2) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(base,)): + assert bound.flow.compute().value == 103 + assert bound.flow.compute().value == 3 + + assert calls == {"add": 2, "evaluator": 1} + + +def test_plain_callable_model_cache_remains_structural_for_flow_context(): + calls = {"count": 0} + + class Counter(CallableModel): + @Flow.call + def __call__(self, context: FlowContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.value) + + cache = MemoryCacheEvaluator() + model = Counter() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model(FlowContext(value=10, unused="one")).value == 10 + assert model(FlowContext(value=10, unused="two")).value == 10 + + assert calls["count"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_models_cross_process_pickle(): + """Module-level @Flow.model instances are deserializable in a separate process.""" + model = basic_loader(source="library", multiplier=3) + data = pickle.dumps(model, protocol=5) + encoded = base64.b64encode(data).decode() + script = ( + "import pickle, base64\n" + f"data = base64.b64decode('{encoded}')\n" + "model = pickle.loads(data)\n" + "from ccflow import FlowContext\n" + "result = model.flow.compute(value=4)\n" + "assert result.value == 12, f'Expected 12, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process unpickle failed:\n{result.stderr}" + + +def test_generated_models_cross_process_cloudpickle(): + """Module-level @Flow.model instances are deserializable via cloudpickle in a separate process.""" + from ray.cloudpickle import dumps as rcpdumps + + model = basic_loader(source="library", multiplier=3) + data = rcpdumps(model, protocol=5) + encoded = base64.b64encode(data).decode() + script = ( + "import base64\n" + "from ray.cloudpickle import loads as rcploads\n" + f"data = base64.b64decode('{encoded}')\n" + "model = rcploads(data)\n" + "from ccflow import FlowContext\n" + "result = model.flow.compute(value=4)\n" + "assert result.value == 12, f'Expected 12, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process cloudpickle failed:\n{result.stderr}" + + +def test_local_generated_models_cross_process_cloudpickle(): + """Local @Flow.model instances carry their generated class across processes.""" + from ray.cloudpickle import dumps as rcpdumps + + def make_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + return add(a=1) + + encoded = base64.b64encode(rcpdumps(make_model(), protocol=5)).decode() + script = ( + "import base64\n" + "from ray.cloudpickle import loads as rcploads\n" + f"data = base64.b64decode('{encoded}')\n" + "model = rcploads(data)\n" + "result = model.flow.compute(b=2)\n" + "assert result.value == 3, f'Expected 3, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process local cloudpickle failed:\n{result.stderr}" + + +def test_model_base_fields_visible_in_bound_inputs(): + """model_base fields that are explicitly set should appear in bound_inputs.""" + + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + assert model.flow.bound_inputs == {"a": 10, "multiplier": 3} + + # Default-only model_base field is NOT in bound_inputs + model_default = add(a=10) + assert model_default.flow.bound_inputs == {"a": 10} + + +def test_model_base_fields_rejected_by_compute(): + """compute() should reject kwargs matching model_base field names.""" + + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + with pytest.raises(TypeError, match="does not accept model configuration override\\(s\\): multiplier"): + model.flow.compute(b=5, multiplier=99) + + +def test_flow_model_public_exports_exclude_context_spec_models(): + assert "StaticValueSpec" not in flow_model_module.__all__ + assert "FieldContextSpec" not in flow_model_module.__all__ + assert "PatchContextSpec" not in flow_model_module.__all__ + assert not hasattr(ccflow, "StaticValueSpec") + assert not hasattr(ccflow, "FieldContextSpec") + assert not hasattr(ccflow, "PatchContextSpec") diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py new file mode 100644 index 0000000..6337c93 --- /dev/null +++ b/ccflow/tests/test_flow_model_hydra.py @@ -0,0 +1,135 @@ +"""Hydra integration tests for the FromContext-based Flow.model API.""" + +from datetime import date +from pathlib import Path + +from omegaconf import OmegaConf + +from ccflow import CallableModel, DateRangeContext, FlowContext, ModelRegistry + +from .test_flow_model import SimpleContext + +CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") + + +def setup_function(): + ModelRegistry.root().clear() + + +def teardown_function(): + ModelRegistry.root().clear() + + +def test_basic_loader_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + loader = registry["flow_loader"] + assert isinstance(loader, CallableModel) + assert loader(SimpleContext(value=10)).value == 50 + + +def test_basic_processor_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + processor = registry["flow_processor"] + assert processor(SimpleContext(value=42)).value == "value=42!" + + +def test_two_stage_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + transformer = registry["flow_transformer"] + assert transformer(SimpleContext(value=5)).value == 315 + + +def test_three_stage_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + stage3 = registry["flow_stage3"] + assert stage3(SimpleContext(value=10)).value == 90 + + +def test_diamond_dependency_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + aggregator = registry["diamond_aggregator"] + assert aggregator(SimpleContext(value=10)).value == 140 + + +def test_date_range_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + processor = registry["flow_date_processor"] + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + assert "normalized:" in result.value + assert "2024-01-09" in result.value + + +def test_from_context_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + loader = registry["contextual_loader_model"] + processor = registry["contextual_processor_model"] + + assert loader.flow.context_inputs == {"start_date": date, "end_date": date} + result = processor.flow.compute(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) + assert result.value == "output:data_source:2024-03-01 to 2024-03-31" + assert processor.data is loader + + +def test_registry_name_references_share_instances(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + transformer = registry["flow_transformer"] + source = registry["flow_source"] + assert transformer.source is source + + stage2 = registry["flow_stage2"] + stage3 = registry["flow_stage3"] + assert stage2.stage1_output is registry["flow_stage1"] + assert stage3.stage2_output is stage2 + + +def test_instantiate_with_omegaconf(): + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "generated_input", + "multiplier": 7, + }, + "contextual": { + "_target_": "ccflow.tests.test_flow_model.contextual_loader", + "source": "library", + }, + } + ) + + registry = ModelRegistry.root() + registry.load_config(cfg) + + assert registry["loader"](SimpleContext(value=3)).value == 21 + assert registry["contextual"].flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value == { + "source": "library", + "start_date": "2024-01-01", + "end_date": "2024-01-02", + } + + +def test_flow_context_execution_with_yaml_models(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + processor = registry["contextual_processor_model"] + result = processor.flow.compute(FlowContext(start_date=date(2024, 4, 1), end_date=date(2024, 4, 30))) + assert result.value == "output:data_source:2024-04-01 to 2024-04-30" diff --git a/docs/wiki/Flow-Model.md b/docs/wiki/Flow-Model.md new file mode 100644 index 0000000..6a9d389 --- /dev/null +++ b/docs/wiki/Flow-Model.md @@ -0,0 +1,488 @@ +# Flow Model + +`@Flow.model` turns a plain Python function into a real `CallableModel`. + +The design is intentionally narrow: + +- ordinary unmarked parameters are regular bound inputs, +- `FromContext[T]` marks the only runtime/contextual inputs, +- `@Flow.context_transform` defines reusable contextual rewrites, +- `.flow.compute(...)` is the execution entry point for the full DAG, +- `.flow.with_context(*patches, **field_overrides)` rewires contextual inputs on one dependency edge, +- upstream `CallableModel`s can still be passed as ordinary arguments. + +The goal is that a reader can look at one function signature and immediately +see: + +1. which values come from runtime context, +1. which values must be bound as regular configuration or dependencies, +1. how to rewrite contextual inputs for one branch of the graph. + +Generated models still plug into the existing evaluator, registry, cache, +Hydra, and serialization machinery — `@Flow.model` does not create a new +execution engine. + +## Core Example + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def foo(a: int, b: FromContext[int]) -> int: + return a + b + + +# Build an instance with a=11 bound, then supply b=12 at runtime: +configured = foo(a=11) +result = configured.flow.compute(b=12) +assert result.value == 23 # .value unwraps the GenericResult wrapper + +# Or create a different instance that stores b=12 as its contextual default: +prefilled = foo(a=11, b=12) +result = prefilled.flow.compute() +assert result.value == 23 +``` + +> **Note:** When the function returns a plain value (like `int` above) instead +> of a `ResultBase` subclass, `@Flow.model` automatically wraps it in +> `GenericResult`. Access the inner value with `.value`. + +This is the core contract: + +- `a` is a regular parameter — it must be bound at construction time, +- `b` is contextual because it is marked with `FromContext[int]` — it can come + from runtime context, a contextual default stored on the model instance, or a + function default, +- `.flow.compute(...)` may carry extra ambient context for upstream graph + branches, but it never binds regular parameters. + +Nothing is being mutated at execution time in the second example. +`prefilled = foo(a=11, b=12)` constructs a different model instance whose +contextual default for `b` is already `12`. Because `b` is still contextual, +incoming runtime context can still override that default. + +This means the following is **invalid**: + +```python +foo().flow.compute(a=11, b=12) +# TypeError: compute() cannot satisfy unbound regular parameter(s): a. +# Bind them at construction time; compute() only supplies runtime context. +``` + +`a` is not contextual, so it must be bound at construction time (`foo(a=11)`). +By contrast, extra ambient fields that are only needed by upstream +`with_context(...)` rewrites are allowed in `compute(**kwargs)` because +`@Flow.model` generated models use `FlowContext` (an open bag) as their +runtime context type. + +## Regular Parameters vs Contextual Parameters + +### Regular Parameters + +Regular parameters are the unmarked ones. + +They can be satisfied by: + +- a literal value, +- a default value from the function signature, +- an upstream `CallableModel`. + +When an upstream model is supplied, `@Flow.model` evaluates it with the current +context and passes the resolved value into the function. This is how you wire +stages together — just pass one model as an argument to another: + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def load_value(value: FromContext[int], offset: int) -> int: + return value + offset + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +# Wire load_value into add's 'a' parameter: +model = add(a=load_value(offset=5)) + +# At runtime, load_value runs first (value=7 + offset=5 = 12), +# then add runs (a=12 + b=12 = 24): +assert model.flow.compute(value=7, b=12).value == 24 +``` + +### Contextual Parameters + +Contextual parameters are the ones marked with `FromContext[...]`. + +They can be satisfied by: + +- runtime context, +- construction-time keyword arguments, stored as contextual defaults on the model instance, +- function defaults. + +They cannot be satisfied by `CallableModel` values. + +A construction-time value for a contextual parameter is still a default, not a +conversion into a regular bound parameter. + +Generated models reserve a few framework attribute names for the model API: +`flow`, `meta`, `context_type`, and `result_type`. Do not use these as +`@Flow.model` function parameter names. + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +model = add(a=10, b=5) +assert model.flow.compute().value == 15 +assert model.flow.compute(b=7).value == 17 +``` + +Contextual precedence is: + +1. branch-local `.flow.with_context(...)` rewrites, +1. incoming runtime context, +1. contextual defaults stored on the model instance, +1. function defaults. + +## `.flow.compute(...)` + +`.flow.compute(...)` is the ergonomic execution entry point for contextual +execution of the whole DAG. + +For generated `@Flow.model` stages it accepts either: + +- keyword arguments that become the ambient runtime context bag, or +- one context object. + +It does not accept both at the same time. + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +model = add(a=10) +assert model.flow.compute(b=5).value == 15 +assert model.flow.compute(FlowContext(b=6)).value == 16 +``` + +For `@Flow.model` generated models, the kwargs form is intentionally a DAG +entrypoint: it can include extra fields needed only by upstream transformed +dependencies. Regular parameters are still never read from runtime context. +`compute()` enforces two guardrails on keyword arguments: + +- If a key matches an **unbound** regular parameter, it raises early instead of + silently treating that value as configuration. +- If a key matches an **already-bound** regular parameter, it raises to prevent + accidental rebinding. Use a context object (`FlowContext`) when you need + ambient fields whose names collide with bound regular parameters. + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def source(value: FromContext[int]) -> int: + return value + + +@Flow.model +def add(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + +@Flow.context_transform +def add_offset(value: FromContext[int], amount: int) -> int: + return value + amount + + +base = source() +model = add( + left=base.flow.with_context(value=add_offset(amount=1)), + right=base.flow.with_context(value=add_offset(amount=10)), +) + +assert model.flow.context_inputs == {"bonus": int} +assert model.flow.compute(value=5, bonus=100).value == 121 +``` + +If a regular parameter is already bound on the root model and you need to pass +an ambient context field with the same name for upstream graph nodes, use a +context object instead of keyword arguments. The kwargs form rejects keys that +match already-bound regular parameters to prevent accidental rebinding: + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def source(a: FromContext[int]) -> int: + return a + + +@Flow.model +def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + +model = combine(a=100, left=source()) + +# The context object form lets ambient 'a=7' flow to upstream nodes +# while root 'a' stays bound to 100: +assert model.flow.compute(FlowContext(a=7, bonus=5)).value == 112 + +# The kwargs form rejects 'a' because it is a bound regular parameter: +# model.flow.compute(a=7, bonus=5) +# → TypeError: compute() does not accept regular parameter override(s): a. +``` + +`compute()` returns the same result object you would get from `model(context)`, +unless `auto_unwrap=True` is enabled for an auto-wrapped plain return type: + +```python +from ccflow import Flow, FromContext + + +@Flow.model(auto_unwrap=True) +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +result = add(a=10).flow.compute(b=5) +assert result == 15 +``` + +## `@Flow.context_transform` and `.flow.with_context(...)` + +`@Flow.context_transform` defines reusable, serializable contextual rewrites using the +same `FromContext[...]` language as `@Flow.model`. A transform's return type +determines how it can be used in `with_context()`: + +- **Patch transforms** return a `Mapping` (e.g. `dict[str, object]`) of + contextual field names to replacement values. They are passed as **positional + arguments** to `with_context()`. +- **Field transforms** return a single scalar value. They are passed as + **keyword arguments** to `with_context()`, keyed by the contextual field they + replace. + +```python +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model +def count_visitors(location: str, start_date: FromContext[date], end_date: FromContext[date]) -> int: + days = (end_date - start_date).days + 1 + return days * 12 + len(location) + + +@Flow.model +def visitor_delta(current: int, previous: int, start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"window": f"{start_date} -> {end_date}", "change": current - previous} + + +# Patch transform — shifts multiple fields together, passed positionally: +@Flow.context_transform +def previous_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +# Field transform — shifts one field, passed as a keyword: +@Flow.context_transform +def shift_end(end_date: FromContext[date], days: int) -> date: + return end_date - timedelta(days=days) + + +current = count_visitors(location="library") + +# Patch transform: shift both dates back 30 days +previous = current.flow.with_context(previous_window(days=30)) + +# Mix both forms: patch shifts start_date, keyword shifts end_date independently +custom = current.flow.with_context( + previous_window(days=30), + end_date=shift_end(days=7), # keyword override replaces end_date from the patch +) + +delta = visitor_delta(current=current, previous=previous) +result = delta(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))) +``` + +In this example, `current` and `previous` share the same underlying +`count_visitors` configuration but see different date windows at runtime. +`previous` uses a patch transform to shift both dates back 30 days. +`custom` demonstrates combining both forms: the patch sets both dates, then the +keyword override replaces `end_date` with a different shift. Keyword overrides +always apply last. + +Key rules: + +- `with_context()` only targets contextual fields, +- positional arguments must be patch transforms, +- keyword overrides may be literals or field transforms, +- raw anonymous callables are rejected; use named `@Flow.context_transform` helpers, +- transforms are branch-local — they only affect the wrapped dependency, not + the entire pipeline, +- patch results merge left-to-right, then keyword overrides apply last, +- every transform evaluates against the original incoming runtime context; if + multiple fields must move together, put that logic inside one patch + transform. + +Importable transform functions are stored by module path. Local, nested, +`__main__`, and notebook-defined transform functions are stored with an +embedded cloudpickle payload so bound models can still move through pickle and +Ray workers. For long-lived YAML/JSON configuration, prefer importable module +functions so the serialized config stays small and inspectable. + +## `context_type=...` + +When you want the `FromContext[...]` fields to match an existing nominal +context shape, use `context_type=...`: + +```python +from datetime import date + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model(context_type=DateRangeContext) +def count_visitors(location: str, start_date: FromContext[date], end_date: FromContext[date]) -> int: + days = (end_date - start_date).days + 1 + return days * 12 + len(location) +``` + +That preserves the primary `FromContext[...]` authoring model while letting +callers pass richer context objects whose relevant fields satisfy the declared +`context_type`. + +`context_type=...` is a validation/coercion contract for the named +`FromContext[...]` fields. Generated `@Flow.model` instances still expose +`FlowContext` as their runtime `context_type`. + +If the function genuinely needs the runtime context object itself inside the +function body on each call, use a normal `CallableModel` subclass instead of +`@Flow.model`. + +For class-based `CallableModel` methods that want to declare context fields as +keyword-only parameters, see `Flow.call(auto_context=...)` in +[Workflows](Workflows#flow-decorator). + +## Introspection APIs + +Generated models expose three useful introspection helpers: + +- `model.flow.context_inputs`: the full contextual contract, +- `model.flow.unbound_inputs`: the contextual fields still required at runtime, +- `model.flow.bound_inputs`: regular bound inputs plus any construction-time + contextual defaults. + +Example: + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + +model = add(a=10) +assert model.flow.context_inputs == {"b": int, "c": int} +assert model.flow.unbound_inputs == {"b": int} +assert model.flow.bound_inputs == {"a": 10} +``` + +## Lazy Dependencies + +`Lazy[T]` defers evaluation of an upstream dependency until the function body +explicitly calls it. This is useful when a dependency is expensive and only +needed conditionally: + +```python +from ccflow import Flow, FlowContext, FromContext, Lazy + + +@Flow.model +def load_value(value: FromContext[int]) -> int: + return value * 10 + + +@Flow.model +def maybe_use(current: int, fallback: Lazy[int], threshold: FromContext[int]) -> int: + if current > threshold: + return current # fallback is never evaluated + return fallback() # evaluate only when needed + + +model = maybe_use(current=50, fallback=load_value()) + +# current (50) > threshold (10), so load_value never runs: +assert model.flow.compute(value=3, threshold=10).value == 50 + +# current (5) <= threshold (10), so load_value runs (3 * 10 = 30): +model2 = maybe_use(current=5, fallback=load_value()) +assert model2.flow.compute(value=3, threshold=10).value == 30 +``` + +Without `Lazy[T]`, the upstream model would always run. With it, the function +controls exactly when (and whether) the dependency executes. + +## When To Use `@Flow.model` + +Use `@Flow.model` when: + +- the stage logic is naturally a plain function, +- you want ordinary arguments to look like ordinary Python function parameters, +- the contextual contract is small and explicit, +- the main goal is easy graph authoring on top of existing ccflow machinery. + +Use a hand-written class-based `CallableModel` when: + +- the model needs custom methods or substantial internal state, +- the full context object is the natural primary interface, +- the stage is no longer best expressed as one function and a small amount of + wiring. + +## Troubleshooting + +**`compute()` says a field is not contextual** + +That field is a regular parameter. Bind it at construction time. Only +`FromContext[...]` fields belong in `compute()`. + +**`with_context()` rejects a field** + +`with_context()` only rewrites contextual inputs. If you are trying to attach one +stage to another, pass the upstream model as a regular argument at construction +time. + +**A contextual parameter still shows up in `context_inputs` after I bound it** + +That is expected. `context_inputs` reports the full contextual contract. +`unbound_inputs` reports only the contextual values still needed at runtime. + +**A shared dependency runs more than once** + +`@Flow.model` authors the graph cleanly, but execution still follows the normal +ccflow evaluator path. If you need deduplication or graph scheduling, use the +appropriate evaluators and cache settings just as you would for class-based +`CallableModel`s. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..724c853 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -22,6 +22,10 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan `CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +`@Flow.model` is a plain-function API for defining `CallableModel`s with normal Python function signatures. +It keeps the same evaluator, registry, cache, and result machinery while making contextual execution more ergonomic via helpers like `.flow.compute(...)` and `.flow.with_context(...)`. +See [Flow Model](Flow-Model) for the full guide. + ## Model Registry A `ModelRegistry` is a named collection of models. diff --git a/docs/wiki/Workflows.md b/docs/wiki/Workflows.md index 6789d79..cfe03d9 100644 --- a/docs/wiki/Workflows.md +++ b/docs/wiki/Workflows.md @@ -355,6 +355,12 @@ The behavior of the `@Flow.call` decorator can be controlled in several ways: - by setting `options` in the `meta` attribute of the CallableModel - by passing `_options` directly to the `__call__` method +For hand-written `CallableModel` classes, `@Flow.call(auto_context=True)` is +also available when the `__call__` method should declare context fields as +keyword-only parameters instead of accepting one explicit context object. This +is an opt-in `Flow.call` feature; it does not add the dependency wiring or +`FromContext[...]` semantics provided by `@Flow.model`. + An example of the first one (model-specific options) is to disable validation of the result type on a particular model ```python diff --git a/docs/wiki/_Sidebar.md b/docs/wiki/_Sidebar.md index 6b20b2f..75c31c0 100644 --- a/docs/wiki/_Sidebar.md +++ b/docs/wiki/_Sidebar.md @@ -13,6 +13,7 @@ Notes for editors: - [Installation](Installation) - [Design Goals](Design-Goals) - [Key Features](Key-Features) +- [Flow Model](Flow-Model) - [First Steps](First-Steps) **Tutorials** diff --git a/examples/config/flow_model_hydra_builder_demo.yaml b/examples/config/flow_model_hydra_builder_demo.yaml new file mode 100644 index 0000000..9755e5d --- /dev/null +++ b/examples/config/flow_model_hydra_builder_demo.yaml @@ -0,0 +1,26 @@ +# Hydra config for examples/flow_model_hydra_builder_demo.py +# +# Pattern: +# - configure static pipeline specs in YAML +# - use model_alias to pass already-registered models into a plain Python builder +# - keep runtime context as runtime inputs, supplied later at execution time + +library_visitors: + _target_: examples.flow_model_hydra_builder_demo.count_visitors + location: library + +previous_week: + _target_: examples.flow_model_hydra_builder_demo.build_visitor_delta + current: + _target_: ccflow.compose.model_alias + model_name: library_visitors + label: previous_week + days_back: 7 + +previous_two_weeks: + _target_: examples.flow_model_hydra_builder_demo.build_visitor_delta + current: + _target_: ccflow.compose.model_alias + model_name: library_visitors + label: previous_two_weeks + days_back: 14 diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py new file mode 100644 index 0000000..f42ac17 --- /dev/null +++ b/examples/flow_model_example.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +"""Small `@Flow.model` example. + +Shows how to: + +1. define stages as plain Python functions, +2. compose stages by passing upstream models as ordinary arguments, +3. rewrite contextual inputs on one dependency edge with `.flow.with_context(...)`, +4. execute either as `model(context)` or `model.flow.compute(...)`. + +Run with: + python examples/flow_model_example.py +""" + +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model(context_type=DateRangeContext) +def count_visitors( + location: str, + start_date: FromContext[date], + end_date: FromContext[date], +) -> int: + """Return a deterministic visitor count for one date window.""" + days = (end_date - start_date).days + 1 + location_offset = sum(ord(ch) for ch in location) % 17 + week_index = (end_date.toordinal() - date(2024, 1, 1).toordinal()) // 7 + return days * 12 + location_offset + week_index + + +@Flow.model(context_type=DateRangeContext) +def visitor_delta( + current: int, + previous: int, + label: str, + start_date: FromContext[date], + end_date: FromContext[date], +) -> dict[str, object]: + """Return both visitor counts plus their difference.""" + return { + "label": label, + "window": f"{start_date} -> {end_date}", + "current": current, + "previous": previous, + "change": current - previous, + } + + +@Flow.context_transform +def shift_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + """Shift both date fields together.""" + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +def build_visitor_pipeline(location: str): + """Build one reusable visitor-count pipeline.""" + current = count_visitors(location=location) + previous = current.flow.with_context(shift_window(days=7)) + return visitor_delta( + current=current, + previous=previous, + label="previous_week", + ) + + +def main() -> None: + print("=" * 64) + print("Flow.model Example") + print("=" * 64) + + pipeline = build_visitor_pipeline(location="library") + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 7), + ) + + direct = pipeline(ctx) + computed = pipeline.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ) + + print("\nPipeline:") + print(" current input:", pipeline.current) + print(" previous input:", pipeline.previous) + + print("\nExecution:") + print(f" direct == computed: {direct == computed}") + + print("\nResult:") + for key, value in computed.value.items(): + print(f" {key}: {value}") + + +if __name__ == "__main__": + main() diff --git a/examples/flow_model_hydra_builder_demo.py b/examples/flow_model_hydra_builder_demo.py new file mode 100644 index 0000000..2fd3653 --- /dev/null +++ b/examples/flow_model_hydra_builder_demo.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +"""Hydra + Flow.model builder demo. + +This example shows a clean way to mix: + +1. ergonomic `@Flow.model` pipeline wiring in Python, and +2. Hydra / ModelRegistry configuration for static pipeline specs. + +The pattern is: + +- keep runtime context (`start_date`, `end_date`) as runtime inputs, +- use a plain Python builder function for graph construction, +- let Hydra instantiate that builder and register the returned model. + +Run with: + python examples/flow_model_hydra_builder_demo.py +""" + +from datetime import date, timedelta +from pathlib import Path + +from ccflow import CallableModel, DateRangeContext, Flow, FromContext, ModelRegistry + +CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" + + +@Flow.model(context_type=DateRangeContext) +def count_visitors(location: str, start_date: FromContext[date], end_date: FromContext[date]) -> int: + """Return a deterministic visitor count for one date window.""" + days = (end_date - start_date).days + 1 + location_offset = sum(ord(ch) for ch in location) % 17 + week_index = (end_date.toordinal() - date(2024, 1, 1).toordinal()) // 7 + return days * 12 + location_offset + week_index + + +@Flow.model(context_type=DateRangeContext) +def visitor_delta( + current: int, + previous: int, + label: str, + start_date: FromContext[date], + end_date: FromContext[date], +) -> dict: + """Return both visitor counts plus their difference.""" + return { + "label": label, + "window": f"{start_date} -> {end_date}", + "current": current, + "previous": previous, + "change": current - previous, + } + + +@Flow.context_transform +def shift_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + """Shift both date fields together.""" + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +def build_visitor_delta(current: CallableModel, *, label: str, days_back: int): + """Hydra-friendly builder that returns a configured visitor-count model.""" + previous = current.flow.with_context(shift_window(days=days_back)) + return visitor_delta( + current=current, + previous=previous, + label=label, + ) + + +def main() -> None: + registry = ModelRegistry.root() + registry.clear() + try: + registry.load_config_from_path(str(CONFIG_PATH), overwrite=True) + + previous_week = registry["previous_week"] + previous_two_weeks = registry["previous_two_weeks"] + + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 7), + ) + + print("=" * 68) + print("Hydra + Flow.model Builder Demo") + print("=" * 68) + print("\nLoaded from config:") + print(" library_visitors:", registry["library_visitors"]) + print(" previous_week:", previous_week) + print(" previous_two_weeks:", previous_two_weeks) + + previous_week_result = previous_week.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ).value + previous_two_weeks_result = previous_two_weeks(ctx).value + + print("\nPrevious week:") + for key, value in previous_week_result.items(): + print(f" {key}: {value}") + + print("\nPrevious two weeks:") + for key, value in previous_two_weeks_result.items(): + print(f" {key}: {value}") + finally: + registry.clear() + + +if __name__ == "__main__": + main() From 6ff5409d9f1b03fdbcec4da2eaf70b4fc117f7c5 Mon Sep 17 00:00:00 2001 From: Nijat K Date: Wed, 13 May 2026 22:03:17 -0400 Subject: [PATCH 2/8] Handle FlowContext and existing context case and clean up code Signed-off-by: Nijat K --- ccflow/_flow_model_binding.py | 248 +++- ccflow/callable.py | 97 +- ccflow/context.py | 17 +- ccflow/evaluators/common.py | 143 ++- ccflow/flow_model.py | 1252 +++++++++++++++------ ccflow/tests/config/conf_flow.yaml | 30 +- ccflow/tests/evaluators/test_common.py | 33 +- ccflow/tests/flow_model_hydra_fixtures.py | 96 ++ ccflow/tests/test_callable.py | 12 + ccflow/tests/test_context.py | 8 + ccflow/tests/test_flow_context.py | 81 +- ccflow/tests/test_flow_model.py | 1078 ++++++++++++++++-- ccflow/tests/test_flow_model_hydra.py | 6 +- docs/wiki/Flow-Model.md | 85 +- docs/wiki/Workflows.md | 7 +- examples/flow_model_example.py | 31 +- examples/flow_model_hydra_builder_demo.py | 30 +- 17 files changed, 2550 insertions(+), 704 deletions(-) create mode 100644 ccflow/tests/flow_model_hydra_fixtures.py diff --git a/ccflow/_flow_model_binding.py b/ccflow/_flow_model_binding.py index 2e67eac..1be6cdf 100644 --- a/ccflow/_flow_model_binding.py +++ b/ccflow/_flow_model_binding.py @@ -1,10 +1,16 @@ -"""Shared signature and context-contract analysis for Flow authoring APIs.""" +"""Shared signature and context-contract analysis for Flow authoring APIs. + +This module also owns the portable serialization format for an already-analyzed +``@Flow.model`` contract. That payload is used by generated-model pickle/Ray +restore and serialized context transforms so workers do not re-resolve function +annotations from the original defining scope. +""" import inspect from dataclasses import dataclass, field from functools import wraps -from types import UnionType -from typing import Annotated, Any, Callable, Dict, Literal, Optional, Tuple, Type, Union, get_args, get_origin, get_type_hints +from types import FunctionType, UnionType +from typing import Annotated, Any, Callable, Dict, Literal, NamedTuple, Optional, Tuple, Type, Union, get_args, get_origin, get_type_hints from .base import ContextBase, ResultBase from .context import FlowContext @@ -33,19 +39,19 @@ def _get_internal_sentinel(name: str) -> _InternalSentinel: _INTERNAL_SENTINELS = { "_UNSET": _InternalSentinel("_UNSET"), - "_REMOVED_CONTEXT_ARGS": _InternalSentinel("_REMOVED_CONTEXT_ARGS"), } _UNSET = _INTERNAL_SENTINELS["_UNSET"] -_REMOVED_CONTEXT_ARGS = _INTERNAL_SENTINELS["_REMOVED_CONTEXT_ARGS"] -_RESERVED_FLOW_MODEL_PARAM_NAMES = frozenset({"flow", "meta", "context_type", "result_type"}) +_RESERVED_FLOW_MODEL_PARAM_NAMES = frozenset({"flow", "meta", "context_type", "result_type", "type_"}) class _LazyMarker: - pass + def __repr__(self) -> str: + return "Lazy" class _FromContextMarker: - pass + def __repr__(self) -> str: + return "FromContext" class FromContext: @@ -76,16 +82,12 @@ class _ParsedAnnotation: class _FlowModelParam: name: str annotation: Any - kind: str + is_contextual: bool is_lazy: bool has_function_default: bool function_default: Any = _UNSET context_validation_annotation: Any = _UNSET - @property - def is_contextual(self) -> bool: - return self.kind == "contextual" - @property def validation_annotation(self) -> Any: if self.context_validation_annotation is not _UNSET: @@ -155,6 +157,34 @@ class _AutoContextSpec: fields: Dict[str, Tuple[Any, Any]] +class _SerializedAnnotation(NamedTuple): + kind: str + value: Any + args: Tuple[Any, ...] = () + + +class _SerializedFlowModelParam(NamedTuple): + name: str + annotation: _SerializedAnnotation + is_contextual: bool + is_lazy: bool + has_function_default: bool + function_default: Any + context_validation_annotation: _SerializedAnnotation + + +class _SerializedFlowModelConfig(NamedTuple): + func: _AnyCallable + return_annotation: _SerializedAnnotation + context_type: _SerializedAnnotation + result_type: _SerializedAnnotation + auto_wrap_result: bool + auto_unwrap: bool + parameters: Tuple[_SerializedFlowModelParam, ...] + declared_context_type: _SerializedAnnotation + path: Optional[PyObjectPath] + + def _callable_name(func: _AnyCallable) -> str: return getattr(func, "__name__", type(func).__name__) @@ -163,6 +193,169 @@ def _callable_qualname(func: _AnyCallable) -> str: return getattr(func, "__qualname__", type(func).__qualname__) +def _clone_function_without_annotations(fn: _AnyCallable) -> _AnyCallable: + """Return a behavior-equivalent function whose annotations are not a pickle dependency. + + ``@Flow.model`` analyzes annotations eagerly when the decorator runs and + stores the resolved contract in ``_FlowModelConfig``. During local/generated + model pickling we want workers to execute the original function body, but we + do not want them to re-evaluate or unpickle the original annotations. Some + runtime annotations, notably Pydantic generic specializations such as + ``GenericResult[int]``, are valid Python objects in-process but do not have a + durable import path for fresh-process cloudpickle restore. + """ + + if not isinstance(fn, FunctionType): + return fn + + clone = FunctionType(fn.__code__, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__) + clone.__kwdefaults__ = fn.__kwdefaults__ + clone.__module__ = fn.__module__ + clone.__qualname__ = fn.__qualname__ + clone.__doc__ = fn.__doc__ + clone.__dict__.update(getattr(fn, "__dict__", {})) + clone.__annotations__ = {} + return clone + + +def _is_pydantic_generic_type(annotation: Any) -> bool: + metadata = getattr(annotation, "__pydantic_generic_metadata__", None) + return isinstance(annotation, type) and bool(metadata) and metadata.get("origin") is not None + + +def _serialize_annotation(annotation: Any) -> Any: + """Serialize the annotation shapes that are part of a Flow.model contract. + + The goal is not to serialize arbitrary Python typing objects perfectly. It + is narrower and deliberate: keep the already-analyzed ``@Flow.model`` + contract portable across subprocess/Ray restore without asking workers to + resolve annotations from the original defining scope again. + + Raw annotations are still allowed as a fallback for ordinary importable + objects. The explicit cases below cover annotation containers that commonly + wrap ccflow/Pydantic model types and would otherwise embed fragile runtime + objects directly in the cloudpickle payload. + """ + + if annotation is _UNSET or annotation is inspect.Signature.empty: + return _SerializedAnnotation(kind="raw", value=annotation) + + if _is_pydantic_generic_type(annotation): + metadata = annotation.__pydantic_generic_metadata__ + return _SerializedAnnotation( + kind="pydantic_generic", + value=_serialize_annotation(metadata["origin"]), + args=tuple(_serialize_annotation(arg) for arg in metadata["args"]), + ) + + origin = get_origin(annotation) + if origin is Annotated: + args = get_args(annotation) + return _SerializedAnnotation(kind="annotated", value=_serialize_annotation(args[0]), args=args[1:]) + if origin in _UNION_ORIGINS: + return _SerializedAnnotation(kind="union", value=tuple(_serialize_annotation(arg) for arg in get_args(annotation))) + if origin is Literal: + return _SerializedAnnotation(kind="literal", value=get_args(annotation)) + if origin is not None: + return _SerializedAnnotation( + kind="generic_alias", + value=_serialize_annotation(origin), + args=tuple(_serialize_annotation(arg) for arg in get_args(annotation)), + ) + + return _SerializedAnnotation(kind="raw", value=annotation) + + +def _restore_annotation(payload: Any) -> Any: + """Restore an annotation payload produced by ``_serialize_annotation``.""" + + if not isinstance(payload, _SerializedAnnotation): + raise TypeError(f"Unknown serialized annotation payload: {payload!r}") + value = payload.value + if payload.kind == "raw": + return value + if payload.kind == "pydantic_generic": + origin = _restore_annotation(value) + args = tuple(_restore_annotation(arg) for arg in payload.args) + return origin[args[0] if len(args) == 1 else args] + if payload.kind == "annotated": + return Annotated.__class_getitem__((_restore_annotation(value), *payload.args)) + if payload.kind == "union": + return Union.__getitem__(tuple(_restore_annotation(arg) for arg in value)) + if payload.kind == "literal": + return Literal.__getitem__(value) + if payload.kind == "generic_alias": + origin = _restore_annotation(value) + args = tuple(_restore_annotation(arg) for arg in payload.args) + return origin[args[0] if len(args) == 1 else args] + raise TypeError(f"Unknown serialized annotation payload kind: {payload.kind!r}") + + +def _serialize_flow_model_param(param: _FlowModelParam) -> _SerializedFlowModelParam: + return _SerializedFlowModelParam( + name=param.name, + annotation=_serialize_annotation(param.annotation), + is_contextual=param.is_contextual, + is_lazy=param.is_lazy, + has_function_default=param.has_function_default, + function_default=param.function_default, + context_validation_annotation=_serialize_annotation(param.context_validation_annotation), + ) + + +def _restore_flow_model_param(payload: _SerializedFlowModelParam) -> _FlowModelParam: + if not isinstance(payload, _SerializedFlowModelParam): + raise TypeError(f"Unknown Flow.model parameter payload: {payload!r}") + return _FlowModelParam( + name=payload.name, + annotation=_restore_annotation(payload.annotation), + is_contextual=payload.is_contextual, + is_lazy=payload.is_lazy, + has_function_default=payload.has_function_default, + function_default=payload.function_default, + context_validation_annotation=_restore_annotation(payload.context_validation_annotation), + ) + + +def _serialize_flow_model_config(config: _FlowModelConfig) -> _SerializedFlowModelConfig: + """Return a tagged, portable description of an analyzed Flow.model config. + + This is intentionally explicit instead of relying on ``_FlowModelConfig`` or + ``_FlowModelParam`` pickle hooks. The payload is the persistence boundary + for local generated models and serialized context transforms: function + behavior plus the resolved contract needed to rebuild the generated + ``CallableModel`` class. It is not a second signature-analysis path. + """ + + return _SerializedFlowModelConfig( + func=_clone_function_without_annotations(config.func), + return_annotation=_serialize_annotation(config.return_annotation), + context_type=_serialize_annotation(config.context_type), + result_type=_serialize_annotation(config.result_type), + auto_wrap_result=config.auto_wrap_result, + auto_unwrap=config.auto_unwrap, + parameters=tuple(_serialize_flow_model_param(param) for param in config.parameters), + declared_context_type=_serialize_annotation(config.declared_context_type), + path=config.path, + ) + + +def _restore_flow_model_config(payload: _SerializedFlowModelConfig) -> _FlowModelConfig: + if not isinstance(payload, _SerializedFlowModelConfig): + raise TypeError(f"Unknown Flow.model config payload: {payload!r}") + return _FlowModelConfig( + func=payload.func, + return_annotation=_restore_annotation(payload.return_annotation), + context_type=_restore_annotation(payload.context_type), + result_type=_restore_annotation(payload.result_type), + auto_wrap_result=payload.auto_wrap_result, + auto_unwrap=payload.auto_unwrap, + parameters=tuple(_restore_flow_model_param(param) for param in payload.parameters), + declared_context_type=_restore_annotation(payload.declared_context_type), + path=payload.path, + ) + + def _resolved_flow_signature( fn: _AnyCallable, *, @@ -222,15 +415,16 @@ def _strip_annotated(annotation: Any) -> Any: def _is_result_annotation(annotation: Any) -> bool: + annotation = _strip_annotated(annotation) origin = get_origin(annotation) or annotation - if isinstance(origin, type) and issubclass(origin, ResultBase): - return True + return isinstance(origin, type) and issubclass(origin, ResultBase) - if get_origin(annotation) in _UNION_ORIGINS: - args = tuple(arg for arg in get_args(annotation) if arg is not type(None)) - return bool(args) and all(_is_result_annotation(arg) for arg in args) - return False +def _result_union_members(annotation: Any) -> Tuple[Any, ...]: + annotation = _strip_annotated(annotation) + if get_origin(annotation) not in _UNION_ORIGINS: + return () + return tuple(arg for arg in get_args(annotation) if arg is not type(None)) def _context_type_annotations_compatible(func_annotation: Any, context_annotation: Any) -> bool: @@ -330,7 +524,7 @@ def _analyze_flow_function( _FlowModelParam( name=param.name, annotation=parsed.base, - kind="contextual" if parsed.is_from_context else "regular", + is_contextual=parsed.is_from_context, is_lazy=parsed.is_lazy, has_function_default=has_default, function_default=param.default if has_default else _UNSET, @@ -403,7 +597,7 @@ def _analyze_flow_model( _FlowModelParam( name=param.name, annotation=param.annotation, - kind=param.kind, + is_contextual=param.is_contextual, is_lazy=param.is_lazy, has_function_default=param.has_function_default, function_default=param.function_default, @@ -412,8 +606,16 @@ def _analyze_flow_model( ) parameters = tuple(updated_params) - auto_wrap_result = not _is_result_annotation(sig.return_annotation) - result_type = GenericResult[sig.return_annotation] if auto_wrap_result else sig.return_annotation + return_annotation = _strip_annotated(sig.return_annotation) + union_result_members = _result_union_members(return_annotation) + if union_result_members and any(_is_result_annotation(arg) for arg in union_result_members): + raise TypeError( + "@Flow.model does not support Union or Optional ResultBase return annotations. " + "Return one concrete ResultBase subclass, or return an ordinary value and let @Flow.model wrap it." + ) + + auto_wrap_result = not _is_result_annotation(return_annotation) + result_type = GenericResult[return_annotation] if auto_wrap_result else return_annotation return _FlowModelConfig( func=fn, diff --git a/ccflow/callable.py b/ccflow/callable.py index 986c04b..6a113e2 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -432,7 +432,11 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): - """Decorator for methods on callable models""" + """Decorator for methods on callable models. + + Pass ``auto_context=True`` or a ``ContextBase`` subclass to derive the + method context from annotated parameters. + """ auto_context = kwargs.pop("auto_context", False) if auto_context is not False: @@ -477,65 +481,23 @@ def deps(*args, **kwargs): @staticmethod def model(*args, **kwargs): - """Decorator that generates a CallableModel class from a plain Python function. - - The generated model participates in the normal CallableModel execution - path, including evaluation, caching, dependency discovery, registry use, - and serialization. The function signature declares both the model's - construction-time inputs and its runtime/contextual inputs. - - Args: - context_type: Optional ContextBase subclass used only to validate/coerce - `FromContext[...]` inputs against an existing nominal context shape - auto_unwrap: When True, `.flow.compute(...)` unwraps auto-wrapped - `GenericResult(value=...)` outputs back to the annotated return type. - Explicit `ResultBase` returns are left unchanged. Default: False. - model_base: Optional custom `CallableModel` subclass to use as an - additional base for the generated model class. - cacheable: Enable caching of results (default: False) - volatile: Mark as volatile (default: False) - log_level: Logging verbosity (default: logging.DEBUG) - validate_result: Validate return type (default: True) - verbose: Verbose logging output (default: True) - evaluator: Custom evaluator (default: None) - - Primary authoring model: - Mark runtime/contextual inputs explicitly with `FromContext[...]`. - Ordinary unmarked parameters are regular bound inputs and are never - read implicitly from the runtime context. - - @Flow.model - def greeting(prefix: str, name: FromContext[str]) -> str: - return f"{prefix}, {name}" - - model = greeting(prefix="Hello") - assert model.flow.compute(name="Ada").value == "Hello, Ada" - - Dependencies: - Any ordinary parameter can be bound either to a literal value or - to another CallableModel. When a CallableModel is supplied, the - generated model treats it as an upstream dependency and resolves it - with the current context before calling the underlying function. - - `FromContext[...]` parameters are different: they may be satisfied by - runtime context, construction-time contextual defaults, or function - defaults, but not by CallableModel values. - - Usage: - @Flow.model - def length(text: FromContext[str]) -> int: - return len(text) - - @Flow.model - def score(base: int, bonus: FromContext[int]) -> int: - return base + bonus - - model = score(base=length()) - result = model.flow.compute(text="abcd", bonus=3) - assert result.value == 7 - - Returns: - A factory function that creates CallableModel instances + """Generate a ``CallableModel`` factory from a typed Python function. + + Unmarked function parameters become construction-time model inputs. + They can be bound to literal values or to upstream ``CallableModel`` + dependencies. Parameters annotated as ``FromContext[T]`` are runtime + inputs supplied by ``model.flow.compute(...)``, a context object, + construction-time contextual defaults, or ``with_context(...)``. + + Plain return annotations are wrapped in ``GenericResult`` so generated + models still follow the normal ccflow result/evaluator/cache path. + Functions that already return a ``ResultBase`` subclass are left as-is. + + Common options mirror ``Flow.call``: ``cacheable``, ``volatile``, + ``log_level``, ``validate_result``, ``verbose``, and ``evaluator``. + Generated-model options also include ``context_type`` for validating + contextual fields, ``auto_unwrap`` for ergonomic compute results, and + ``model_base`` for custom ``CallableModel`` bases. """ from .flow_model import flow_model @@ -543,7 +505,20 @@ def score(base: int, bonus: FromContext[int]) -> int: @staticmethod def context_transform(*args, **kwargs): - """Decorator that turns a top-level function into a serializable with_context() transform factory.""" + """Create a serializable transform factory for ``model.flow.with_context``. + + Mark transform inputs with ``FromContext[T]`` when they should be read + from the caller's runtime context. Unmarked parameters are regular + arguments bound when the transform factory is called. + + Mapping-returning transforms are patch transforms and are passed + positionally to ``with_context(...)``. Scalar-returning transforms are + field transforms and are passed as keyword overrides, for example + ``model.flow.with_context(score=shift_score(amount=1))``. + + Transform bindings serialize enough function metadata to survive model + serialization, including local or nested functions through cloudpickle. + """ from .flow_model import flow_context_transform return flow_context_transform(*args, **kwargs) diff --git a/ccflow/context.py b/ccflow/context.py index 56a3774..04967d3 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -5,7 +5,7 @@ from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import ConfigDict, PrivateAttr, field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency @@ -97,7 +97,8 @@ class FlowContext(ContextBase): Instead of generating a new ContextBase subclass for each @Flow.model, this single class with extra="allow" serves as the universal carrier. - Validation happens via TypedDict + TypeAdapter at compute() time. + Validation happens against the generated model's declared contextual input + types at compute() time. This design avoids: - Proliferation of dynamic _funcname_Context classes @@ -106,13 +107,9 @@ class FlowContext(ContextBase): """ model_config = ConfigDict(extra="allow", frozen=True) - _frozen_hash_key: Hashable | None = PrivateAttr(default=None) - _hash_value: int | None = PrivateAttr(default=None) def _hash_key(self) -> Hashable: - if self._frozen_hash_key is None: - self._frozen_hash_key = _freeze_for_hash(self.model_dump(mode="python")) - return self._frozen_hash_key + return _freeze_for_hash(self.model_dump(mode="python")) def __eq__(self, other: Any) -> bool: if self is other: @@ -122,9 +119,7 @@ def __eq__(self, other: Any) -> bool: return self._hash_key() == other._hash_key() def __hash__(self) -> int: - if self._hash_value is None: - self._hash_value = hash(self._hash_key()) - return self._hash_value + return hash(self._hash_key()) def _freeze_for_hash(value: Any) -> Hashable: @@ -139,8 +134,6 @@ def _freeze_for_hash(value: Any) -> Hashable: try: hash(value) except TypeError as exc: - if hasattr(value, "__dict__"): - return (type(value), _freeze_for_hash(vars(value))) raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}: {value!r}") from exc return value diff --git a/ccflow/evaluators/common.py b/ccflow/evaluators/common.py index 7baa529..06cdfad 100644 --- a/ccflow/evaluators/common.py +++ b/ccflow/evaluators/common.py @@ -7,7 +7,7 @@ from types import MappingProxyType from typing import Any, Callable, Dict, List, Optional, Set, Union -from pydantic import Field, PrivateAttr, ValidationError, field_validator +from pydantic import Field, PrivateAttr, field_validator from typing_extensions import override from ..base import BaseModel, make_lazy_result @@ -21,7 +21,7 @@ TransparentModelEvaluationContext, WrapperModel, ) -from ..utils.tokenize import compute_cache_token, compute_data_token +from ..utils.tokenize import compute_cache_token __all__ = [ "cache_key", @@ -43,8 +43,31 @@ class _EffectiveEvaluationKeyUnavailable(Exception): """Internal signal to use the existing structural evaluation key.""" -_EFFECTIVE_IDENTITY_DECLINED_ERRORS = (TypeError, ValueError, ValidationError) _EFFECTIVE_EVALUATION_KEY_VERSION = "ccflow_effective_evaluation_key_v1" +_RECURSIVE_EFFECTIVE_IDENTITY_SENTINEL = "recursive_effective_identity" + + +class _IdentityMemoKey: + """Identity-based key that keeps objects alive while effective keys recurse. + + Raw ``id(...)`` tuples are unsafe here because rewritten context objects can + be short-lived and Python may reuse their ids during a single graph build. + Holding references keeps the ids stable for the lifetime of the memo key; + equality still checks object identity, so hash collisions are harmless. + """ + + __slots__ = ("model", "context", "_hash") + + def __init__(self, model: CallableModel, context: Any): + self.model = model + self.context = context + self._hash = hash((id(model), id(context))) + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Any) -> bool: + return isinstance(other, _IdentityMemoKey) and self.model is other.model and self.context is other.context def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[EvaluatorBase]) -> EvaluatorBase: @@ -69,15 +92,22 @@ def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[Evaluato return MultiEvaluator(evaluators=[first, second]) -def _flatten_cache_key_context(flow_obj: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[EvaluatorBase]]: - fn = flow_obj.fn - non_transparent: List[EvaluatorBase] = [] - while isinstance(flow_obj.context, ModelEvaluationContext): - fn = flow_obj.fn if flow_obj.fn != "__call__" else fn - if not isinstance(flow_obj, TransparentModelEvaluationContext): - non_transparent.append(flow_obj.model) - flow_obj = flow_obj.context - return flow_obj, fn if fn != "__call__" else flow_obj.fn, non_transparent +def _flatten_cache_key_context(evaluation_context: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[CallableModel]]: + """Strip transparent evaluator wrappers and keep opaque wrappers in order. + + This preserves the structural cache-key behavior: transparent evaluators are + ignored, while non-transparent evaluators remain part of the identity. The + returned function name is the innermost non-``__call__`` name, so + ``__deps__`` does not collapse into ``__call__`` when wrapped. + """ + fn = evaluation_context.fn + outer_to_inner_evaluators: List[CallableModel] = [] + while isinstance(evaluation_context.context, ModelEvaluationContext): + fn = evaluation_context.fn if evaluation_context.fn != "__call__" else fn + if not isinstance(evaluation_context, TransparentModelEvaluationContext): + outer_to_inner_evaluators.append(evaluation_context.model) + evaluation_context = evaluation_context.context + return evaluation_context, fn if fn != "__call__" else evaluation_context.fn, outer_to_inner_evaluators class MultiEvaluator(EvaluatorBase): @@ -237,41 +267,11 @@ def _format_result(self, result: ResultType) -> str: return f"{msg_str}{pformat(result_dict, **self.format_config.pformat_config)}" -def _unwrap_evaluation_context(evaluation_context: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[CallableModel]]: - """Strip transparent evaluator wrappers and keep opaque wrappers in order. - - This preserves the existing structural cache-key behavior: transparent - evaluators are ignored, while non-transparent evaluators remain part of the - identity. The returned function name is the innermost non-``__call__`` name, - so ``__deps__`` does not collapse into ``__call__`` when wrapped. - """ - fn = evaluation_context.fn - outer_to_inner_evaluators = [] - while isinstance(evaluation_context.context, ModelEvaluationContext): - fn = evaluation_context.fn if evaluation_context.fn != "__call__" else fn - if not isinstance(evaluation_context, TransparentModelEvaluationContext): - outer_to_inner_evaluators.append(evaluation_context.model) - evaluation_context = evaluation_context.context - return evaluation_context, fn if fn != "__call__" else evaluation_context.fn, outer_to_inner_evaluators - - -def _evaluator_identity_payload(outer_to_inner_evaluators: List[CallableModel]) -> List[Dict[str, Any]]: - return [evaluator.model_dump(mode="python") for evaluator in outer_to_inner_evaluators] - - -def _memo_token(model: CallableModel, context: Any) -> tuple[int, str]: - if hasattr(context, "model_dump"): - context_value = context.model_dump(mode="python") - else: - context_value = context - return (id(model), compute_data_token((type(context), context_value))) - - def _effective_model_key( model: CallableModel, context: Any, - memo: Dict[tuple[int, str], bytes], - active: Set[tuple[int, str]], + memo: Dict[_IdentityMemoKey, bytes], + active: Set[int], ) -> Optional[bytes]: """Return a model's opt-in effective key, or ``None`` for normal opt-out. @@ -280,22 +280,16 @@ def _effective_model_key( payload are resolved by ``_resolve_effective_identity_payload()`` so models declare what matters without constructing recursive keys themselves. """ - token = _memo_token(model, context) + token = _IdentityMemoKey(model, context) if token in memo: return memo[token] - if token in active: + model_id = id(model) + if model_id in active: raise _EffectiveEvaluationKeyUnavailable("recursive effective identity") - active.add(token) + active.add(model_id) try: - try: - payload = model._evaluation_identity_payload(context) - except _EFFECTIVE_IDENTITY_DECLINED_ERRORS as exc: - # Identity derivation runs before the actual call and may encounter - # the same validation failures as evaluation context construction. - # Falling back preserves existing behavior instead of turning key - # computation into a new failure mode for ordinary models. - raise _EffectiveEvaluationKeyUnavailable(str(exc)) from exc + payload = model._evaluation_identity_payload(context) # For normal CallableModels, `_evaluation_identity_payload` defaults to # None, so we should hit this path if payload is None: @@ -308,21 +302,21 @@ def _effective_model_key( memo[token] = key return key finally: - active.discard(token) + active.discard(model_id) def _resolve_effective_identity_payload( value: Any, - memo: Dict[tuple[int, str], bytes], - active: Set[tuple[int, str]], + memo: Dict[_IdentityMemoKey, bytes], + active: Set[int], ) -> Any: """Replace dependency invocation markers with recursive effective keys.""" if isinstance(value, EvaluationDependency): + evaluation = value.model.__call__.get_evaluation_context(value.model, value.context) try: - evaluation = value.model.__call__.get_evaluation_context(value.model, value.context) - except _EFFECTIVE_IDENTITY_DECLINED_ERRORS as exc: - raise _EffectiveEvaluationKeyUnavailable(f"dependency {type(value.model).__name__} could not build evaluation context: {exc}") from exc - return _effective_evaluation_key(evaluation, memo=memo, active=active) + return _effective_evaluation_key(evaluation, memo=memo, active=active, fallback=False) + except _EffectiveEvaluationKeyUnavailable: + return (_RECURSIVE_EFFECTIVE_IDENTITY_SENTINEL, type(value.model).__module__, type(value.model).__qualname__) if isinstance(value, dict): return {key: _resolve_effective_identity_payload(item, memo, active) for key, item in value.items()} if isinstance(value, list): @@ -334,13 +328,14 @@ def _resolve_effective_identity_payload( def _effective_evaluation_key( evaluation_context: ModelEvaluationContext, - memo: Optional[Dict[tuple[int, str], bytes]] = None, - active: Optional[Set[tuple[int, str]]] = None, + memo: Optional[Dict[_IdentityMemoKey, bytes]] = None, + active: Optional[Set[int]] = None, + fallback: bool = True, ) -> bytes: """Use opt-in effective identity for ``__call__``; otherwise preserve ``cache_key()``.""" memo = {} if memo is None else memo active = set() if active is None else active - inner, fn, outer_to_inner_evaluators = _unwrap_evaluation_context(evaluation_context) + inner, fn, outer_to_inner_evaluators = _flatten_cache_key_context(evaluation_context) if fn != "__call__": # Keep non-call evaluations, especially ``__deps__``, on the exact # public structural key path. Effective identity is only meant to @@ -358,6 +353,8 @@ def _effective_evaluation_key( try: key = _effective_model_key(inner.model, inner.context, memo, active) except _EffectiveEvaluationKeyUnavailable as exc: + if not fallback: + raise # Effective identity is an optimization/semantic narrowing for opt-in # generated models. If deriving it is unclear, do not make cache/graph # key construction a new failure mode; use the old structural key. @@ -495,12 +492,12 @@ def _build_dependency_graph( graph: CallableModelGraph, parent_key: Optional[bytes] = None, parent_model: Optional[CallableModel] = None, -): +) -> bytes: # Generated/bound ``@Flow.model`` nodes can use effective identity so unused # ambient FlowContext fields do not split the graph. Normal CallableModel # nodes opt out and therefore still receive ``cache_key(evaluation_context)``. key = _effective_evaluation_key(evaluation_context) - unwrapped_evaluation_context, _, _ = _unwrap_evaluation_context(evaluation_context) + unwrapped_evaluation_context, _, _ = _flatten_cache_key_context(evaluation_context) current_model = unwrapped_evaluation_context.model is_same_evaluation_key = parent_key == key is_collapsed_wrapper_child = is_same_evaluation_key and _is_wrapper_to_wrapped_edge(parent_model, current_model) @@ -526,7 +523,7 @@ def _build_dependency_graph( # Preserve normal graph deduplication by key, and make the only exception # the exact same-key wrapper -> wrapped edge. if not is_new_graph_key and not is_collapsed_wrapper_child: - return + return key # Note that __deps__ will be evaluated using whatever evaluator is configured for the model, # which could include logging, caching, etc. @@ -541,6 +538,7 @@ def _build_dependency_graph( parent_key=key, parent_model=current_model, ) + return key def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> CallableModelGraph: @@ -549,13 +547,8 @@ def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> Callable Args: evaluation_context: The model and context to build the graph for. """ - # Keep the root id on the same identity function used for every graph node. - # For existing models this is still ``cache_key(evaluation_context)``; for - # generated flow models it is the narrowed key that ignores unused ambient - # context fields. - root_key = _effective_evaluation_key(evaluation_context) - graph = CallableModelGraph(ids={}, graph={}, root_id=root_key) - _build_dependency_graph(evaluation_context, graph) + graph = CallableModelGraph(ids={}, graph={}, root_id=b"") + graph.root_id = _build_dependency_graph(evaluation_context, graph) return graph diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index c358cec..6854101 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -44,11 +44,10 @@ """ import inspect -import logging import sys from base64 import b64decode, b64encode from collections import OrderedDict -from functools import lru_cache, wraps +from functools import lru_cache, singledispatch, wraps from typing import ( Annotated, Any, @@ -63,17 +62,17 @@ Set, Tuple, Type, + Union, cast, get_args, get_origin, get_type_hints, ) -from pydantic import BaseModel as PydanticModel, Field, TypeAdapter, ValidationError, model_validator +from pydantic import BaseModel as PydanticModel, Field, SkipValidation, TypeAdapter, ValidationError, model_validator from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation from ._flow_model_binding import ( - _REMOVED_CONTEXT_ARGS, _UNION_ORIGINS, _UNSET, FromContext, @@ -84,6 +83,8 @@ _FlowModelConfig, _FlowModelParam, _resolved_flow_signature, + _restore_flow_model_config, + _serialize_flow_model_config, _strip_annotated, ) from .base import BaseModel, ContextBase, ContextType, ResultBase @@ -92,6 +93,7 @@ from .exttypes import PyObjectPath from .local_persistence import register_ccflow_import_path from .result import GenericResult +from .utils.tokenize import compute_behavior_token, compute_data_token __all__ = ( "FlowAPI", @@ -99,12 +101,10 @@ "FromContext", "Lazy", "ContextTransform", - "clear_flow_model_caches", "flow_context_transform", ) _AnyCallable = Callable[..., Any] -log = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -148,9 +148,9 @@ def _is_unset_flow_input(value: Any) -> bool: class ContextTransform(PydanticModel): """Serializable binding produced by ``@Flow.context_transform``. - Importable top-level transforms are stored by import path. Local/nested - transforms fall back to a cloudpickled config payload so bound models can - survive pickle, cloudpickle, and Ray round trips. + Transform bindings store either an import path or a cloudpickled config + payload so bound models can survive pickle, cloudpickle, and Ray round + trips. """ kind: Literal["context_transform"] = "context_transform" @@ -172,28 +172,92 @@ class StaticValueSpec(PydanticModel): value: Any -class FieldContextSpec(PydanticModel): - """A ``with_context(field=transform(...))`` contextual override.""" +_FieldOverrideSpec = Annotated[StaticValueSpec | ContextTransform, Field(discriminator="kind")] - kind: Literal["context_value"] = "context_value" + +class PatchContextOperation(PydanticModel): + """One ordered positional context patch in a ``with_context`` chain.""" + + kind: Literal["patch"] = "patch" binding: ContextTransform -class PatchContextSpec(PydanticModel): - """A positional ``with_context(transform(...))`` mapping patch.""" +class FieldContextOperation(PydanticModel): + """One ordered field override in a ``with_context`` chain.""" - kind: Literal["context_patch"] = "context_patch" - binding: ContextTransform + kind: Literal["field"] = "field" + name: str + spec: _FieldOverrideSpec -_FieldOverrideSpec = StaticValueSpec | FieldContextSpec +_ContextOperation = Annotated[PatchContextOperation | FieldContextOperation, Field(discriminator="kind")] class _BoundContextSpec(PydanticModel): """Normalized, serializable representation of all context bindings.""" - patches: List[PatchContextSpec] = Field(default_factory=list) - field_overrides: Dict[str, _FieldOverrideSpec] = Field(default_factory=dict) + operations: List[_ContextOperation] = Field(default_factory=list) + + +class _BoundModelContext(FlowContext): + """Flow.call carrier for BoundModel that preserves existing context objects.""" + + @model_validator(mode="wrap") + @classmethod + def _preserve_context_base(cls, value, handler, info): + if isinstance(value, ContextBase): + return value + return handler(value) + + +class _DependencyIdentity(NamedTuple): + kind: Literal["dependency"] + evaluation: EvaluationDependency + + +class _LiteralIdentity(NamedTuple): + kind: Literal["literal"] + value: Any + + +class _UnresolvedLazyDependencyIdentity(NamedTuple): + kind: Literal["unresolved_lazy_dependency"] + model_type: str + model: Dict[str, Any] + context_type: str + context: Dict[str, Any] + missing_context: Tuple[str, ...] + missing_transform_context: Tuple[Tuple[str, Tuple[str, ...]], ...] + + +class _RegularInputIdentity(NamedTuple): + kind: Literal["regular_input"] + name: str + lazy: bool + payload: Any + + +class _GeneratedModelIdentity(NamedTuple): + kind: Literal["generated_flow_model_v1"] + model_type: str + contextual_inputs: Dict[str, Any] + regular_inputs: Tuple[_RegularInputIdentity, ...] + model_base_fields: Dict[str, Any] + + +class _LocalFlowModelPicklePayload(NamedTuple): + serialized_config: Any + factory_kwargs: Dict[str, Any] + model_data: Dict[str, Any] + + +class _PortableBaseModelState(NamedTuple): + data: Dict[str, Any] + + +class _PortablePydanticModelState(NamedTuple): + model_type: str + data: Dict[str, Any] # --------------------------------------------------------------------------- @@ -222,6 +286,9 @@ def _context_transform_identifier(binding: ContextTransform) -> str: def _is_model_dependency(value: Any) -> bool: + # Keep this predicate in the module that can import CallableModel. The + # binding analyzers also need it, but _flow_model_binding.py is imported by + # callable.py and cannot import CallableModel directly without a cycle. return isinstance(value, CallableModel) @@ -280,7 +347,7 @@ def _type_adapter(annotation: Any) -> TypeAdapter: def _can_validate_type(annotation: Any) -> bool: try: _type_adapter(annotation) - except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation, TypeError, ValueError): + except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation): return False return True @@ -297,7 +364,7 @@ def _coerce_value(name: str, value: Any, annotation: Any, source: str) -> Any: return value try: return _type_adapter(annotation).validate_python(value) - except (ValidationError, ValueError, TypeError) as exc: + except ValidationError as exc: expected = _expected_type_repr(annotation) raise TypeError(f"{source} '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc @@ -342,19 +409,6 @@ def _maybe_auto_unwrap_external_result(target: CallableModel, result: Any) -> An return result -def _type_accepts_str(annotation: Any) -> bool: - if annotation is Any or annotation is inspect.Parameter.empty: - return True - if annotation is str: - return True - origin = get_origin(annotation) - if origin is Annotated: - return _type_accepts_str(get_args(annotation)[0]) - if origin in _UNION_ORIGINS: - return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) - return False - - def _resolve_registry_candidate(value: str) -> Any: try: candidate = BaseModel.model_validate(value) @@ -375,7 +429,66 @@ def _registry_candidate_allowed(expected_type: Any, candidate: Any) -> bool: return True -def _ensure_top_level_named_function(fn: _AnyCallable, *, decorator_name: str) -> None: +def _resolve_bound_param_registry_ref(param: _FlowModelParam, value: Any) -> Any: + """Resolve registry references for bound regular parameters. + + Generated fields use SkipValidation, so registry lookup no longer needs to + run as a Pydantic before-validator just to beat field validation. Keeping it + here lets the generated model's after-validator own the construction + contract. Literal string validation gets first refusal for non-lazy + parameters; registry aliases are fallback dependency syntax. + """ + + if not isinstance(value, str): + return value + if not param.is_lazy and _can_validate_type(param.annotation): + try: + _type_adapter(param.annotation).validate_python(value) + except ValidationError: + pass + else: + return value + + candidate = _resolve_registry_candidate(value) + if candidate is None: + return value + if _registry_candidate_allowed(param.annotation, candidate): + return candidate + return value + + +def _resolve_serialized_dependency_ref( + value: Any, + *, + include_target_alias: bool = False, +) -> Any: + """Restore a direct serialized CallableModel reference.""" + + def serialized_model_type(item: Dict[str, Any]) -> Optional[type]: + marker = item.get("type_", _UNSET) + if marker is _UNSET and include_target_alias: + marker = item.get("_target_", _UNSET) + if marker is _UNSET: + return None + try: + candidate = marker.object if isinstance(marker, PyObjectPath) else PyObjectPath(marker).object + except (ImportError, AttributeError, TypeError, ValueError): + return None + return candidate if inspect.isclass(candidate) and issubclass(candidate, CallableModel) else None + + if type(value) is not dict or serialized_model_type(value) is None: + return value + # ``type_`` is ccflow's default serialized-model marker. ``_target_`` is + # also a ccflow alias, but it is Hydra's config language too, so callers + # only enable that spelling after normal literal validation has failed. + try: + restored = BaseModel.model_validate(value) + except (ValidationError, ImportError, AttributeError, TypeError, ValueError): + return value + return restored if _is_model_dependency(restored) else value + + +def _ensure_named_python_function(fn: _AnyCallable, *, decorator_name: str) -> None: if not inspect.isfunction(fn): raise TypeError(f"{decorator_name} only supports Python functions.") @@ -406,7 +519,7 @@ def _load_context_transform_config(path: str) -> _FlowModelConfig: def _serialize_context_transform_config(config: _FlowModelConfig) -> str: import cloudpickle - payload = cloudpickle.dumps(config, protocol=5) + payload = cloudpickle.dumps(_serialize_flow_model_config(config), protocol=5) return b64encode(payload).decode("ascii") @@ -414,8 +527,10 @@ def _serialize_context_transform_config(config: _FlowModelConfig) -> str: def _load_serialized_context_transform_config(serialized_config: str) -> _FlowModelConfig: import cloudpickle - config = cloudpickle.loads(b64decode(serialized_config.encode("ascii"))) - if not isinstance(config, _FlowModelConfig): + payload = cloudpickle.loads(b64decode(serialized_config.encode("ascii"))) + try: + config = _restore_flow_model_config(payload) + except (TypeError, ValueError): raise TypeError("Stored context transform payload does not contain a Flow.context_transform binding.") return config @@ -453,38 +568,158 @@ def _is_mapping_annotation(annotation: Any) -> bool: return False -def _restore_pickled_flow_model(type_path: str, state: Dict[str, Any]) -> BaseModel: +def _restore_pickled_flow_model(type_path: str, model_data: Dict[str, Any]) -> BaseModel: cls = cast(type[BaseModel], PyObjectPath(type_path).object) - instance = cls.__new__(cls) - instance.__setstate__(state) - return instance + return cls.model_validate(_restore_portable_generated_model_state(model_data)) -def _restore_pickled_local_flow_model(serialized_factory_payload: bytes, state: Dict[str, Any]) -> BaseModel: +def _restore_pickled_local_flow_model(serialized_factory_payload: bytes) -> BaseModel: import cloudpickle - fn, factory_kwargs = cloudpickle.loads(serialized_factory_payload) - factory = flow_model(fn, **factory_kwargs) + payload = cloudpickle.loads(serialized_factory_payload) + config = _restore_flow_model_config(payload.serialized_config) + # Do not call ``flow_model(config.func, **factory_kwargs)`` here. That would + # re-run type-hint resolution in the receiving process, which is exactly the + # path that fails for local/postponed annotations and runtime-only generic + # aliases such as GenericResult[int]. The serialized config is the resolved + # contract from the defining process; rebuild the generated class from it. + factory = _build_flow_model_factory_from_config(config, payload.factory_kwargs) cls = cast(type[BaseModel], getattr(factory, "_generated_model")) - instance = cls.__new__(cls) - instance.__setstate__(state) - return instance + return cls.model_validate(_restore_portable_generated_model_state(payload.model_data)) -def _restore_generated_flow_model(factory_path: str, state: Dict[str, Any]) -> BaseModel: +def _restore_generated_flow_model(factory_path: str, model_data: Dict[str, Any]) -> BaseModel: """Restore a generated flow model by importing its factory function. This is the cross-process-safe restore path: importing the factory's module triggers the ``@Flow.model`` decorator, which re-creates the GeneratedModel - class. We then reconstruct the instance from the pickled state. + class. The instance is reconstructed through normal validation data instead + of raw Pydantic state because raw state can embed process-local generic + classes. """ factory = PyObjectPath(factory_path).object generated_cls = getattr(factory, "_generated_model", None) if generated_cls is None: raise ImportError(f"Cannot restore generated flow model: '{factory_path}' does not have a _generated_model attribute.") - instance = generated_cls.__new__(generated_cls) - instance.__setstate__(state) - return instance + return generated_cls.model_validate(_restore_portable_generated_model_state(model_data)) + + +@singledispatch +def _portable_generated_model_state_value(value: Any) -> Any: + """Remove fragile Pydantic generic instance classes from local pickle state.""" + + if _is_unset_flow_input(value) or _is_model_dependency(value): + return value + return value + + +@_portable_generated_model_state_value.register +def _(value: BaseModel) -> Any: + if _is_model_dependency(value): + return value + return _PortableBaseModelState(data=value.model_dump(mode="python", by_alias=True)) + + +@_portable_generated_model_state_value.register +def _(value: PydanticModel) -> Any: + if _is_model_dependency(value): + return value + return _PortablePydanticModelState( + model_type=str(PyObjectPath.validate(type(value))), + data=_portable_generated_model_state_value(value.model_dump(mode="python", by_alias=True)), + ) + + +@_portable_generated_model_state_value.register +def _(value: tuple) -> tuple: + return tuple(_portable_generated_model_state_value(item) for item in value) + + +@_portable_generated_model_state_value.register +def _(value: list) -> list: + return [_portable_generated_model_state_value(item) for item in value] + + +@_portable_generated_model_state_value.register +def _(value: OrderedDict) -> OrderedDict: + return OrderedDict((key, _portable_generated_model_state_value(item)) for key, item in value.items()) + + +@_portable_generated_model_state_value.register +def _(value: dict) -> dict: + return {key: _portable_generated_model_state_value(item) for key, item in value.items()} + + +@_portable_generated_model_state_value.register +def _(value: frozenset) -> frozenset: + return frozenset(_portable_generated_model_state_value(item) for item in value) + + +@_portable_generated_model_state_value.register +def _(value: set) -> set: + return {_portable_generated_model_state_value(item) for item in value} + + +def _portable_generated_model_state(model: "_GeneratedFlowModelBase") -> Dict[str, Any]: + """Return validation data for a generated model without raw Pydantic state.""" + + data: Dict[str, Any] = {} + for name in type(model).model_fields: + value = getattr(model, name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + data[name] = _portable_generated_model_state_value(value) + return data + + +@singledispatch +def _restore_portable_generated_model_state_value(value: Any) -> Any: + return value + + +@_restore_portable_generated_model_state_value.register +def _(value: _PortableBaseModelState) -> BaseModel: + return BaseModel.model_validate(_restore_portable_generated_model_state_value(value.data)) + + +@_restore_portable_generated_model_state_value.register +def _(value: _PortablePydanticModelState) -> PydanticModel: + cls = cast(type[PydanticModel], PyObjectPath(value.model_type).object) + return cls.model_validate(_restore_portable_generated_model_state_value(value.data)) + + +@_restore_portable_generated_model_state_value.register +def _(value: tuple) -> tuple: + return tuple(_restore_portable_generated_model_state_value(item) for item in value) + + +@_restore_portable_generated_model_state_value.register +def _(value: list) -> list: + return [_restore_portable_generated_model_state_value(item) for item in value] + + +@_restore_portable_generated_model_state_value.register +def _(value: OrderedDict) -> OrderedDict: + return OrderedDict((key, _restore_portable_generated_model_state_value(item)) for key, item in value.items()) + + +@_restore_portable_generated_model_state_value.register +def _(value: dict) -> dict: + return {key: _restore_portable_generated_model_state_value(item) for key, item in value.items()} + + +@_restore_portable_generated_model_state_value.register +def _(value: frozenset) -> frozenset: + return frozenset(_restore_portable_generated_model_state_value(item) for item in value) + + +@_restore_portable_generated_model_state_value.register +def _(value: set) -> set: + return {_restore_portable_generated_model_state_value(item) for item in value} + + +def _restore_portable_generated_model_state(data: Dict[str, Any]) -> Dict[str, Any]: + return {name: _restore_portable_generated_model_state_value(value) for name, value in data.items()} def _is_importable_function(func: _AnyCallable) -> bool: @@ -514,13 +749,48 @@ def _generated_model_factory_path_for_pickle(config: _FlowModelConfig, generated return None +def _module_has_factory_for_generated_class(module: Any, generated_cls: type, *, excluding: str) -> bool: + """Return whether a module-level factory still owns ``generated_cls``. + + During ``importlib.reload()``, generated class attributes from the previous + import can remain on the module while the new decorator is running. Those + stale classes should be replaced, not force a suffixed path that a clean + process will never recreate. If a live factory still points at the class, + the slot is occupied by a real duplicate and should not be reclaimed. + """ + + return any(name != excluding and getattr(value, "_generated_model", None) is generated_cls for name, value in vars(module).items()) + + +def _same_generated_function_source(existing_cls: type, config: _FlowModelConfig) -> bool: + existing_config = getattr(existing_cls, "__flow_model_config__", None) + if existing_config is None: + return False + existing_func = existing_config.func + current_func = config.func + existing_code = getattr(existing_func, "__code__", None) + current_code = getattr(current_func, "__code__", None) + filename = getattr(current_code, "co_filename", "") + return ( + getattr(existing_func, "__module__", None) == getattr(current_func, "__module__", None) + and getattr(existing_func, "__qualname__", None) == getattr(current_func, "__qualname__", None) + and existing_code is not None + and current_code is not None + and not filename.startswith("<") + and existing_code.co_filename == current_code.co_filename + and existing_code.co_firstlineno == current_code.co_firstlineno + ) + + def _register_generated_model_class(config: _FlowModelConfig, generated_cls: type) -> None: """Make generated classes importable when their factory function is importable. Importable module-level ``@Flow.model`` functions should serialize by a stable module path. Local, nested, and ``__main__`` definitions still use local-persistence registration because there is no durable import path for - their generated class. + their generated class. Duplicate importable generated names are rejected + instead of suffixed because suffixed paths are not reliably reproducible + across ``importlib.reload()`` and clean-process config/Hydra round trips. """ if _importable_function_path(config.func) is None: @@ -541,18 +811,31 @@ def _register_generated_model_class(config: _FlowModelConfig, generated_cls: typ if obj is None: register_ccflow_import_path(generated_cls) return - setattr(obj, parts[-1], generated_cls) - -def _context_transform_should_use_import_path(config: _FlowModelConfig) -> bool: - path = config.path - if path is None or not _is_importable_function(config.func): - return False - try: - resolved = PyObjectPath(str(path)).object - except ImportError: - return True - return isinstance(getattr(resolved, "__flow_context_transform_config__", None), _FlowModelConfig) + name = parts[-1] + existing = getattr(obj, name, None) + if existing is None or existing is generated_cls: + setattr(obj, name, generated_cls) + return + if getattr(existing, "__flow_model_config__", None) is not None and _same_generated_function_source(existing, config): + # Reloaded modules can keep aliases to the previous factory until the + # assignment currently being evaluated completes. Matching the previous + # generated class by source location lets those stale aliases be replaced + # while still rejecting true duplicate function definitions. + setattr(obj, name, generated_cls) + return + if getattr(existing, "__flow_model_config__", None) is not None and not _module_has_factory_for_generated_class( + module, existing, excluding=_callable_name(config.func) + ): + # ``importlib.reload()`` can leave the previous generated class on the + # module while the factory function is being rebound. Replacing that + # stale class keeps the advertised path reproducible in a clean process. + setattr(obj, name, generated_cls) + return + raise ValueError( + f"Cannot register generated Flow model class at {module_name}.{'.'.join(parts)} because that path is already occupied. " + "Use a unique function name for each importable @Flow.model factory." + ) # --------------------------------------------------------------------------- @@ -583,12 +866,8 @@ def _project_context_values_for_model(model: CallableModel, values: Dict[str, An return {name: values[name] for name in contract.input_types if name in values} -def _dependency_context_values(model: CallableModel, context: ContextBase) -> Dict[str, Any]: - return _project_context_values_for_model(model, _context_values(context)) - - def _dependency_context_for_model(model: CallableModel, context: ContextBase) -> ContextBase: - return _runtime_context_for_model(model, _dependency_context_values(model, context)) + return _runtime_context_for_model(model, _project_context_values_for_model(model, _context_values(context))) def _resolved_dependency_invocation(value: CallableModel, context: ContextBase) -> Tuple[CallableModel, ContextBase]: @@ -604,13 +883,24 @@ def _resolved_dependency_invocation(value: CallableModel, context: ContextBase) return value, _dependency_context_for_model(value, context) -def _merge_context_specs( - existing: _BoundContextSpec, patches: List[PatchContextSpec], field_overrides: Dict[str, _FieldOverrideSpec] -) -> _BoundContextSpec: - return _BoundContextSpec( - patches=[*existing.patches, *patches], - field_overrides={**existing.field_overrides, **field_overrides}, - ) +def _effective_context_operations(context_spec: _BoundContextSpec) -> Tuple[_ContextOperation, ...]: + """Drop field operations overwritten by later field operations. + + Patches are kept conservative because their write keys may be dynamic. Field + operations have explicit targets, so an earlier field transform for ``a`` + should not run or require inputs if a later field binding overwrites ``a``. + """ + + seen_fields: Set[str] = set() + operations: List[_ContextOperation] = [] + for operation in reversed(context_spec.operations): + if isinstance(operation, FieldContextOperation): + if operation.name in seen_fields: + continue + seen_fields.add(operation.name) + operations.append(operation) + operations.reverse() + return tuple(operations) def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: @@ -677,16 +967,45 @@ def _resolve_regular_param_value(model: "_GeneratedFlowModelBase", param: _FlowM f"Regular parameter '{param.name}' for {_callable_name(type(model).__flow_model_config__.func)} is still unbound. " "Bind it at construction time." ) - if _is_model_dependency(value): - if param.is_lazy: + if param.is_lazy: + if _is_model_dependency(value): return _make_lazy_thunk(value, context) + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") + + if _is_model_dependency(value): dependency_model, dependency_context = _resolved_dependency_invocation(value, context) return _unwrap_model_result(dependency_model(dependency_context)) - if param.is_lazy: - raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") return value +def _regular_dependency_identity(value: CallableModel, context: ContextBase) -> _DependencyIdentity: + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + return _DependencyIdentity( + kind="dependency", + evaluation=EvaluationDependency(dependency_model, dependency_context), + ) + + +def _lazy_regular_dependency_identity(value: CallableModel, context: ContextBase) -> Any: + unresolved, dependency_model, dependency_context = _lazy_dependency_identity(value, context) + if unresolved is not None: + return unresolved + assert dependency_model is not None + assert dependency_context is not None + return _DependencyIdentity( + kind="dependency", + evaluation=EvaluationDependency(dependency_model, dependency_context), + ) + + +def _regular_input_identity(param: _FlowModelParam, value: Any, context: ContextBase) -> _RegularInputIdentity: + if _is_model_dependency(value): + payload = _lazy_regular_dependency_identity(value, context) if param.is_lazy else _regular_dependency_identity(value, context) + else: + payload = _LiteralIdentity(kind="literal", value=value) + return _RegularInputIdentity(kind="regular_input", name=param.name, lazy=param.is_lazy, payload=payload) + + def _collect_contextual_values( model: "_GeneratedFlowModelBase", config: _FlowModelConfig, @@ -745,22 +1064,27 @@ def _validate_declared_context_values(config: _FlowModelConfig, values: Dict[str return {param.name: getattr(validated, param.name) for param in config.contextual_params} -def _validate_declared_context_field(config: _FlowModelConfig, name: str, value: Any) -> Any: - if config.declared_context_type is None: - return _UNSET +def _declared_context_field_annotation(config: _FlowModelConfig, name: str) -> Any: + """Return a field-level annotation preserving declared context constraints.""" - try: - validated = config.declared_context_type.model_validate({name: value}) - except ValidationError as exc: - field_errors = [error for error in exc.errors() if error.get("loc") and error["loc"][0] == name] - if field_errors: - raise + assert config.declared_context_type is not None + field_info = config.declared_context_type.model_fields[name] + if field_info.metadata: + return Annotated.__class_getitem__((field_info.annotation, *field_info.metadata)) + return field_info.annotation + + +def _coerce_declared_context_field(config: _FlowModelConfig, name: str, value: Any) -> Any: + if config.declared_context_type is None: return _UNSET - return getattr(validated, name) + annotation = _declared_context_field_annotation(config, name) + if not _can_validate_type(annotation): + return value + return _type_adapter(annotation).validate_python(value) def _coerce_contextual_value(config: _FlowModelConfig, param: _FlowModelParam, value: Any, source: str) -> Any: - declared_value = _validate_declared_context_field(config, param.name, value) + declared_value = _coerce_declared_context_field(config, param.name, value) if declared_value is not _UNSET: return declared_value return _coerce_value(param.name, value, param.validation_annotation, source) @@ -785,9 +1109,15 @@ def _coerce_model_context_value(model: CallableModel, field_name: str, value: An # Effective identity helpers # --------------------------------------------------------------------------- - -def _identity_context_values_for_model(model: CallableModel, context: ContextBase) -> Dict[str, Any]: - return _identity_context_values_for_model_values(model, _context_values(context)) +# Identity terms used below: +# - config identity: stable hash of the analyzed Flow.model contract, fixed at +# generated-class construction time and carried through local restore. +# - behavior token: tokenizer hook for invalidating cache keys when generated +# model behavior changes. +# - model type identity: importable factory path when available, otherwise the +# fixed local config identity; used inside effective cache-key payloads. +# - effective invocation identity: the full context/input payload for one model +# evaluation, with unused ambient FlowContext fields removed. def _identity_context_values_for_model_values(model: CallableModel, values: Dict[str, Any]) -> Dict[str, Any]: @@ -829,9 +1159,9 @@ def _context_transform_missing_context_names(binding: ContextTransform, values: def _evaluate_context_transform_from_values(binding: ContextTransform, values: Dict[str, Any]) -> Any: """Run a context transform against a raw value mapping. - Transform contextual inputs are read from the original runtime context for a - binding layer. This keeps field transforms independent of earlier patches - in the same ``with_context(...)`` call and makes ordering rules explicit. + ``with_context`` transforms read from the original ambient context. Chained + bindings preserve write order, but they are not an implicit transform + pipeline; dependent rewrites should be expressed inside one transform. """ config = _load_context_transform_config_from_binding(binding) @@ -845,7 +1175,7 @@ def _evaluate_context_transform_from_values(binding: ContextTransform, values: D else: raise TypeError( f"Missing contextual input(s) for context transform {_callable_name(config.func)}: {param.name}. " - "Supply them via the runtime context or with_context() ordering." + "Supply them via the runtime context, a transform default, or combine dependent rewrites into one patch transform." ) return config.func(**kwargs) @@ -862,56 +1192,173 @@ def _apply_context_spec_values_for_identity( dependencies do not collapse accidentally. """ - current_values = _context_values(context) + original_values = _context_values(context) + current_values = dict(original_values) missing_transforms: List[Tuple[str, Tuple[str, ...]]] = [] - for patch in context_spec.patches: - missing = _context_transform_missing_context_names(patch.binding, _context_values(context)) - if missing: - missing_transforms.append((_context_transform_identifier(patch.binding), missing)) + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + missing = _context_transform_missing_context_names(operation.binding, original_values) + if missing: + missing_transforms.append((_context_transform_identifier(operation.binding), missing)) + continue + result = _evaluate_context_transform_from_values(operation.binding, original_values) + current_values.update(_validate_patch_result(model, result)) continue - result = _evaluate_context_transform_from_values(patch.binding, _context_values(context)) - current_values.update(_validate_patch_result(model, result)) - for name, spec in context_spec.field_overrides.items(): - if isinstance(spec, StaticValueSpec): - current_values[name] = spec.value + if isinstance(operation.spec, StaticValueSpec): + current_values[operation.name] = operation.spec.value continue - missing = _context_transform_missing_context_names(spec.binding, _context_values(context)) + missing = _context_transform_missing_context_names(operation.spec, original_values) if missing: - missing_transforms.append((name, missing)) - current_values.pop(name, None) + missing_transforms.append((operation.name, missing)) + current_values.pop(operation.name, None) continue - result = _evaluate_context_transform_from_values(spec.binding, _context_values(context)) - current_values[name] = _coerce_model_context_value(model, name, result, "with_context()") + result = _evaluate_context_transform_from_values(operation.spec, original_values) + current_values[operation.name] = _coerce_model_context_value(model, operation.name, result, "with_context()") return current_values, tuple(missing_transforms) +def _unresolved_lazy_model_identity(value: CallableModel) -> Dict[str, Any]: + """Return a stable structural model payload for unresolved lazy identity.""" + + if isinstance(value, BoundModel): + return { + "kind": "bound_model", + "model": _unresolved_lazy_model_identity(value.model), + "context_spec": value.context_spec.model_dump(mode="python"), + } + + dump = _stable_model_identity_dump(value) + return dump + + +def _contains_model_dependency(value: Any) -> bool: + if _is_model_dependency(value): + return True + if isinstance(value, (tuple, list, frozenset, set)): + return any(_contains_model_dependency(item) for item in value) + if isinstance(value, dict): + return any(_contains_model_dependency(item) for item in value.values()) + return False + + +def _stable_model_identity_dump(value: Any) -> Any: + if _is_model_dependency(value): + dump = value.model_dump(mode="python") + if isinstance(dump, dict): + dump = _stable_model_identity_dump(dump) + for name in type(value).model_fields: + field_value = getattr(value, name, _UNSET_FLOW_INPUT) + if name in dump and _contains_model_dependency(field_value): + dump[name] = _stable_model_identity_dump(field_value) + dump["type_"] = _model_type_identity(value) + return dump + if isinstance(value, tuple): + return tuple(_stable_model_identity_dump(item) for item in value) + if isinstance(value, list): + return [_stable_model_identity_dump(item) for item in value] + if isinstance(value, OrderedDict): + return OrderedDict((key, _stable_model_identity_dump(item)) for key, item in value.items()) + if isinstance(value, dict): + return {key: _stable_model_identity_dump(item) for key, item in value.items()} + if isinstance(value, frozenset): + return frozenset(_stable_model_identity_dump(item) for item in value) + if isinstance(value, set): + return {_stable_model_identity_dump(item) for item in value} + return value + + def _unresolved_lazy_dependency_descriptor( value: CallableModel, context_values: Dict[str, Any], missing_context: Tuple[str, ...], missing_transform_context: Tuple[Tuple[str, Tuple[str, ...]], ...] = (), -) -> Dict[str, Any]: +) -> _UnresolvedLazyDependencyIdentity: """Describe a lazy dependency whose runtime context cannot be resolved yet.""" - return { - "kind": "unresolved_lazy_dependency", - "model_type": str(PyObjectPath.validate(type(value))), - "model": value.model_dump(mode="python"), - "context_type": str(PyObjectPath.validate(FlowContext)), - "context": context_values, - "missing_context": missing_context, - "missing_transform_context": missing_transform_context, - } + return _UnresolvedLazyDependencyIdentity( + kind="unresolved_lazy_dependency", + model_type=_model_type_identity(value), + model=_unresolved_lazy_model_identity(value), + context_type=str(PyObjectPath.validate(FlowContext)), + context=context_values, + missing_context=missing_context, + missing_transform_context=missing_transform_context, + ) + + +def _flow_model_config_identity(config: _FlowModelConfig) -> str: + """Return a stable identity for a generated model's analyzed behavior. + + This is for local generated classes whose ``PyObjectPath`` points at a + random ``ccflow.local_persistence._Local_*`` name. The identity is computed + once when the generated class is created and then stored in factory kwargs + so cache keys survive pickle/cloudpickle restore. Do not recompute it on + every cache-key request: function closures can contain mutable state such as + call counters, and those values are runtime state, not model identity. + """ + + return compute_data_token( + ( + "local_generated_flow_model", + config.func, + config.return_annotation, + config.context_type, + config.result_type, + config.auto_wrap_result, + config.auto_unwrap, + tuple( + ( + param.name, + param.annotation, + param.is_contextual, + param.is_lazy, + param.has_function_default, + param.function_default, + param.context_validation_annotation, + ) + for param in config.parameters + ), + config.declared_context_type, + ) + ) + + +def _generated_model_behavior_token(config_identity: str, model_base: Type[CallableModel]) -> str: + return compute_data_token( + ( + "generated_flow_model_behavior", + config_identity, + f"{model_base.__module__}.{model_base.__qualname__}", + compute_behavior_token(model_base), + ) + ) + + +def _model_type_identity(model: CallableModel) -> str: + """Return a stable model-type identity for effective generated-model keys.""" + + generated = _generated_model_instance(model) + if generated is None: + return str(PyObjectPath.validate(type(model))) + + config = type(generated).__flow_model_config__ + factory_path = _generated_model_factory_path_for_pickle(config, type(generated)) + if factory_path is not None: + return factory_path + identity = getattr(type(generated), "__flow_model_identity__", None) + if identity is None: + identity = _flow_model_config_identity(config) + return f"local:{identity}" def _lazy_dependency_identity( value: CallableModel, context: ContextBase, -) -> Tuple[Optional[Dict[str, Any]], Optional[CallableModel], Optional[ContextBase]]: +) -> Tuple[Optional[Any], Optional[CallableModel], Optional[ContextBase]]: """Resolve or describe a lazy dependency for effective identity. If all required context is available, return the concrete dependency @@ -960,17 +1407,27 @@ def _validate_bound_param_value( ) return _coerce_contextual_value(config, param, value, source) - if param.is_lazy and not _is_model_dependency(value): - raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") + if param.is_lazy: + if not _is_model_dependency(value): + value = _resolve_serialized_dependency_ref(value, include_target_alias=True) + if not _is_model_dependency(value): + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") + return value if _is_model_dependency(value): return value - return _coerce_value(param.name, value, param.annotation, source) + try: + return _coerce_value(param.name, value, param.annotation, source) + except TypeError: + value_with_alias_deps = _resolve_serialized_dependency_ref(value, include_target_alias=True) + if value_with_alias_deps is not value and _is_model_dependency(value_with_alias_deps): + return value_with_alias_deps + raise def _generated_model_identity_payload( model: "_GeneratedFlowModelBase", context: ContextBase, -) -> Optional[Dict[str, Any]]: +) -> Optional[_GeneratedModelIdentity]: """Describe the generated model's effective invocation for cache keys. Contract: @@ -993,33 +1450,17 @@ def _generated_model_identity_payload( value = getattr(model, param.name, _UNSET_FLOW_INPUT) if _is_unset_flow_input(value): return None - - descriptor = {"name": param.name, "lazy": param.is_lazy} - if _is_model_dependency(value): - if param.is_lazy: - unresolved, dependency_model, dependency_context = _lazy_dependency_identity(value, context) - if unresolved is not None: - descriptor.update(unresolved) - else: - assert dependency_model is not None - assert dependency_context is not None - descriptor.update({"kind": "dependency", "evaluation": EvaluationDependency(dependency_model, dependency_context)}) - else: - dependency_model, dependency_context = _resolved_dependency_invocation(value, context) - descriptor.update({"kind": "dependency", "evaluation": EvaluationDependency(dependency_model, dependency_context)}) - else: - descriptor.update({"kind": "literal", "value": value}) - regular_inputs.append(descriptor) + regular_inputs.append(_regular_input_identity(param, value, context)) model_base_fields = {name: getattr(model, name) for name in sorted(_model_base_field_names(model))} - return { - "kind": "generated_flow_model_v1", - "model_type": str(PyObjectPath.validate(type(model))), - "contextual_inputs": _resolved_contextual_inputs(model, config, context), - "regular_inputs": regular_inputs, - "model_base_fields": model_base_fields, - } + return _GeneratedModelIdentity( + kind="generated_flow_model_v1", + model_type=_model_type_identity(model), + contextual_inputs=_resolved_contextual_inputs(model, config, context), + regular_inputs=tuple(regular_inputs), + model_base_fields=model_base_fields, + ) # --------------------------------------------------------------------------- @@ -1030,10 +1471,9 @@ def _generated_model_identity_payload( def _resolved_static_contextual_values( model: "_GeneratedFlowModelBase", config: _FlowModelConfig, - static_overrides: Optional[Dict[str, StaticValueSpec]] = None, + static_overrides: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: - override_values = {name: spec.value for name, spec in (static_overrides or {}).items()} - resolved, missing = _collect_contextual_values(model, config, override_values) + resolved, missing = _collect_contextual_values(model, config, static_overrides or {}) return None if missing else resolved @@ -1056,7 +1496,7 @@ def _bound_context_transform_regular_kwargs(config: _FlowModelConfig, binding: C kwargs: Dict[str, Any] = {} for param in config.regular_params: if param.name in binding.bound_args: - kwargs[param.name] = binding.bound_args[param.name] + kwargs[param.name] = _coerce_value(param.name, binding.bound_args[param.name], param.annotation, "Context transform argument") elif param.has_function_default: kwargs[param.name] = param.function_default else: @@ -1079,54 +1519,59 @@ def _evaluate_static_context_transform(binding: ContextTransform) -> Any: return config.func(**kwargs) -def _static_field_override_value(model: CallableModel, field_name: str, spec: _FieldOverrideSpec) -> Any: - if isinstance(spec, StaticValueSpec): - return spec.value - - value = _evaluate_static_context_transform(spec.binding) - if value is _UNSET: - return _UNSET - - contract = _model_context_contract(model) - if contract.input_types is None or field_name not in contract.input_types: - return value - return _coerce_model_context_value(model, field_name, value, "with_context()") - - def _statically_resolved_context_values(model: CallableModel, context_spec: _BoundContextSpec) -> Optional[Dict[str, Any]]: """Return static binding values when the whole spec can be resolved without runtime context.""" values: Dict[str, Any] = {} - for patch in context_spec.patches: - result = _evaluate_static_context_transform(patch.binding) - if result is _UNSET: - return None - values.update(_validate_patch_result(model, result)) + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + result = _evaluate_static_context_transform(operation.binding) + if result is _UNSET: + return None + values.update(_validate_patch_result(model, result)) + continue - for name, spec in context_spec.field_overrides.items(): - value = _static_field_override_value(model, name, spec) + if isinstance(operation.spec, StaticValueSpec): + value = operation.spec.value + else: + value = _evaluate_static_context_transform(operation.spec) + if value is not _UNSET: + value = _coerce_model_context_value(model, operation.name, value, "with_context()") if value is _UNSET: return None - values[name] = value + values[operation.name] = value return values -def _statically_resolved_context_field_names(model: CallableModel, context_spec: _BoundContextSpec) -> Set[str]: - names: Set[str] = set() +def _statically_resolved_context_field_values(model: CallableModel, context_spec: _BoundContextSpec) -> Dict[str, Any]: + values: Dict[str, Any] = {} - for patch in context_spec.patches: - result = _evaluate_static_context_transform(patch.binding) - if result is _UNSET: + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + result = _evaluate_static_context_transform(operation.binding) + if result is _UNSET: + values.clear() + continue + values.update(_validate_patch_result(model, result)) continue - names.update(_validate_patch_result(model, result)) - for name, spec in context_spec.field_overrides.items(): - if _static_field_override_value(model, name, spec) is not _UNSET: - names.add(name) + if isinstance(operation.spec, StaticValueSpec): + values[operation.name] = operation.spec.value + continue - return names + value = _evaluate_static_context_transform(operation.spec) + if value is _UNSET: + values.pop(operation.name, None) + continue + values[operation.name] = _coerce_model_context_value(model, operation.name, value, "with_context()") + + return values + + +def _statically_resolved_context_field_names(model: CallableModel, context_spec: _BoundContextSpec) -> Set[str]: + return set(_statically_resolved_context_field_values(model, context_spec)) def _context_transform_input_types(binding: ContextTransform, *, required_only: bool) -> Dict[str, Any]: @@ -1135,6 +1580,36 @@ def _context_transform_input_types(binding: ContextTransform, *, required_only: return {name: config.context_input_types[name] for name in names} +def _merge_context_input_types(target: Dict[str, Any], updates: Dict[str, Any]) -> None: + """Merge context input annotations without silently hiding conflicts.""" + + for name, annotation in updates.items(): + if name in target and target[name] != annotation: + raise TypeError(f"Conflicting runtime context annotations for {name!r}: {target[name]!r} and {annotation!r}.") + target[name] = annotation + + +def _merge_dynamic_context_operation_inputs( + target: Dict[str, Any], model: CallableModel, context_spec: _BoundContextSpec, *, required_only: bool +) -> None: + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + patch_result = _evaluate_static_context_transform(operation.binding) + if patch_result is _UNSET: + _merge_context_input_types(target, _context_transform_input_types(operation.binding, required_only=required_only)) + continue + continue + + if isinstance(operation.spec, StaticValueSpec): + continue + + value = _evaluate_static_context_transform(operation.spec) + if value is _UNSET: + target.pop(operation.name, None) + _merge_context_input_types(target, _context_transform_input_types(operation.spec, required_only=required_only)) + continue + + def _validate_static_context_spec_declared_context(model: CallableModel, context_spec: _BoundContextSpec) -> _BoundContextSpec: generated = _generated_model_instance(model) if generated is None: @@ -1148,8 +1623,7 @@ def _validate_static_context_spec_declared_context(model: CallableModel, context if static_context_values is None: return context_spec - static_overrides = {name: StaticValueSpec(value=value) for name, value in static_context_values.items()} - resolved = _resolved_static_contextual_values(generated, config, static_overrides) + resolved = _resolved_static_contextual_values(generated, config, static_context_values) if resolved is None: return context_spec @@ -1215,36 +1689,39 @@ def _validate_patch_result(model: CallableModel, result: Any) -> Dict[str, Any]: def _normalize_with_context(model: CallableModel, patches: Tuple[Any, ...], field_overrides: Dict[str, Any]) -> _BoundContextSpec: """Validate and normalize user-facing ``with_context(...)`` arguments.""" - normalized_patches = [] + operations: List[_ContextOperation] = [] for patch in patches: if callable(patch): - raise TypeError("with_context() no longer accepts raw callables. Replace the callable with a top-level @Flow.context_transform binding.") + raise TypeError("Positional with_context() arguments must be bound @Flow.context_transform results that return a mapping.") if not isinstance(patch, ContextTransform): raise TypeError("Positional with_context() arguments must be @Flow.context_transform bindings that return a mapping.") if not _binding_uses_patch_shape(patch): raise TypeError( "Field context transforms must be passed by keyword to with_context(...). Patch transforms belong in positional arguments." ) - normalized_patches.append(PatchContextSpec(binding=patch)) + operations.append(PatchContextOperation(binding=patch)) _validate_with_context_field_names(model, list(field_overrides)) contract = _model_context_contract(model) - normalized_field_overrides: Dict[str, _FieldOverrideSpec] = {} for name, value in field_overrides.items(): - if callable(value): - raise TypeError("with_context() no longer accepts raw callables. Replace the callable with a top-level @Flow.context_transform binding.") if isinstance(value, ContextTransform): if _binding_uses_patch_shape(value): raise TypeError("Patch transforms must be passed positionally to with_context(...), not as keyword field overrides.") - normalized_field_overrides[name] = FieldContextSpec(binding=value) + operations.append(FieldContextOperation(name=name, spec=value)) continue - normalized_field_overrides[name] = StaticValueSpec( + if callable(value) and (contract.input_types is None or name not in contract.input_types): + raise TypeError( + "Callable keyword values in with_context() must either be bound @Flow.context_transform results " + "or validate against a declared contextual field type." + ) + spec = StaticValueSpec( value=value if contract.input_types is None or name not in contract.input_types else _coerce_model_context_value(model, name, value, "with_context()") ) + operations.append(FieldContextOperation(name=name, spec=spec)) - context_spec = _BoundContextSpec(patches=normalized_patches, field_overrides=normalized_field_overrides) + context_spec = _BoundContextSpec(operations=operations) return _validate_static_context_spec_declared_context(model, context_spec) @@ -1256,18 +1733,20 @@ def _normalize_with_context(model: CallableModel, patches: Tuple[Any, ...], fiel def _apply_context_spec_values(model: CallableModel, context_spec: _BoundContextSpec, context: ContextBase) -> Dict[str, Any]: """Apply a binding spec at execution time and return rewritten context values.""" - current_values = _context_values(context) + original_values = _context_values(context) + current_values = dict(original_values) - for patch in context_spec.patches: - result = _evaluate_context_transform_from_values(patch.binding, _context_values(context)) - current_values.update(_validate_patch_result(model, result)) + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + result = _evaluate_context_transform_from_values(operation.binding, original_values) + current_values.update(_validate_patch_result(model, result)) + continue - for name, spec in context_spec.field_overrides.items(): - if isinstance(spec, StaticValueSpec): - current_values[name] = spec.value + if isinstance(operation.spec, StaticValueSpec): + current_values[operation.name] = operation.spec.value continue - result = _evaluate_context_transform_from_values(spec.binding, _context_values(context)) - current_values[name] = _coerce_model_context_value(model, name, result, "with_context()") + result = _evaluate_context_transform_from_values(operation.spec, original_values) + current_values[operation.name] = _coerce_model_context_value(model, operation.name, result, "with_context()") return current_values @@ -1275,10 +1754,18 @@ def _apply_context_spec_values(model: CallableModel, context_spec: _BoundContext def _apply_context_spec(model: CallableModel, context_spec: _BoundContextSpec, context: ContextBase) -> ContextBase: """Apply bindings, project to the wrapped model, and build its runtime context.""" - if not context_spec.patches and not context_spec.field_overrides: + if not context_spec.operations: + if isinstance(context, _BoundModelContext): + return _dependency_context_for_model(model, context) + if _context_matches_type(context, model.context_type): + return context return _dependency_context_for_model(model, context) values = _apply_context_spec_values(model, context_spec, context) + if isinstance(context, _BoundModelContext): + return _runtime_context_for_model(model, _project_context_values_for_model(model, values)) + if _context_matches_type(context, model.context_type): + return type(context).model_validate(values) return _runtime_context_for_model(model, _project_context_values_for_model(model, values)) @@ -1294,7 +1781,7 @@ def _build_compute_context(model: CallableModel, context: Any, kwargs: Dict[str, raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") ctx_type = model.context_type - _ctx_is_optional = get_origin(ctx_type) in _UNION_ORIGINS and type(None) in get_args(ctx_type) + _ctx_is_optional = _is_optional_context_type(ctx_type) contract = _model_context_contract(model) @@ -1304,6 +1791,8 @@ def _build_compute_context(model: CallableModel, context: Any, kwargs: Dict[str, if isinstance(context, FlowContext): return context if isinstance(context, ContextBase): + if _context_matches_type(context, model.context_type): + return context return _runtime_context_for_model(model, _context_values(context)) return contract.runtime_context_type.model_validate(context) @@ -1350,12 +1839,18 @@ def _is_optional_context_type(context_type: Any) -> bool: return get_origin(context_type) in _UNION_ORIGINS and type(None) in get_args(context_type) +def _context_matches_type(context: Any, context_type: Any) -> bool: + """Return whether an existing context object is accepted by a context annotation.""" + + if context is None: + return _is_optional_context_type(context_type) + if get_origin(context_type) in _UNION_ORIGINS: + return any(_context_matches_type(context, arg) for arg in get_args(context_type) if arg is not type(None)) + return isinstance(context_type, type) and isinstance(context, context_type) + + def _bound_model_preserves_none_context(bound_model: "BoundModel") -> bool: - return ( - not bound_model.context_spec.patches - and not bound_model.context_spec.field_overrides - and _is_optional_context_type(bound_model.model.context_type) - ) + return not bound_model.context_spec.operations and _is_optional_context_type(bound_model.model.context_type) def _build_bound_compute_context(bound_model: "BoundModel", context: Any, kwargs: Dict[str, Any]) -> Optional[ContextBase]: @@ -1400,14 +1895,20 @@ def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = No @property def context_inputs(self) -> Dict[str, Any]: - """Contextual input names and expected types for this model.""" + """Declared contextual input names and expected types for this model.""" contract = _model_context_contract(self._model) return dict(contract.input_types or {}) @property - def unbound_inputs(self) -> Dict[str, Any]: - """Required contextual inputs that are not already satisfied.""" + def runtime_inputs(self) -> Dict[str, Any]: + """Direct runtime context fields this model may read from the caller.""" + + return self.context_inputs + + @property + def required_inputs(self) -> Dict[str, Any]: + """Required direct runtime context fields still needed from the caller.""" contract = _model_context_contract(self._model) if contract.generated_model is None and _is_optional_context_type(self._model.context_type): @@ -1479,7 +1980,7 @@ class BoundModel(WrapperModel): context_spec: _BoundContextSpec = Field(default_factory=_BoundContextSpec, repr=False) def __reduce__(self): - return (_restore_pickled_flow_model, (str(PyObjectPath.validate(type(self))), self.__getstate__())) + return (_restore_pickled_flow_model, (str(PyObjectPath.validate(type(self))), _portable_generated_model_state(self))) def _rewrite_context(self, context: ContextBase) -> ContextBase: """Apply this wrapper's context bindings to an ambient runtime context.""" @@ -1490,7 +1991,7 @@ def _rewrite_context(self, context: ContextBase) -> ContextBase: def context_type(self) -> Any: if _bound_model_preserves_none_context(self): return self.model.context_type - return FlowContext + return _BoundModelContext @Flow.call def __call__(self, context: ContextType) -> ResultBase: @@ -1509,11 +2010,13 @@ def __deps__(self, context: ContextType) -> GraphDepList: return [(self.model, [self._rewrite_context(context)])] def __repr__(self) -> str: - args = [_context_transform_repr(patch.binding) for patch in self.context_spec.patches] - args.extend( - f"{name}={_context_transform_repr(spec.binding if isinstance(spec, FieldContextSpec) else spec.value)}" - for name, spec in self.context_spec.field_overrides.items() - ) + args = [] + for operation in _effective_context_operations(self.context_spec): + if isinstance(operation, PatchContextOperation): + args.append(_context_transform_repr(operation.binding)) + continue + value = operation.spec if isinstance(operation.spec, ContextTransform) else operation.spec.value + args.append(f"{operation.name}={_context_transform_repr(value)}") return f"{self.model!r}.flow.with_context({', '.join(args)})" def _evaluation_identity_payload( @@ -1551,50 +2054,51 @@ def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = No @property def bound_inputs(self) -> Dict[str, Any]: + """Concrete values already fixed, including statically resolved context bindings.""" + result = super().bound_inputs - for patch in self._bound.context_spec.patches: - patch_result = _evaluate_static_context_transform(patch.binding) - if patch_result is not _UNSET: - result.update(_validate_patch_result(self._bound.model, patch_result)) - for name, spec in self._bound.context_spec.field_overrides.items(): - value = _static_field_override_value(self._bound.model, name, spec) - if value is not _UNSET: - result[name] = value - else: - result.pop(name, None) + for name in self.context_inputs: + result.pop(name, None) + result.update(_statically_resolved_context_field_values(self._bound.model, self._bound.context_spec)) return result @property def context_inputs(self) -> Dict[str, Any]: + """Declared contextual inputs of the wrapped model.""" + + return super().context_inputs + + @property + def runtime_inputs(self) -> Dict[str, Any]: + """Direct runtime context inputs after applying this wrapper's bindings. + + Static context transforms may be evaluated to identify resolved fields. + Dynamic transforms contribute their own runtime context inputs. + """ + result = super().context_inputs for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): result.pop(name, None) - for patch in self._bound.context_spec.patches: - if _evaluate_static_context_transform(patch.binding) is _UNSET: - result.update(_context_transform_input_types(patch.binding, required_only=False)) - for name, spec in self._bound.context_spec.field_overrides.items(): - if isinstance(spec, FieldContextSpec) and _static_field_override_value(self._bound.model, name, spec) is _UNSET: - result.pop(name, None) - result.update(_context_transform_input_types(spec.binding, required_only=False)) + _merge_dynamic_context_operation_inputs(result, self._bound.model, self._bound.context_spec, required_only=False) return result @property - def unbound_inputs(self) -> Dict[str, Any]: - result = super().unbound_inputs + def required_inputs(self) -> Dict[str, Any]: + """Required direct runtime context inputs still missing after static bindings. + + Static context transforms may be evaluated to identify resolved fields. + Dynamic transforms contribute their own required runtime context inputs. + """ + + result = super().required_inputs for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): result.pop(name, None) - for patch in self._bound.context_spec.patches: - if _evaluate_static_context_transform(patch.binding) is _UNSET: - result.update(_context_transform_input_types(patch.binding, required_only=True)) - for name, spec in self._bound.context_spec.field_overrides.items(): - if isinstance(spec, FieldContextSpec) and _static_field_override_value(self._bound.model, name, spec) is _UNSET: - result.pop(name, None) - result.update(_context_transform_input_types(spec.binding, required_only=True)) + _merge_dynamic_context_operation_inputs(result, self._bound.model, self._bound.context_spec, required_only=True) return result def with_context(self, *patches, **field_overrides) -> BoundModel: context_spec = _normalize_with_context(self._bound.model, patches, field_overrides) - merged = _merge_context_specs(self._bound.context_spec, context_spec.patches, context_spec.field_overrides) + merged = _BoundContextSpec(operations=[*self._bound.context_spec.operations, *context_spec.operations]) return BoundModel( model=self._bound.model, context_spec=_validate_static_context_spec_declared_context(self._bound.model, merged), @@ -1605,6 +2109,7 @@ class _GeneratedFlowModelBase(CallableModel): """Base class for all classes created by ``@Flow.model``.""" __flow_model_config__: ClassVar[_FlowModelConfig] + __flow_model_identity__: ClassVar[str] def __reduce__(self): """Prefer import-path restoration, falling back to serialized local factories.""" @@ -1612,39 +2117,19 @@ def __reduce__(self): config = type(self).__flow_model_config__ factory_path = _generated_model_factory_path_for_pickle(config, type(self)) if factory_path is not None: - return (_restore_generated_flow_model, (factory_path, self.__getstate__())) + return (_restore_generated_flow_model, (factory_path, _portable_generated_model_state(self))) import cloudpickle - payload = (config.func, type(self).__flow_model_factory_kwargs__) - return (_restore_pickled_local_flow_model, (cloudpickle.dumps(payload, protocol=5), self.__getstate__())) - - @model_validator(mode="before") - @classmethod - def _resolve_registry_refs(cls, values): - """Resolve registry string references for regular dependency fields.""" - - if not isinstance(values, dict): - return values - - config = getattr(cls, "__flow_model_config__", None) - if config is None: - return values - - resolved = dict(values) - for param in config.regular_params: - if param.name not in resolved: - continue - value = resolved[param.name] - if not isinstance(value, str): - continue - if _type_accepts_str(param.annotation): - continue - candidate = _resolve_registry_candidate(value) - if candidate is None: - continue - if _registry_candidate_allowed(param.annotation, candidate): - resolved[param.name] = candidate - return resolved + # Local generated classes are not normal importable class definitions: + # plain pickle cannot reconstruct them, and cloudpickle would otherwise + # walk generated signatures/model metadata that can contain fragile + # runtime-only annotations such as GenericResult[int]. + payload = _LocalFlowModelPicklePayload( + serialized_config=_serialize_flow_model_config(config), + factory_kwargs=type(self).__flow_model_factory_kwargs__, + model_data=_portable_generated_model_state(self), + ) + return (_restore_pickled_local_flow_model, (cloudpickle.dumps(payload, protocol=5),)) @model_validator(mode="after") def _validate_flow_model_fields(self): @@ -1656,6 +2141,8 @@ def _validate_flow_model_fields(self): value = getattr(self, param.name, _UNSET_FLOW_INPUT) if _is_unset_flow_input(value): continue + if not param.is_contextual: + value = _resolve_bound_param_registry_ref(param, value) object.__setattr__( self, param.name, @@ -1715,7 +2202,7 @@ def __call__(self, context): raw_result = config.func(**fn_kwargs) if config.auto_wrap_result: - return config.result_type.model_validate(raw_result) + return GenericResult(value=raw_result) return raw_result cast(Any, __call__).__signature__ = inspect.Signature( @@ -1744,7 +2231,7 @@ def __deps__(self, context): if param.is_lazy: continue value = getattr(self, param.name, _UNSET_FLOW_INPUT) - if isinstance(value, CallableModel): + if _is_model_dependency(value): dependency_model, dependency_context = _resolved_dependency_invocation(value, context) deps.append((dependency_model, [dependency_context])) return deps @@ -1760,6 +2247,14 @@ def __deps__(self, context): def _factory_param_annotation(param: _FlowModelParam) -> Any: + """Return the public factory signature annotation for a user function parameter. + + The factory signature keeps the user's surface syntax visible: contextual + params show as FromContext[T], lazy params show as Lazy[T], and regular + params keep T. The generated Pydantic field annotation below has a different + job: describing the stored construction value after binding. + """ + if param.is_contextual: return FromContext[param.annotation] if param.is_lazy: @@ -1767,6 +2262,38 @@ def _factory_param_annotation(param: _FlowModelParam) -> Any: return param.annotation +def _pydantic_schema_safe_annotation(annotation: Any) -> Any: + # Only the generated Pydantic field declaration falls back to Any for known + # Pydantic schema-build failures. Runtime coercion still builds the real + # TypeAdapter and propagates unexpected errors. + try: + _type_adapter(annotation) + except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation): + return Any + return annotation + + +def _generated_field_annotation(param: _FlowModelParam) -> Any: + """Return the generated model field annotation used for schema only. + + This differs from _factory_param_annotation: the factory signature describes + user-facing inputs (FromContext[T], Lazy[T]), while this describes the value + stored on the generated Pydantic model after binding. SkipValidation keeps + this schema visible without letting Pydantic enforce it before ccflow can + distinguish literals, dependencies, lazy deps, and contextual defaults. + """ + + if param.is_contextual: + annotation = param.validation_annotation + elif param.is_lazy: + annotation = CallableModel + elif param.annotation is Any or param.annotation is inspect.Parameter.empty: + annotation = Any + else: + annotation = Union[param.annotation, CallableModel] + return SkipValidation[_pydantic_schema_safe_annotation(annotation)] + + def _factory_signature(config: _FlowModelConfig, generated_cls: Type[BaseModel]) -> inspect.Signature: """Return the public construction signature for a generated factory.""" @@ -1798,6 +2325,22 @@ def _factory_signature(config: _FlowModelConfig, generated_cls: Type[BaseModel]) return inspect.Signature(parameters=parameters, return_annotation=generated_cls) +def _context_transform_factory_signature(config: _FlowModelConfig) -> inspect.Signature: + """Return the public binding signature for a context transform factory.""" + + parameters = [] + for param in config.regular_params: + parameters.append( + inspect.Parameter( + param.name, + inspect.Parameter.KEYWORD_ONLY, + annotation=param.annotation, + default=param.function_default if param.has_function_default else inspect.Parameter.empty, + ) + ) + return inspect.Signature(parameters=parameters, return_annotation=ContextTransform) + + def _resolve_generated_model_bases(model_base: Type[CallableModel]) -> Tuple[type, ...]: """Return the class bases for a generated model, preserving custom model bases.""" @@ -1811,6 +2354,64 @@ def _resolve_generated_model_bases(model_base: Type[CallableModel]) -> Tuple[typ return (_GeneratedFlowModelBase, model_base) +def _build_flow_model_factory_from_config(config: _FlowModelConfig, factory_kwargs: Dict[str, Any]) -> _AnyCallable: + """Build the generated model class/factory from an analyzed flow-model config.""" + + fn = config.func + factory_kwargs = dict(factory_kwargs) + model_base = factory_kwargs["model_base"] + # Preserve one generated-model identity across local pickle/cloudpickle + # restore. See ``_flow_model_config_identity`` for why this must be fixed at + # class-construction time instead of recalculated on every cache-key build. + config_identity = factory_kwargs.setdefault("_flow_model_identity", _flow_model_config_identity(config)) + annotations: Dict[str, Any] = {} + namespace: Dict[str, Any] = { + "__module__": getattr(fn, "__module__", __name__), + "__qualname__": f"_{_callable_name(fn)}_Model", + "__call__": Flow.call( + **{ + name: value + for name in ("cacheable", "volatile", "log_level", "validate_result", "verbose", "evaluator") + if (value := factory_kwargs.get(name, _UNSET)) is not _UNSET + } + )(_make_call_impl(config)), + "__deps__": Flow.deps(_make_deps_impl(config)), + } + + for param in config.parameters: + annotations[param.name] = _generated_field_annotation(param) + if param.is_contextual: + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + elif param.has_function_default: + namespace[param.name] = param.function_default + else: + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + + namespace["__annotations__"] = annotations + + GeneratedModel = cast( + type[_GeneratedFlowModelBase], + type(f"_{_callable_name(fn)}_Model", _resolve_generated_model_bases(model_base), namespace), + ) + GeneratedModel.__flow_model_config__ = config + GeneratedModel.__flow_model_factory_kwargs__ = factory_kwargs + GeneratedModel.__flow_model_identity__ = config_identity + GeneratedModel.__ccflow_tokenizer_cache__ = _generated_model_behavior_token(config_identity, model_base) + _register_generated_model_class(config, GeneratedModel) + GeneratedModel.model_rebuild() + + @wraps(fn) + def factory(**kwargs) -> _GeneratedFlowModelBase: + """Create a generated model instance with regular/contextual defaults bound.""" + + return GeneratedModel(**kwargs) + + cast(Any, factory)._generated_model = GeneratedModel + cast(Any, factory).__signature__ = _factory_signature(config, GeneratedModel) + factory.__doc__ = fn.__doc__ + return factory + + def flow_context_transform(func: Optional[_AnyCallable] = None) -> _AnyCallable: """Decorator that turns a function into a serializable ``with_context`` transform factory. @@ -1823,11 +2424,8 @@ def flow_context_transform(func: Optional[_AnyCallable] = None) -> _AnyCallable: def decorator(fn: _AnyCallable) -> _AnyCallable: """Analyze one transform function and return its binding factory.""" - _ensure_top_level_named_function(fn, decorator_name="@Flow.context_transform") - try: - resolved_hints = get_type_hints(fn, include_extras=True) - except AttributeError: - resolved_hints = {} + _ensure_named_python_function(fn, decorator_name="@Flow.context_transform") + resolved_hints = get_type_hints(fn, include_extras=True) sig = _resolved_flow_signature( fn, resolved_hints=resolved_hints, @@ -1835,19 +2433,23 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: function_name=_callable_name(fn), ) config = _analyze_flow_context_transform(fn, sig, is_model_dependency=_is_model_dependency) - serialized_config = None if _context_transform_should_use_import_path(config) else _serialize_context_transform_config(config) + # Store the analyzed transform contract directly. Import-path detection + # during decoration is brittle because module globals usually still point + # at the undecorated function until the decorator returns. + serialized_config = _serialize_context_transform_config(config) @wraps(fn) def factory(**kwargs) -> ContextTransform: """Bind regular transform arguments into a serializable spec.""" return ContextTransform( - path=config.path if serialized_config is None else None, + path=None, serialized_config=serialized_config, bound_args=_validate_context_transform_factory_kwargs(config, kwargs), ) cast(Any, factory).__flow_context_transform_config__ = config + cast(Any, factory).__signature__ = _context_transform_factory_signature(config) return factory if func is not None: @@ -1858,7 +2460,6 @@ def factory(**kwargs) -> ContextTransform: def flow_model( func: Optional[_AnyCallable] = None, *, - context_args: Any = _REMOVED_CONTEXT_ARGS, context_type: Optional[Type[ContextBase]] = None, auto_unwrap: bool = False, model_base: Type[CallableModel] = CallableModel, @@ -1878,16 +2479,10 @@ def flow_model( creates instances of the generated model class. """ - if context_args is not _REMOVED_CONTEXT_ARGS: - raise TypeError("context_args=... has been removed. Mark runtime/contextual parameters with FromContext[...] instead.") - def decorator(fn: _AnyCallable) -> _AnyCallable: """Analyze one user function and synthesize its generated model class.""" - try: - resolved_hints = get_type_hints(fn, include_extras=True) - except AttributeError: - resolved_hints = {} + resolved_hints = get_type_hints(fn, include_extras=True) sig = _resolved_flow_signature( fn, resolved_hints=resolved_hints, @@ -1901,45 +2496,7 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: auto_unwrap=auto_unwrap, is_model_dependency=_is_model_dependency, ) - - annotations: Dict[str, Any] = {} - namespace: Dict[str, Any] = { - "__module__": getattr(fn, "__module__", __name__), - "__qualname__": f"_{_callable_name(fn)}_Model", - "__call__": Flow.call( - **{ - name: value - for name, value in [ - ("cacheable", cacheable), - ("volatile", volatile), - ("log_level", log_level), - ("validate_result", validate_result), - ("verbose", verbose), - ("evaluator", evaluator), - ] - if value is not _UNSET - } - )(_make_call_impl(config)), - "__deps__": Flow.deps(_make_deps_impl(config)), - } - - for param in config.parameters: - annotations[param.name] = Any - if param.is_contextual: - namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) - elif param.has_function_default: - namespace[param.name] = param.function_default - else: - namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) - - namespace["__annotations__"] = annotations - - GeneratedModel = cast( - type[_GeneratedFlowModelBase], - type(f"_{_callable_name(fn)}_Model", _resolve_generated_model_bases(model_base), namespace), - ) - GeneratedModel.__flow_model_config__ = config - GeneratedModel.__flow_model_factory_kwargs__ = { + factory_kwargs = { "context_type": context_type, "auto_unwrap": auto_unwrap, "model_base": model_base, @@ -1950,19 +2507,8 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: "verbose": verbose, "evaluator": evaluator, } - _register_generated_model_class(config, GeneratedModel) - GeneratedModel.model_rebuild() - - @wraps(fn) - def factory(**kwargs) -> _GeneratedFlowModelBase: - """Create a generated model instance with regular/contextual defaults bound.""" - - return GeneratedModel(**kwargs) - - cast(Any, factory)._generated_model = GeneratedModel - cast(Any, factory).__signature__ = _factory_signature(config, GeneratedModel) - factory.__doc__ = fn.__doc__ - return factory + factory_kwargs["_flow_model_identity"] = _flow_model_config_identity(config) + return _build_flow_model_factory_from_config(config, factory_kwargs) if func is not None: return decorator(func) diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml index a697229..51fc6ba 100644 --- a/ccflow/tests/config/conf_flow.yaml +++ b/ccflow/tests/config/conf_flow.yaml @@ -1,73 +1,73 @@ # Flow.model configurations for Hydra integration tests. flow_loader: - _target_: ccflow.tests.test_flow_model.basic_loader + _target_: ccflow.tests.flow_model_hydra_fixtures.basic_loader source: fixture_input multiplier: 5 flow_processor: - _target_: ccflow.tests.test_flow_model.string_processor + _target_: ccflow.tests.flow_model_hydra_fixtures.string_processor prefix: "value=" suffix: "!" flow_source: - _target_: ccflow.tests.test_flow_model.data_source + _target_: ccflow.tests.flow_model_hydra_fixtures.data_source base_value: 100 flow_transformer: - _target_: ccflow.tests.test_flow_model.data_transformer + _target_: ccflow.tests.flow_model_hydra_fixtures.data_transformer source: flow_source factor: 3 flow_stage1: - _target_: ccflow.tests.test_flow_model.pipeline_stage1 + _target_: ccflow.tests.flow_model_hydra_fixtures.pipeline_stage1 initial: 10 flow_stage2: - _target_: ccflow.tests.test_flow_model.pipeline_stage2 + _target_: ccflow.tests.flow_model_hydra_fixtures.pipeline_stage2 stage1_output: flow_stage1 multiplier: 2 flow_stage3: - _target_: ccflow.tests.test_flow_model.pipeline_stage3 + _target_: ccflow.tests.flow_model_hydra_fixtures.pipeline_stage3 stage2_output: flow_stage2 offset: 50 diamond_source: - _target_: ccflow.tests.test_flow_model.data_source + _target_: ccflow.tests.flow_model_hydra_fixtures.data_source base_value: 10 diamond_branch_a: - _target_: ccflow.tests.test_flow_model.data_transformer + _target_: ccflow.tests.flow_model_hydra_fixtures.data_transformer source: diamond_source factor: 2 diamond_branch_b: - _target_: ccflow.tests.test_flow_model.data_transformer + _target_: ccflow.tests.flow_model_hydra_fixtures.data_transformer source: diamond_source factor: 5 diamond_aggregator: - _target_: ccflow.tests.test_flow_model.data_aggregator + _target_: ccflow.tests.flow_model_hydra_fixtures.data_aggregator input_a: diamond_branch_a input_b: diamond_branch_b operation: add flow_date_loader: - _target_: ccflow.tests.test_flow_model.date_range_loader_previous_day + _target_: ccflow.tests.flow_model_hydra_fixtures.date_range_loader_previous_day source: calendar_feed include_weekends: false flow_date_processor: - _target_: ccflow.tests.test_flow_model.date_range_processor + _target_: ccflow.tests.flow_model_hydra_fixtures.date_range_processor raw_data: flow_date_loader normalize: true contextual_loader_model: - _target_: ccflow.tests.test_flow_model.contextual_loader + _target_: ccflow.tests.flow_model_hydra_fixtures.contextual_loader source: data_source contextual_processor_model: - _target_: ccflow.tests.test_flow_model.contextual_processor + _target_: ccflow.tests.flow_model_hydra_fixtures.contextual_processor data: contextual_loader_model prefix: output diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index dd39c97..37f3dd0 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -10,7 +10,10 @@ DateRangeContext, Evaluator, EvaluatorBase, + Flow, + FlowContext, FlowOptionsOverride, + FromContext, ModelEvaluationContext, NullContext, TransparentModelEvaluationContext, @@ -406,8 +409,8 @@ def test_plain_callable_key_fallback_does_not_log(self): with self.assertNoLogs("ccflow.evaluators.common", level="DEBUG"): self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) - def test_effective_key_exception_fallback_logs_debug(self): - """Unexpected effective-identity failures are visible without breaking existing calls.""" + def test_effective_key_identity_errors_propagate(self): + """Unexpected effective-identity failures should not be hidden by structural fallback.""" class BadIdentityCallable(MyDateCallable): def _evaluation_identity_payload(self, context): @@ -418,9 +421,8 @@ def _evaluation_identity_payload(self, context): context = DateContext(date=date(2022, 1, 1)) model_evaluation_context = ModelEvaluationContext(model=m1, context=context, options=dict(cacheable=True)) - with self.assertLogs("ccflow.evaluators.common", level="DEBUG") as captured: - self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) - self.assertIn("Falling back to structural evaluation key for BadIdentityCallable.__call__: identity broke", captured.output[0]) + with self.assertRaisesRegex(ValueError, "identity broke"): + evaluator.key(model_evaluation_context) def test_plain_callable_deps_key_matches_public_cache_key(self): """Non-__call__ evaluations stay structural.""" @@ -646,6 +648,27 @@ def test_graph_evaluator_basic(self): self.assertIn(("n0", date(2022, 1, 1)), graph_calls[:2]) self.assertIn(("n1", date(2022, 1, 1)), graph_calls[:2]) + def test_graph_evaluator_root_id_uses_built_graph_key(self): + calls = 0 + + @Flow.context_transform + def bump(seed: FromContext[int]) -> int: + nonlocal calls + calls += 1 + return seed + calls + + @Flow.model + def add(a: FromContext[int]) -> int: + return a + + model = add().flow.with_context(a=bump()) + graph = get_dependency_graph(model.__call__.get_evaluation_context(model, FlowContext(seed=10))) + result = model.flow.compute(FlowContext(seed=10), _options={"evaluator": GraphEvaluator()}) + + self.assertIn(graph.root_id, graph.ids) + self.assertIsNotNone(result) + self.assertGreater(result.value, 10) + def test_graph_evaluator_diamond(self): n0 = NodeModel(meta=dict(name="n0")) n1 = NodeModel(meta=dict(name="n1"), deps_model=[n0]) diff --git a/ccflow/tests/flow_model_hydra_fixtures.py b/ccflow/tests/flow_model_hydra_fixtures.py new file mode 100644 index 0000000..f99baf4 --- /dev/null +++ b/ccflow/tests/flow_model_hydra_fixtures.py @@ -0,0 +1,96 @@ +"""Flow.model fixtures used by Hydra integration tests.""" + +from datetime import date, timedelta + +from ccflow import ContextBase, Flow, FromContext, GenericResult + + +class SimpleContext(ContextBase): + value: int + + +@Flow.model +def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + +@Flow.model +def string_processor(value: FromContext[int], prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: + return GenericResult(value=f"{prefix}{value}{suffix}") + + +@Flow.model +def data_source(base_value: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + base_value) + + +@Flow.model +def data_transformer(source: int, factor: int) -> GenericResult[int]: + return GenericResult(value=source * factor) + + +@Flow.model +def data_aggregator(input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: + if operation == "add": + return GenericResult(value=input_a + input_b) + raise ValueError(f"unsupported operation: {operation}") + + +@Flow.model +def pipeline_stage1(initial: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + initial) + + +@Flow.model +def pipeline_stage2(stage1_output: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=stage1_output * multiplier) + + +@Flow.model +def pipeline_stage3(stage2_output: int, offset: int) -> GenericResult[int]: + return GenericResult(value=stage2_output + offset) + + +@Flow.model +def date_range_loader_previous_day( + source: str, + start_date: FromContext[date], + end_date: FromContext[date], + include_weekends: bool = False, +) -> GenericResult[dict]: + del include_weekends + return GenericResult( + value={ + "source": source, + "start_date": str(start_date - timedelta(days=1)), + "end_date": str(end_date), + } + ) + + +@Flow.model +def date_range_processor(raw_data: dict, normalize: bool = False) -> GenericResult[str]: + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") + + +@Flow.model +def contextual_loader(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) + + +@Flow.model +def contextual_processor( + prefix: str, + data: dict, + start_date: FromContext[date], + end_date: FromContext[date], +) -> GenericResult[str]: + del start_date, end_date + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 8c62ac0..d39f3e3 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -789,6 +789,18 @@ def foo(self, context): class TestAutoContext(TestCase): """Tests for the opt-in @Flow.call(auto_context=...) path.""" + def test_direct_function_call_form(self): + class AutoContextCallable(CallableModel): + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + __call__ = Flow.call(__call__, auto_context=True) + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertEqual(AutoContextCallable()(auto_ctx(x=3)).value, 3) + def test_basic_usage_with_kwargs(self): class AutoContextCallable(CallableModel): @Flow.call(auto_context=True) diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index 64d71e8..611e447 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -13,6 +13,7 @@ DateContext, DateRangeContext, DatetimeContext, + FlowContext, FreqContext, FreqDateContext, FreqDateRangeContext, @@ -53,6 +54,13 @@ def test_null_context_validation(self): self.assertIsInstance(NullContext.model_validate(DateContext(date="0d")), NullContext) self.assertRaises(ValueError, NullContext.model_validate, [True]) + def test_flow_context_hash_freezes_nested_pydantic_values(self): + c1 = FlowContext(payload=DateContext(date=date(2024, 1, 1))) + c2 = FlowContext(payload=DateContext(date="2024-01-01")) + + self.assertEqual(c1, c2) + self.assertEqual(hash(c1), hash(c2)) + def test_context_with_defaults(self): # Contexts may define default values. Extending the assumptions above: # Any context inherits the behavior from NullContext, and can be diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 1bccc55..2eec6f9 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -68,6 +68,17 @@ def test_flow_context_value_semantics_and_hash(): assert len({first, second, third}) == 2 +def test_flow_context_hash_reflects_nested_mutable_changes(): + ctx = FlowContext(values=[1]) + assert hash(ctx) == hash(FlowContext(values=[1])) + + ctx.values.append(2) + + assert ctx == FlowContext(values=[1, 2]) + assert ctx != FlowContext(values=[1]) + assert hash(ctx) == hash(FlowContext(values=[1, 2])) + + def test_flow_context_hash_handles_nested_models_and_rejects_opaque_unhashable_values(): class WithDict: __hash__ = None @@ -85,7 +96,9 @@ class UnhashableNoState: assert nested == nested assert nested != {"model": NumberContext(x=1)} assert hash(nested) == hash(same) - assert hash(FlowContext(value=WithDict())) == hash(FlowContext(value=WithDict())) + + with pytest.raises(TypeError, match="unhashable value"): + hash(FlowContext(value=WithDict())) with pytest.raises(TypeError, match="unhashable value"): hash(FlowContext(value=UnhashableNoState())) @@ -104,7 +117,8 @@ def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: model = add(a=10) assert model.flow.context_inputs == {"b": int, "c": int} - assert model.flow.unbound_inputs == {"b": int} + assert model.flow.runtime_inputs == {"b": int, "c": int} + assert model.flow.required_inputs == {"b": int} assert model.flow.bound_inputs == {"a": 10} assert model.flow.compute(b=2).value == 17 @@ -163,10 +177,10 @@ def load(start_date: FromContext[date], end_date: FromContext[date]) -> dict: # First with_context applies a patch, second chains another patch on top chained = base.flow.with_context(shift_window(days=7)).flow.with_context(shift_window(days=3)) - # Both patches should be present in the merged context spec. - assert len(chained.context_spec.patches) == 2 + # Both patches should be present in the ordered context spec. + assert [operation.kind for operation in chained.context_spec.operations] == ["patch", "patch"] - # Patches evaluate against the original context, merge left-to-right. + # Patches read the original context, then apply writes left-to-right. # patch1: start - 7, end - 7 => Jan 1, Jan 24 # patch2: start - 3, end - 3 => Jan 5, Jan 28 (overwrites patch1 keys) result = chained(FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31))) @@ -189,9 +203,49 @@ def combine(left: int, right: int, bonus: FromContext[int]) -> int: ) assert model.flow.context_inputs == {"bonus": int} + assert model.flow.runtime_inputs == {"bonus": int} assert model.flow.compute(value=5, bonus=100).value == (6 + 15 + 100) +def test_bound_flow_api_separates_declared_and_runtime_context_inputs(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + @Flow.context_transform + def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + bound = add(a=10).flow.with_context(b=from_seed()) + + assert bound.flow.context_inputs == {"b": int, "c": int} + assert bound.flow.runtime_inputs == {"c": int, "seed": int} + assert bound.flow.required_inputs == {"seed": int} + for name, annotation in bound.flow.required_inputs.items(): + assert bound.flow.runtime_inputs[name] == annotation + assert bound.flow.bound_inputs == {"a": 10} + assert bound.flow.compute(seed=1).value == 17 + + +def test_bound_flow_api_rejects_conflicting_runtime_input_annotations(): + @Flow.model + def add(b: FromContext[int], c: FromContext[int]) -> int: + return b + c + + @Flow.context_transform + def from_int(seed: FromContext[int]) -> int: + return seed + 1 + + @Flow.context_transform + def from_str(seed: FromContext[str]) -> int: + return int(seed) + 1 + + bound = add().flow.with_context(b=from_int(), c=from_str()) + + with pytest.raises(TypeError, match="Conflicting runtime context annotations for 'seed'"): + bound.flow.runtime_inputs + + def test_bound_model_rejects_regular_field_context_overrides(): @Flow.model def add(a: int, b: FromContext[int]) -> int: @@ -219,7 +273,9 @@ def add(a: int, b: FromContext[int]) -> int: bound = add(a=10).flow.with_context(b=5) dumped = bound.model_dump(mode="python") - assert dumped["context_spec"] == {"patches": [], "field_overrides": {"b": {"kind": "static_value", "value": 5}}} + assert dumped["context_spec"] == { + "operations": [{"kind": "field", "name": "b", "spec": {"kind": "static_value", "value": 5}}], + } restored = type(bound).model_validate(dumped) assert restored.flow.compute().value == 15 @@ -233,7 +289,8 @@ def add(a: int, b: FromContext[int]) -> int: bound = add(a=10).flow.with_context(b=offset_b(amount=1)) dumped = bound.model_dump(mode="json") - assert dumped["context_spec"]["field_overrides"]["b"]["binding"]["path"].endswith(".offset_b") + binding = dumped["context_spec"]["operations"][0]["spec"] + assert binding["path"] is not None or binding["serialized_config"] is not None restored = type(bound).model_validate(dumped) assert restored.flow.compute(b=4).value == 15 @@ -289,7 +346,7 @@ def add(a: int, b: FromContext[int]) -> int: assert bound.flow.compute(b=4).value == 15 dumped = bound.model_dump(mode="python") - assert dumped["context_spec"]["field_overrides"]["b"]["binding"]["kind"] == "context_transform" + assert dumped["context_spec"]["operations"][0]["spec"]["kind"] == "context_transform" restored = type(bound).model_validate(dumped) assert restored.flow.compute(b=4).value == 15 @@ -303,8 +360,9 @@ def load_window(start_date: FromContext[date], end_date: FromContext[date]) -> G return GenericResult(value={"start": start_date, "end": end_date}) dumped = load_window().flow.with_context(shift_window(days=7), start_date=shift_start_date(days=1)).model_dump(mode="json") - assert dumped["context_spec"]["patches"][0]["kind"] == "context_patch" - assert dumped["context_spec"]["field_overrides"]["start_date"]["kind"] == "context_value" + assert dumped["context_spec"]["operations"][0]["binding"]["kind"] == "context_transform" + assert dumped["context_spec"]["operations"][1]["name"] == "start_date" + assert dumped["context_spec"]["operations"][1]["spec"]["kind"] == "context_transform" def test_regular_callable_models_still_support_with_context(): @@ -317,7 +375,8 @@ def test_flow_api_for_regular_callable_model(): model = OffsetModel(offset=10) assert model.flow.compute(x=5).value == 15 assert model.flow.context_inputs == {"x": int} - assert model.flow.unbound_inputs == {"x": int} + assert model.flow.runtime_inputs == {"x": int} + assert model.flow.required_inputs == {"x": int} assert model.flow.bound_inputs == {"offset": 10} diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 6c82625..c73b79b 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -2,13 +2,15 @@ import base64 import graphlib +import importlib import inspect import pickle import subprocess import sys from datetime import date, timedelta +from pathlib import Path from types import ModuleType -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Callable, Literal, Optional, get_args import pytest import ray @@ -249,7 +251,10 @@ def test_context_transform_internal_error_and_repr_paths(): assert flow_model_module._context_transform_repr(static_patch()) == "static_patch()" assert flow_model_module._context_transform_repr(increment_b(amount=2)) == "increment_b(amount=2)" assert flow_model_module._context_transform_repr(123) == "123" - assert flow_model_module._context_transform_identifier(increment_b(amount=1)).endswith(".increment_b") + importable_increment_b = Flow.context_transform(increment_b.__wrapped__) + importable_config = flow_model_module._load_context_transform_config_from_binding(importable_increment_b(amount=1)) + assert importable_config.path is not None + assert flow_model_module._context_transform_identifier(importable_increment_b(amount=1)) == "increment_b" with pytest.raises(ValidationError, match="exactly one"): flow_model_module.ContextTransform() @@ -272,32 +277,11 @@ def test_context_transform_internal_error_and_repr_paths(): flow_model_module._restore_generated_flow_model("ccflow.tests.test_flow_model.lazy_context_transform_for_rejection", {}) -def test_flow_model_low_level_value_helpers_cover_edge_paths(): - assert flow_model_module._bound_field_names(object()) == set() - assert flow_model_module._concrete_context_type(int | None) is None - no_name_annotation = int | str - assert flow_model_module._expected_type_repr(no_name_annotation) == repr(no_name_annotation) - assert flow_model_module._coerce_value("x", "still-raw", object(), "test") == "still-raw" - assert flow_model_module._unwrap_model_result(7) == 7 - assert flow_model_module._type_accepts_str(Annotated[str, "meta"]) is True - assert flow_model_module._type_accepts_str(int | str) is True - assert flow_binding_module._is_result_annotation(GenericResult[int] | None) is True - assert flow_model_module._registry_candidate_allowed(object(), data_source(base_value=1)) is True - assert flow_model_module._registry_candidate_allowed(int, GenericResult(value=1)) is False - assert flow_model_module._is_mapping_annotation(inspect.Signature.empty) is False - assert flow_model_module._is_mapping_annotation(123) is False - generated_type = type(basic_loader(source="s", multiplier=2)) - assert flow_model_module._resolve_generated_model_bases(generated_type) == (generated_type,) - assert callable(Flow.context_transform()) - - metadata: list[object] = [] - annotation = Annotated[int, metadata] - assert flow_model_module._type_adapter(annotation) is flow_model_module._type_adapter(annotation) - - with pytest.raises(TypeError, match="only supports Python functions"): - flow_model_module._ensure_top_level_named_function(123, decorator_name="@Flow.model") - with pytest.raises(TypeError, match="only supports named Python functions"): - flow_model_module._ensure_top_level_named_function(lambda: None, decorator_name="@Flow.model") +def test_flow_model_rejects_invalid_decorator_targets(): + with pytest.raises(TypeError): + Flow.model(123) + with pytest.raises(TypeError): + Flow.model(lambda: None) def test_lazy_thunks_and_regular_resolution_edge_paths(): @@ -363,8 +347,11 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: flow_model_module._evaluate_context_transform_from_values(seed_plus_one(), {}) dynamic_spec = flow_model_module._BoundContextSpec( - patches=[flow_model_module.PatchContextSpec(binding=dynamic_patch())], - field_overrides={}, + operations=[ + flow_model_module.PatchContextOperation( + binding=dynamic_patch(), + ) + ], ) assert flow_model_module._statically_resolved_context_values(add(), dynamic_spec) is None assert flow_model_module._statically_resolved_context_field_names(add(), dynamic_spec) == set() @@ -394,11 +381,9 @@ class OpaqueModel: assert flow_model_module._validate_patch_result(OpaqueModel(), {"x": 1}) == {"x": 1} flow_model_module._validate_with_context_field_names(OpaqueModel(), ["anything"]) - assert ( - flow_model_module._static_field_override_value(OpaqueModel(), "anything", flow_model_module.FieldContextSpec(binding=default_amount())) == 5 - ) + assert flow_model_module._evaluate_static_context_transform(default_amount()) == 5 - with pytest.raises(TypeError, match="raw callables"): + with pytest.raises(TypeError, match="Positional with_context"): add().flow.with_context(lambda: {"a": 1}) with pytest.raises(TypeError, match="Positional with_context"): add().flow.with_context(123) @@ -425,12 +410,6 @@ def regular_required(x: int) -> int: def lazy_consumer(x: Lazy[int]) -> int: return x() - restored = flow_model_module._restore_generated_flow_model( - "ccflow.tests.test_flow_model.basic_loader", - basic_loader(source="s", multiplier=2).__getstate__(), - ) - assert restored.flow.compute(value=3).value == 6 - class FailingPath: def __init__(self, path): self.path = path @@ -439,29 +418,20 @@ def __init__(self, path): def object(self): raise ImportError(self.path) - original_path = flow_model_module.PyObjectPath - monkeypatch.setattr(flow_model_module, "PyObjectPath", FailingPath) - config = type(basic_loader(source="s", multiplier=2)).__flow_model_config__ - assert flow_model_module._generated_model_factory_path_for_pickle(config, type(basic_loader(source="s", multiplier=2))) is None - monkeypatch.setattr(flow_model_module, "PyObjectPath", original_path) - - assert flow_model_module._registry_candidate_allowed(int, 1) is True - opaque_model = type("OpaqueModel", (), {"context_type": object})() - assert flow_model_module._coerce_model_context_value(opaque_model, "anything", "raw", "test") == "raw" - assert flow_model_module._generated_model_identity_payload(regular_required(), FlowContext()) is None - - context_spec = flow_model_module._BoundContextSpec( - patches=[flow_model_module.PatchContextSpec(binding=dynamic_patch())], - field_overrides={"b": flow_model_module.FieldContextSpec(binding=default_seed())}, - ) - values, missing = flow_model_module._apply_context_spec_values_for_identity(add(), context_spec, FlowContext(seed=1)) - assert values == {"a": 1, "b": 2, "seed": 1} - assert missing == () - assert flow_model_module._statically_resolved_context_values(add(), context_spec) is None + generated_name = "_basic_loader_Model" + original_generated = getattr(sys.modules[__name__], generated_name) + try: + with monkeypatch.context() as path_patch: + path_patch.setattr(flow_model_module, "PyObjectPath", FailingPath) + restored = pickle.loads(pickle.dumps(basic_loader(source="s", multiplier=2), protocol=5)) + finally: + setattr(sys.modules[__name__], generated_name, original_generated) + assert restored.flow.compute(value=3).value == 6 bound = add().flow.with_context(dynamic_patch()) - assert bound.flow.context_inputs == {"a": int, "b": int, "seed": int} - assert bound.flow.unbound_inputs == {"a": int, "b": int, "seed": int} + assert bound.flow.context_inputs == {"a": int, "b": int} + assert bound.flow.runtime_inputs == {"a": int, "b": int, "seed": int} + assert bound.flow.required_inputs == {"a": int, "b": int, "seed": int} with pytest.raises(TypeError, match="missing required regular"): flow_model_module._bound_context_transform_regular_kwargs( @@ -471,8 +441,6 @@ def object(self): with pytest.raises(TypeError, match="Missing regular parameter"): regular_required().__deps__(FlowContext()) assert lazy_consumer(x=data_source(base_value=1)).__deps__(FlowContext(value=1)) == [] - assert getattr(basic_loader, "_generated_model")._resolve_registry_refs("raw") == "raw" - assert flow_model_module._GeneratedFlowModelBase._resolve_registry_refs({}) == {} def transform_with_bad_hints(value: FromContext[int]) -> int: return value @@ -481,7 +449,8 @@ def raise_attribute_error(*args, **kwargs): raise AttributeError("bad hints") monkeypatch.setattr(flow_model_module, "get_type_hints", raise_attribute_error) - assert Flow.context_transform(transform_with_bad_hints)().serialized_config is not None + with pytest.raises(AttributeError, match="bad hints"): + Flow.context_transform(transform_with_bad_hints) def test_plain_and_bound_optional_compute_paths_and_identity_helpers(): @@ -498,10 +467,10 @@ def __call__(self, context: Optional[SimpleContext] = None) -> GenericResult[int assert flow_model_module._model_context_contract(AnyContextModel()).input_types is None assert flow_model_module._model_context_contract(FlowContextModel()).input_types is None - assert flow_model_module._identity_context_values_for_model(AnyContextModel(), FlowContext(extra=1)) == {"extra": 1} + assert flow_model_module._identity_context_values_for_model_values(AnyContextModel(), {"extra": 1}) == {"extra": 1} assert OptionalContextModel().flow.compute(None).value == 0 assert OptionalContextModel().flow.compute().value == 0 - assert OptionalContextModel().flow.unbound_inputs == {} + assert OptionalContextModel().flow.required_inputs == {} bound = OptionalContextModel().flow.with_context() assert bound.flow.compute(FlowContext(value=3)).value == 3 @@ -561,6 +530,28 @@ def foo(a: int, b: FromContext[int]) -> int: assert model.flow.compute(FlowContext(value=7, b=12)).value == 24 +def test_regular_param_containers_are_literals(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def inspect_values(values: list[Any]) -> int: + return sum(isinstance(value, CallableModel) for value in values) + + model = inspect_values(values=[source()]) + + assert model.__deps__(FlowContext(value=10)) == [] + assert model.flow.compute(value=10).value == 1 + + @Flow.model + def total(values: list[int]) -> int: + return sum(values) + + with pytest.raises(TypeError, match="Field 'values'"): + total(values=[source()]) + + def test_regular_param_upstream_dependency_coerced(): """Upstream model returning str should be coerced to downstream int annotation.""" @@ -668,7 +659,7 @@ def foo(a: int, b: FromContext[int]) -> int: model = foo(a=11, b=12) assert model.flow.bound_inputs == {"a": 11, "b": 12} assert model.flow.context_inputs == {"b": int} - assert model.flow.unbound_inputs == {} + assert model.flow.required_inputs == {} assert model.flow.compute().value == 23 @@ -680,7 +671,7 @@ def foo(a: int, b: FromContext[int] = 5) -> int: model = foo(a=2) assert model.flow.bound_inputs == {"a": 2} assert model.flow.context_inputs == {"b": int} - assert model.flow.unbound_inputs == {} + assert model.flow.required_inputs == {} assert model.flow.compute().value == 7 assert model.flow.compute(b=10).value == 12 @@ -716,7 +707,7 @@ def choose(mode: FromContext[str]) -> str: model = choose() assert model.flow.context_inputs == {"mode": Literal["a"]} - assert model.flow.unbound_inputs == {"mode": Literal["a"]} + assert model.flow.required_inputs == {"mode": Literal["a"]} assert model.flow.compute(mode="a").value == "a" with pytest.raises(ValidationError): model.flow.compute(mode="b") @@ -849,6 +840,51 @@ def load(value: FromContext[int]) -> GenericResult[int]: assert result.value == 6 +def test_flow_model_rejects_union_resultbase_return_annotations(): + with pytest.raises(TypeError, match="does not support Union or Optional ResultBase"): + + @Flow.model + def optional_result(value: FromContext[int]) -> Optional[GenericResult[int]]: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="does not support Union or Optional ResultBase"): + + @Flow.model + def pep604_optional_result(value: FromContext[int]) -> GenericResult[int] | None: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="does not support Union or Optional ResultBase"): + + @Flow.model + def annotated_optional_result(value: FromContext[int]) -> Annotated[GenericResult[int] | None, "meta"]: + return GenericResult(value=value) + + +def test_flow_model_allows_plain_union_return_annotations(): + @Flow.model + def choose(flag: FromContext[bool]) -> int | str: + return 1 if flag else "one" + + assert choose().flow.compute(flag=True).value == 1 + assert choose().flow.compute(flag=False).value == "one" + + +def test_flow_model_handles_annotated_result_annotations(): + @Flow.model + def explicit_result() -> Annotated[GenericResult[int], "meta"]: + return GenericResult(value=1) + + @Flow.model + def plain_union(flag: FromContext[bool]) -> Annotated[int | str, "meta"]: + return 1 if flag else "one" + + result = explicit_result().flow.compute() + assert type(result) is GenericResult[int] + assert result.value == 1 + assert plain_union().flow.compute(flag=True).value == 1 + assert plain_union().flow.compute(flag=False).value == "one" + + def test_auto_unwrap_can_be_enabled_for_auto_wrapped_results(): @Flow.model(auto_unwrap=True) def add(a: int, b: FromContext[int]) -> int: @@ -879,6 +915,24 @@ def bad(x: FromContext[int]) -> int: bad().flow.compute(x=1) +def test_auto_wrap_respects_validate_result_false(): + @Flow.model(validate_result=False) + def bad_decorator() -> int: + return "oops" + + @Flow.model + def bad_runtime() -> int: + return "oops" + + decorator_result = bad_decorator().flow.compute() + runtime_result = bad_runtime().flow.compute(_options=FlowOptions(validate_result=False)) + + assert isinstance(decorator_result, GenericResult) + assert decorator_result.value == "oops" + assert isinstance(runtime_result, GenericResult) + assert runtime_result.value == "oops" + + def test_auto_wrap_coerces_compatible_return(): @Flow.model def coerce(x: FromContext[int]) -> float: @@ -939,7 +993,7 @@ def mixed(context: SimpleContext, y: FromContext[int]) -> int: assert model.flow.compute(y=5).value == 15 -@pytest.mark.parametrize("reserved_name", ["flow", "meta", "context_type", "result_type"]) +@pytest.mark.parametrize("reserved_name", ["flow", "meta", "context_type", "result_type", "type_"]) def test_flow_model_rejects_reserved_parameter_names(reserved_name): namespace = {"Flow": Flow, "FromContext": FromContext} exec( @@ -951,14 +1005,6 @@ def test_flow_model_rejects_reserved_parameter_names(reserved_name): Flow.model(namespace["bad"]) -def test_context_args_keyword_is_removed(): - with pytest.raises(TypeError, match="context_args=... has been removed"): - - @Flow.model(context_args=["x"]) - def bad(x: int) -> int: - return x - - def test_context_type_requires_from_context(): with pytest.raises(TypeError, match="context_type=... requires FromContext"): @@ -1040,7 +1086,7 @@ def add(a: int, b: FromContext[int]) -> int: restored = type(model).model_validate(dumped) assert restored.flow.bound_inputs == {"a": 10} - assert restored.flow.unbound_inputs == {"b": int} + assert restored.flow.required_inputs == {"b": int} assert restored.flow.compute(b=5).value == 15 @@ -1096,6 +1142,226 @@ def multiply(a: int, b: FromContext[int]) -> int: assert param.validation_annotation is int +def test_local_generated_model_effective_cache_key_survives_pickle_roundtrip(): + def make_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + return add(a=1) + + model = make_model() + context = FlowContext(b=2) + before = cache_key(model.__call__.get_evaluation_context(model, context), effective=True) + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(model, protocol=5)) + after = cache_key(restored.__call__.get_evaluation_context(restored, context), effective=True) + assert after == before + + +def test_local_generated_model_plain_pickle_handles_generic_result_state(): + def make_model(): + @Flow.model + def first(xs: list[GenericResult[int]], b: FromContext[int]) -> int: + return xs[0].value + b + + return first(xs=[GenericResult(value=1)]) + + restored = pickle.loads(pickle.dumps(make_model(), protocol=5)) + + assert restored.flow.compute(b=2).value == 3 + + +def test_bound_model_plain_pickle_handles_context_transform_generic_result_bound_args(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + @Flow.context_transform + def fixed(value: GenericResult[int]) -> int: + return value.value + + bound = source().flow.with_context(a=fixed(value=GenericResult(value=5))) + restored = pickle.loads(pickle.dumps(bound, protocol=5)) + + assert restored.flow.compute().value == 5 + + +def test_local_generated_model_plain_pickle_bytes_in_ray_handles_generic_result_state(): + def make_model(): + @Flow.model + def first(xs: list[GenericResult[int]], b: FromContext[int]) -> int: + return xs[0].value + b + + return first(xs=[GenericResult(value=10)]) + + @ray.remote + class Runner: + def run(self, payload): + model = pickle.loads(payload) + context = FlowContext(b=3) + before = cache_key(model.__call__.get_evaluation_context(model, context), effective=True) + value = model.flow.compute(context).value + after = cache_key(model.__call__.get_evaluation_context(model, context), effective=True) + return value, before == after + + with ray.init(num_cpus=1): + runner = Runner.remote() + assert ray.get(runner.run.remote(pickle.dumps(make_model(), protocol=5))) == (13, True) + + +def test_importable_generated_model_plain_pickle_cross_process_handles_generic_result_state(tmp_path, monkeypatch): + module_dir = tmp_path / "generic_state_module" + module_dir.mkdir() + module_path = module_dir / "generic_state_mod.py" + module_path.write_text( + "\n".join( + [ + "from ccflow import Flow, FromContext, GenericResult", + "", + "@Flow.model", + "def first(xs: list[GenericResult[int]], b: FromContext[int]) -> int:", + " return xs[0].value + b", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import generic_state_mod + + payload = base64.b64encode(pickle.dumps(generic_state_mod.first(xs=[GenericResult(value=10)]), protocol=5)).decode() + script = ( + "import base64, pickle, sys\n" + f"sys.path.insert(0, {str(module_dir)!r})\n" + f"model = pickle.loads(base64.b64decode({payload!r}))\n" + "assert model.flow.compute(b=3).value == 13\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + + assert result.returncode == 0, result.stderr + + +def test_generated_model_plain_pickle_preserves_generic_result_in_loose_state(): + @Flow.model + def use_any(x: Any, b: FromContext[int]) -> int: + return x.value + b + + @Flow.model + def use_list(xs: list[Any]) -> int: + return xs[0].value + + restored_any = pickle.loads(pickle.dumps(use_any(x=GenericResult(value=10)), protocol=5)) + restored_list = pickle.loads(pickle.dumps(use_list(xs=[GenericResult(value=11)]), protocol=5)) + + assert restored_any.flow.compute(b=3).value == 13 + assert isinstance(restored_any.x, GenericResult) + assert restored_list.flow.compute().value == 11 + assert isinstance(restored_list.xs[0], GenericResult) + + +def test_importable_generated_model_plain_pickle_cross_process_preserves_generic_result_in_loose_state(tmp_path, monkeypatch): + module_dir = tmp_path / "generic_any_state_module" + module_dir.mkdir() + module_path = module_dir / "generic_any_state_mod.py" + module_path.write_text( + "\n".join( + [ + "from typing import Any", + "from ccflow import Flow, FromContext", + "", + "@Flow.model", + "def use_any(x: Any, b: FromContext[int]) -> int:", + " return x.value + b", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import generic_any_state_mod + + payload = base64.b64encode(pickle.dumps(generic_any_state_mod.use_any(x=GenericResult(value=10)), protocol=5)).decode() + script = ( + "import base64, pickle, sys\n" + f"sys.path.insert(0, {str(module_dir)!r})\n" + f"model = pickle.loads(base64.b64decode({payload!r}))\n" + "assert model.flow.compute(b=3).value == 13\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + + assert result.returncode == 0, result.stderr + + +def test_local_generated_model_plain_pickle_handles_generic_result_function_default(): + def make_model(): + @Flow.model + def first(xs: list[GenericResult[int]] = [GenericResult(value=1)], b: FromContext[int] = 2) -> int: + return xs[0].value + b + + return first() + + restored = pickle.loads(pickle.dumps(make_model(), protocol=5)) + + assert restored.flow.compute().value == 3 + + +def test_unresolved_lazy_local_generated_dependency_identity_survives_pickle_roundtrip(): + def make_model(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + return lazy_value() if use_lazy else x + + return choose(x=7, lazy_value=source()) + + model = make_model() + context = FlowContext(use_lazy=False) + before_eval = model.__call__.get_evaluation_context(model, context) + before_key = cache_key(before_eval, effective=True) + before_root = get_dependency_graph(before_eval).root_id + + restored = pickle.loads(pickle.dumps(model, protocol=5)) + after_eval = restored.__call__.get_evaluation_context(restored, context) + + assert cache_key(after_eval, effective=True) == before_key + assert get_dependency_graph(after_eval).root_id == before_root + + +def test_unresolved_lazy_nested_local_generated_dependency_identity_survives_pickle_roundtrip(): + def make_model(): + @Flow.model + def dependency(d: FromContext[int]) -> int: + return d + + @Flow.model + def source(dep_value: int, a: FromContext[int]) -> int: + return dep_value + a + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + return lazy_value() if use_lazy else x + + return choose(x=7, lazy_value=source(dep_value=dependency())) + + model = make_model() + context = FlowContext(use_lazy=False) + before_eval = model.__call__.get_evaluation_context(model, context) + before_key = cache_key(before_eval, effective=True) + before_root = get_dependency_graph(before_eval).root_id + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(model, protocol=5)) + after_eval = restored.__call__.get_evaluation_context(restored, context) + + assert cache_key(after_eval, effective=True) == before_key + assert get_dependency_graph(after_eval).root_id == before_root + + def test_generated_model_pydantic_roundtrip_via_base_model(): @Flow.model def add(a: int, b: FromContext[int]) -> int: @@ -1128,6 +1394,112 @@ def add(a: int, b: FromContext[int]) -> int: assert restored.flow.compute(b=7).value == 17 +def test_generated_model_dependency_input_json_roundtrip(): + from ccflow import BaseModel + + model = data_transformer(source=data_source(base_value=1), factor=2) + restored = BaseModel.model_validate_json(model.model_dump_json()) + + assert restored.flow.compute(value=10).value == 22 + assert isinstance(restored.source, CallableModel) + + +def test_generated_model_lazy_dependency_input_json_roundtrip(): + @Flow.model + def choose(source: Lazy[int], use_value: FromContext[bool]) -> int: + return source() if use_value else 0 + + from ccflow import BaseModel + + model = choose(source=data_source(base_value=1)) + restored = BaseModel.model_validate_json(model.model_dump_json()) + + assert restored.flow.compute(value=10, use_value=True).value == 11 + assert restored.flow.compute(value=10, use_value=False).value == 0 + assert isinstance(restored.source, CallableModel) + + +def test_generated_model_type_dict_regular_input_stays_literal_when_annotation_accepts_dict(): + @Flow.model + def payload_type(payload: dict) -> str: + return type(payload).__name__ + + type_payload = data_source(base_value=1).model_dump(mode="python") + model = payload_type(payload=type_payload) + + assert model.flow.compute(value=10).value == "dict" + assert model.payload == type_payload + + +def test_generated_model_target_alias_dict_regular_input_stays_literal(): + @Flow.model + def payload_type(payload: Any) -> str: + return type(payload).__name__ + + alias_payload = data_source(base_value=1).model_dump(mode="python", by_alias=True) + model = payload_type(payload=alias_payload) + + assert model.flow.compute(value=10).value == "dict" + assert model.payload == alias_payload + + +def test_generated_model_target_alias_restores_dependency_after_literal_validation_fails(): + @Flow.model + def total(value: int) -> int: + return value + + alias_payload = data_source(base_value=1).model_dump(mode="python", by_alias=True) + model = total(value=alias_payload) + + assert model.flow.compute(FlowContext(value=10)).value == 11 + assert isinstance(model.value, CallableModel) + + +def test_generated_model_type_marker_restores_dependency_after_literal_validation_fails(): + @Flow.model + def total(value: int) -> int: + return value + + type_payload = data_source(base_value=1).model_dump(mode="python") + model = total(value=type_payload) + + assert model.flow.compute(FlowContext(value=10)).value == 11 + assert isinstance(model.value, CallableModel) + + +@pytest.mark.parametrize( + "payload", + [ + {"_target_": "does.not.exist.Class", "foo": 1}, + {"_target_": "builtins.dict", "foo": 1}, + {"type_": "does.not.exist.Class", "foo": 1}, + {"type_": "builtins.dict", "foo": 1}, + ], +) +def test_serialized_dependency_fallback_preserves_literal_validation_error(payload): + @Flow.model + def total(value: int) -> int: + return value + + with pytest.raises(TypeError, match="Field 'value': expected int, got dict"): + total(value=payload) + + +def test_registry_lookup_preserves_literal_error_when_candidate_type_is_incompatible(): + registry = ModelRegistry.root().clear() + registry.add("context", SimpleContext(value=1)) + + @Flow.model + def total(values: list[int]) -> int: + return sum(values) + + try: + with pytest.raises(TypeError, match="Field 'values': expected list, got str"): + total(values="context") + finally: + registry.clear() + + def test_importable_generated_model_uses_stable_module_path_for_type_serialization(): model = basic_loader(source="library", multiplier=3) stable_path = f"{__name__}._basic_loader_Model" @@ -1138,6 +1510,134 @@ def test_importable_generated_model_uses_stable_module_path_for_type_serializati assert str(model.model_dump(mode="python")["type_"]) == stable_path +def test_importable_generated_model_duplicate_names_raise_conflict(monkeypatch): + module = ModuleType("ccflow_test_duplicate_generated_models") + module.Flow = Flow + module.FromContext = FromContext + monkeypatch.setitem(sys.modules, module.__name__, module) + + exec( + """ +def stage(value: FromContext[int]) -> int: + return value + 1 +""", + module.__dict__, + ) + first_factory = Flow.model(module.stage) + module.first = first_factory + first = first_factory() + + exec( + """ +def stage(value: FromContext[int]) -> int: + return value + 2 +""", + module.__dict__, + ) + with pytest.raises(ValueError, match="already occupied"): + Flow.model(module.stage) + + assert getattr(module, "_stage_Model") is type(first) + assert not hasattr(module, "_stage_Model_2") + assert str(PyObjectPath.validate(type(first))) == f"{module.__name__}._stage_Model" + assert first.flow.compute(value=10).value == 11 + + +def test_reloaded_importable_generated_model_keeps_clean_process_path(tmp_path, monkeypatch): + module_dir = tmp_path / "reload_case" + module_dir.mkdir() + module_path = module_dir / "repro_mod.py" + module_path.write_text( + "\n".join( + [ + "from ccflow import Flow, FromContext", + "", + "@Flow.model", + "def foo(x: FromContext[int]) -> int:", + " return x + 1", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import repro_mod + + assert str(repro_mod.foo().type_) == "repro_mod._foo_Model" + reloaded = importlib.reload(repro_mod) + model = reloaded.foo() + payload = model.model_dump_json() + script = ( + "import sys\n" + f"sys.path.insert(0, {str(module_dir)!r})\n" + "from ccflow import BaseModel\n" + f"model = BaseModel.model_validate_json({payload!r})\n" + "assert model.flow.compute(x=2).value == 3\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + + assert str(model.type_) == "repro_mod._foo_Model" + assert result.returncode == 0, f"Clean-process reload JSON roundtrip failed:\n{result.stderr}" + + +def test_reloaded_importable_generated_model_allows_stale_factory_aliases(tmp_path, monkeypatch): + module_dir = tmp_path / "reload_alias_case" + module_dir.mkdir() + module_path = module_dir / "alias_mod.py" + module_path.write_text( + "\n".join( + [ + "from ccflow import Flow, FromContext", + "", + "def foo(x: FromContext[int]) -> int:", + " return x + 1", + "", + "bar = Flow.model(foo)", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import alias_mod + + assert str(alias_mod.bar().type_) == "alias_mod._foo_Model" + reloaded = importlib.reload(alias_mod) + + assert str(reloaded.bar().type_) == "alias_mod._foo_Model" + assert reloaded.bar().flow.compute(x=2).value == 3 + + +def test_reloaded_importable_generated_model_allows_stale_decorator_aliases(tmp_path, monkeypatch): + module_dir = tmp_path / "reload_decorator_alias_case" + module_dir.mkdir() + module_path = module_dir / "decor_alias_mod.py" + module_path.write_text( + "\n".join( + [ + "from ccflow import Flow, FromContext", + "", + "@Flow.model", + "def foo(x: FromContext[int]) -> int:", + " return x + 1", + "", + "foo_alias = foo", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import decor_alias_mod + + assert str(decor_alias_mod.foo().type_) == "decor_alias_mod._foo_Model" + reloaded = importlib.reload(decor_alias_mod) + + assert str(reloaded.foo().type_) == "decor_alias_mod._foo_Model" + assert str(reloaded.foo_alias().type_) == "decor_alias_mod._foo_Model" + assert reloaded.foo().flow.compute(x=2).value == 3 + + def test_importable_generated_model_json_roundtrip_cross_process(): model = basic_loader(source="library", multiplier=3) payload = model.model_dump_json() @@ -1205,14 +1705,14 @@ def add(a: int, b: FromContext[int]) -> int: add(a=1).flow.with_context(b="not_an_int") -def test_context_transform_serializes_import_path_and_bound_args(): - binding = increment_b(amount=3) +def test_context_transform_serializes_embedded_config_and_bound_args(): + transform_factory = Flow.context_transform(increment_b.__wrapped__) + binding = transform_factory(amount=3) assert isinstance(binding, flow_model_module.ContextTransform) assert binding.kind == "context_transform" - assert binding.path is not None - assert binding.serialized_config is None + assert binding.path is None + assert binding.serialized_config is not None assert binding.bound_args == {"amount": 3} - assert str(binding.path).endswith(".increment_b") def test_context_transform_rejects_none_for_required_param(): @@ -1251,6 +1751,24 @@ def add(a: int, b: FromContext[int]) -> int: assert restored.flow.compute(value=4).value == 15 +def test_context_transform_json_roundtrip_recoerces_regular_bound_args(): + from ccflow import BaseModel + + @Flow.model + def load(day: FromContext[date]) -> str: + return day.isoformat() + + @Flow.context_transform + def shift_from_anchor(anchor: date, days: FromContext[int]) -> date: + return anchor + timedelta(days=days) + + bound = load().flow.with_context(day=shift_from_anchor(anchor=date(2024, 1, 1))) + restored = BaseModel.model_validate_json(bound.model_dump_json()) + + assert bound.flow.compute(days=2).value == "2024-01-03" + assert restored.flow.compute(days=2).value == "2024-01-03" + + def test_context_transform_supports_nested_functions_with_serialized_payload(): @Flow.model def add(a: int, b: FromContext[int]) -> int: @@ -1314,10 +1832,23 @@ def test_with_context_rejects_raw_callables(): def add(a: int, b: FromContext[int]) -> int: return a + b - with pytest.raises(TypeError, match="no longer accepts raw callables"): + with pytest.raises(TypeError, match="with_context\\(\\) 'b': expected int"): add(a=1).flow.with_context(b=lambda ctx: ctx.b + 1) +def test_with_context_accepts_callable_literals_for_callable_context_fields(): + def increment(value: int) -> int: + return value + 1 + + @Flow.model + def apply(fn: FromContext[Callable[[int], int]], value: FromContext[int]) -> int: + return fn(value) + + assert apply(fn=increment).flow.compute(value=2).value == 3 + assert apply().flow.compute(fn=increment, value=2).value == 3 + assert apply().flow.with_context(fn=increment).flow.compute(value=2).value == 3 + + def test_with_context_rejects_wrong_transform_position(): @Flow.model def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: @@ -1330,6 +1861,62 @@ def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: load().flow.with_context(start_date=shift_integer_window(amount=10)) +def test_chained_with_context_preserves_patch_and_field_order(): + @Flow.context_transform + def patch_a() -> dict[str, int]: + return {"a": 2} + + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + field_then_patch = source().flow.with_context(a=1).flow.with_context(patch_a()) + patch_then_field = source().flow.with_context(patch_a()).flow.with_context(a=1) + + assert ( + field_then_patch.model_dump(mode="json")["context_spec"]["operations"] + != patch_then_field.model_dump(mode="json")["context_spec"]["operations"] + ) + assert field_then_patch.flow.compute().value == 2 + assert patch_then_field.flow.compute().value == 1 + + +def test_chained_with_context_transform_reads_original_context(): + @Flow.context_transform + def patch_a() -> dict[str, int]: + return {"a": 2} + + @Flow.context_transform + def b_from_a(a: FromContext[int]) -> int: + return a + 3 + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + bound = source().flow.with_context(patch_a()).flow.with_context(b=b_from_a()) + + assert bound.flow.compute(a=10).value == 15 + assert bound.flow.required_inputs == {"a": int} + + +def test_chained_with_context_later_field_override_skips_dead_field_transform(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + bound = source().flow.with_context(a=seed_plus_one()).flow.with_context(a=1) + + assert bound.flow.bound_inputs == {"a": 1} + assert bound.flow.runtime_inputs == {} + assert bound.flow.required_inputs == {} + assert bound.flow.compute().value == 1 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + assert restored.flow.compute().value == 1 + + def test_with_context_accepts_wrapped_mapping_patch_annotations(): @Flow.model def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: @@ -1352,11 +1939,12 @@ def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: assert result.value == 100_012 dumped = bound.model_dump(mode="json") - assert dumped["context_spec"]["patches"][0]["binding"]["bound_args"] == {"amount": 10} - assert dumped["context_spec"]["field_overrides"]["start_date"]["kind"] == "static_value" + assert dumped["context_spec"]["operations"][0]["binding"]["bound_args"] == {"amount": 10} + assert dumped["context_spec"]["operations"][1]["name"] == "start_date" + assert dumped["context_spec"]["operations"][1]["spec"]["kind"] == "static_value" -def test_transforms_evaluate_against_original_runtime_context(): +def test_chained_transforms_read_original_runtime_context(): @Flow.model def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: return start_date * 1000 + end_date @@ -1448,22 +2036,67 @@ def uses_any(x: Any, y: FromContext[int]) -> int: assert result.value == 3, "literal string should not be replaced by registry entry for Any-typed param" +def test_registry_lookup_does_not_steal_coercible_string_literals(): + registry = ModelRegistry.root().clear() + registry.add("3", data_source(base_value=10)) + registry.add("2024-01-01", data_source(base_value=20)) + registry.add("data.csv", data_source(base_value=30)) + registry.add("loader", data_source(base_value=40)) + + @Flow.model + def uses_int(x: int) -> str: + return f"{type(x).__name__}:{x}" + + @Flow.model + def uses_date(day: date) -> str: + return f"{type(day).__name__}:{day}" + + @Flow.model + def uses_path(path: Path) -> str: + return f"{type(path).__name__}:{path}" + + try: + assert uses_int(x="3").flow.compute(value=1).value == "int:3" + assert uses_date(day="2024-01-01").flow.compute(value=1).value == "date:2024-01-01" + path_model = uses_path(path="data.csv") + assert path_model.path == Path("data.csv") + assert path_model.flow.compute(value=1).value.endswith(":data.csv") + + dependency_model = uses_int(x="loader") + assert isinstance(dependency_model.x, CallableModel) + assert dependency_model.flow.compute(value=1).value == "int:41" + finally: + registry.clear() + + def test_unexpected_type_adapter_errors_are_not_silently_swallowed(): class BrokenSchema: @classmethod def __get_pydantic_core_schema__(cls, source, handler): raise RuntimeError("boom") - @Flow.model def bad(x: BrokenSchema, y: FromContext[int]) -> int: del x, y return 0 with pytest.raises(RuntimeError, match="boom"): - bad(x=object()) + Flow.model(bad) + +@pytest.mark.parametrize("error", [RuntimeError("boom"), TypeError("boom"), ValueError("boom")]) +def test_schema_safe_annotation_does_not_swallow_unexpected_type_adapter_errors(monkeypatch, error): + def broken_type_adapter(annotation): + del annotation + raise error -def test_unexpected_type_validation_errors_are_not_rewritten(): + monkeypatch.setattr(flow_model_module, "_type_adapter", broken_type_adapter) + + with pytest.raises(type(error), match="boom"): + flow_model_module._pydantic_schema_safe_annotation(int) + + +@pytest.mark.parametrize("error", [RuntimeError("boom"), TypeError("boom")]) +def test_unexpected_type_validation_errors_are_not_rewritten(error): from pydantic_core import core_schema class BrokenValidation: @@ -1473,7 +2106,7 @@ def __get_pydantic_core_schema__(cls, source, handler): def validate(value): del value - raise RuntimeError("boom") + raise error return core_schema.no_info_plain_validator_function(validate) @@ -1482,22 +2115,29 @@ def bad(x: BrokenValidation, y: FromContext[int]) -> int: del x, y return 0 - with pytest.raises(RuntimeError, match="boom"): + with pytest.raises(type(error), match="boom"): bad(x=object()) -def test_unexpected_type_hint_resolution_errors_propagate(monkeypatch): +@pytest.mark.parametrize("error", [RuntimeError("boom"), AttributeError("boom")]) +def test_unexpected_type_hint_resolution_errors_propagate(monkeypatch, error): def broken_get_type_hints(*args, **kwargs): - raise RuntimeError("boom") + raise error monkeypatch.setattr(flow_model_module, "get_type_hints", broken_get_type_hints) def add(x: int) -> int: return x - with pytest.raises(RuntimeError, match="boom"): + with pytest.raises(type(error), match="boom"): Flow.model(add) + def transform(x: FromContext[int]) -> int: + return x + + with pytest.raises(type(error), match="boom"): + Flow.context_transform(transform) + def test_generated_model_flow_api_introspection_and_execution(): @Flow.model @@ -1507,7 +2147,7 @@ def add(a: int, b: FromContext[int]) -> int: model = add(a=10) assert model.flow.context_inputs == {"b": int} assert model.flow.bound_inputs == {"a": 10} - assert model.flow.unbound_inputs == {"b": int} + assert model.flow.required_inputs == {"b": int} assert model.flow.compute(b=5).value == 15 @@ -1541,6 +2181,19 @@ def add(a: int, b: FromContext[int]) -> int: assert custom_sig.parameters["multiplier"].default == 1 +def test_context_transform_factory_signature_only_exposes_regular_bindings(): + sig = inspect.signature(increment_b) + + assert list(sig.parameters) == ["amount"] + assert sig.parameters["amount"].kind is inspect.Parameter.KEYWORD_ONLY + assert sig.parameters["amount"].annotation is int + assert sig.parameters["amount"].default is inspect.Parameter.empty + assert sig.return_annotation is flow_model_module.ContextTransform + + with pytest.raises(TypeError, match="positional"): + increment_b(1) + + def test_type_adapter_caches_are_bounded_and_clearable(monkeypatch): monkeypatch.setattr(flow_model_module, "_TYPE_ADAPTER_CACHE_MAXSIZE", 2) flow_model_module.clear_flow_model_caches() @@ -1556,6 +2209,9 @@ def test_type_adapter_caches_are_bounded_and_clearable(monkeypatch): Annotated[str, []], Annotated[float, []], ) + repeated_unhashable = Annotated[bytes, []] + assert flow_model_module._type_adapter(repeated_unhashable) is flow_model_module._type_adapter(repeated_unhashable) + for annotation in unhashable_annotations: flow_model_module._type_adapter(annotation) @@ -1591,7 +2247,7 @@ def __deps__(self, context: SimpleContext): model = PlainModel() assert model.flow.context_inputs == {"value": int} - assert model.flow.unbound_inputs == {"value": int} + assert model.flow.required_inputs == {"value": int} assert model.flow.bound_inputs == {} assert model.flow.compute({"value": 3}).value == 3 @@ -1599,6 +2255,32 @@ def __deps__(self, context: SimpleContext): model.flow.compute(SimpleContext(value=1), value=2) +def test_plain_callable_flow_compute_preserves_matching_context_subclass(): + class RequestContext(SimpleContext): + request_id: str + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[tuple[int, str]]: + return GenericResult(value=(context.value, getattr(context, "request_id", "missing"))) + + context = RequestContext(value=3, request_id="abc") + + model = PlainModel() + + assert model.flow.compute(context).value == (3, "abc") + assert model.flow.with_context().flow.compute(context).value == (3, "abc") + + bound = model.flow.with_context(value=4) + other_context = RequestContext(value=3, request_id="def") + + assert bound.flow.compute(context).value == (4, "abc") + assert bound.flow.compute(other_context).value == (4, "def") + first_key = cache_key(bound.__call__.get_evaluation_context(bound, context), effective=True) + second_key = cache_key(bound.__call__.get_evaluation_context(bound, other_context), effective=True) + assert first_key != second_key + + def test_unhashable_annotations_still_validate(): annotation = Annotated[int, []] @@ -1613,7 +2295,7 @@ def test_compute_accepts_context_object_for_from_context_models(): model = basic_loader(source="library", multiplier=3) assert model.flow.context_inputs == {"value": int} - assert model.flow.unbound_inputs == {"value": int} + assert model.flow.required_inputs == {"value": int} assert model.flow.compute({"value": 4}).value == 12 assert model.flow.compute(SimpleContext(value=5)).value == 15 @@ -1708,11 +2390,11 @@ def missing_hints(*args, **kwargs): monkeypatch.setattr(flow_model_module, "get_type_hints", missing_hints) - @Flow.model def add(x: int, y: FromContext[int]) -> int: return x + y - assert add(x=1).flow.compute(y=2).value == 3 + with pytest.raises(AttributeError, match="missing hints"): + Flow.model(add) def test_unresolved_forward_refs_do_not_silently_strip_from_context(): @@ -1985,6 +2667,18 @@ def __call__(self, context: RequiredContext) -> GenericResult[int]: assert calls["count"] == 1 +def test_bound_plain_callable_direct_kwargs_use_flow_context(): + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: FlowContext) -> GenericResult[str]: + return GenericResult(value=type(context).__name__) + + bound = PlainSource().flow.with_context(a=1) + + assert bound(b=2).value == "FlowContext" + assert bound.flow.compute(b=2).value == "FlowContext" + + def test_bound_plain_callable_compute_preserves_bound_scoped_options(): calls = {"source": 0, "evaluator": 0} @@ -2129,7 +2823,7 @@ def consumer(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: assert calls == {"source": 1, "consumer": 1, "evaluator": 1} -def test_bound_flow_unbound_inputs_subtracts_static_context(): +def test_bound_flow_required_inputs_subtracts_static_context(): class RequiredContext(ContextBase): a: int b: int @@ -2143,14 +2837,14 @@ def __call__(self, context: RequiredContext) -> GenericResult[int]: def add(a: FromContext[int], b: FromContext[int]) -> int: return a + b - assert PlainSource().flow.with_context(a=1).flow.unbound_inputs == {"b": int} - assert add().flow.with_context(a=1).flow.unbound_inputs == {"b": int} - assert add().flow.with_context(a=static_bad()).flow.unbound_inputs == {"b": int} - assert add().flow.with_context(static_patch()).flow.unbound_inputs == {"b": int} - assert add().flow.with_context(a=1, b=2).flow.unbound_inputs == {} + assert PlainSource().flow.with_context(a=1).flow.required_inputs == {"b": int} + assert add().flow.with_context(a=1).flow.required_inputs == {"b": int} + assert add().flow.with_context(a=static_bad()).flow.required_inputs == {"b": int} + assert add().flow.with_context(static_patch()).flow.required_inputs == {"b": int} + assert add().flow.with_context(a=1, b=2).flow.required_inputs == {} -def test_bound_flow_unbound_inputs_reflects_dynamic_field_transform_inputs(): +def test_bound_flow_required_inputs_reflects_dynamic_field_transform_inputs(): @Flow.model def add(a: FromContext[int], b: FromContext[int]) -> int: return a + b @@ -2158,8 +2852,9 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: bound = add().flow.with_context(a=seed_plus_one()) assert bound.flow.compute(seed=1, b=10).value == 12 - assert bound.flow.context_inputs == {"b": int, "seed": int} - assert bound.flow.unbound_inputs == {"b": int, "seed": int} + assert bound.flow.context_inputs == {"a": int, "b": int} + assert bound.flow.runtime_inputs == {"b": int, "seed": int} + assert bound.flow.required_inputs == {"b": int, "seed": int} def test_bound_flow_bound_inputs_include_static_context_bindings(): @@ -2169,8 +2864,9 @@ def add(a: int, b: FromContext[int]) -> int: bound = add(a=1).flow.with_context(b=2) - assert bound.flow.context_inputs == {} - assert bound.flow.unbound_inputs == {} + assert bound.flow.context_inputs == {"b": int} + assert bound.flow.runtime_inputs == {} + assert bound.flow.required_inputs == {} assert bound.flow.bound_inputs == {"a": 1, "b": 2} @@ -2183,8 +2879,9 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: assert bound.flow.compute(seed=3, b=10).value == 14 assert bound.flow.bound_inputs == {} - assert bound.flow.context_inputs == {"b": int, "seed": int} - assert bound.flow.unbound_inputs == {"b": int, "seed": int} + assert bound.flow.context_inputs == {"a": int, "b": int} + assert bound.flow.runtime_inputs == {"b": int, "seed": int} + assert bound.flow.required_inputs == {"b": int, "seed": int} def test_generated_model_cache_ignores_unused_flow_context_fields(): @@ -2252,6 +2949,51 @@ def root(x: int, bonus: FromContext[int]) -> int: assert cache.key(eval1) == cache_key(eval1, effective=True) +def test_effective_cache_key_ignores_untokenizable_unused_ambient_context(): + class BadToken: + def __deepcopy__(self, memo): + return self + + def __getstate__(self): + raise RuntimeError("unused field should not be tokenized") + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=1) + clean_context = FlowContext(b=2) + noisy_context = FlowContext(b=2, unused=BadToken()) + clean_eval = model.__call__.get_evaluation_context(model, clean_context) + noisy_eval = model.__call__.get_evaluation_context(model, noisy_context) + + assert cache_key(noisy_eval, effective=True) == cache_key(clean_eval, effective=True) + + +def test_recursive_effective_cache_key_ignores_untokenizable_unused_ambient_context(): + class BadToken: + def __deepcopy__(self, memo): + return self + + def __getstate__(self): + raise RuntimeError("unused field should not be tokenized") + + @Flow.model + def cycle(value: int, a: FromContext[int]) -> int: + return a + + model = cycle(value=cycle(value=1)) + model.value = model + context = FlowContext(a=1, unused=BadToken()) + evaluation = model.__call__.get_evaluation_context(model, context) + + cache_key(evaluation, effective=True) + graph = get_dependency_graph(evaluation) + assert graph.root_id in graph.ids + with pytest.raises(graphlib.CycleError): + tuple(graphlib.TopologicalSorter(graph.graph).static_order()) + + def test_cache_key_effective_option_preserves_plain_callable_structural_identity(): calls = {"count": 0} @@ -2978,6 +3720,60 @@ def add(a: int, b: FromContext[int]) -> int: assert result.returncode == 0, f"Cross-process local cloudpickle failed:\n{result.stderr}" +def test_local_generated_model_explicit_generic_result_cross_process_cloudpickle(): + """Local generated models should not rehydrate from fragile GenericResult[T] annotations.""" + from ray.cloudpickle import dumps as rcpdumps + + @Flow.model + def add(a: int, b: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=a + b) + + encoded = base64.b64encode(rcpdumps(add(a=1), protocol=5)).decode() + script = ( + "import base64\n" + "from ray.cloudpickle import loads as rcploads\n" + f"data = base64.b64decode('{encoded}')\n" + "model = rcploads(data)\n" + "result = model.flow.compute(b=2)\n" + "assert result.value == 3, f'Expected 3, got {result.value}'\n" + "assert repr(model.result_type) == \"\"\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process explicit GenericResult cloudpickle failed:\n{result.stderr}" + + +def test_local_generated_model_postponed_annotations_cross_process_cloudpickle(): + """Local generated models should restore from analyzed config, not worker-side type-hint resolution.""" + from ray.cloudpickle import dumps as rcpdumps + + namespace: dict[str, Any] = {} + exec( + """ +from __future__ import annotations +from ccflow import Flow, FromContext + +def make_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + return add(a=1) +""", + namespace, + ) + + encoded = base64.b64encode(rcpdumps(namespace["make_model"](), protocol=5)).decode() + script = ( + "import base64\n" + "from ray.cloudpickle import loads as rcploads\n" + f"data = base64.b64decode('{encoded}')\n" + "model = rcploads(data)\n" + "result = model.flow.compute(b=2)\n" + "assert result.value == 3, f'Expected 3, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process postponed-annotation cloudpickle failed:\n{result.stderr}" + + def test_model_base_fields_visible_in_bound_inputs(): """model_base fields that are explicitly set should appear in bound_inputs.""" @@ -3000,6 +3796,56 @@ def add(a: int, b: FromContext[int]) -> int: assert model_default.flow.bound_inputs == {"a": 10} +def _annotation_contains(annotation: object, expected: object) -> bool: + if annotation is expected: + return True + return any(_annotation_contains(arg, expected) for arg in get_args(annotation)) + + +def _schema_contains(schema: object, predicate) -> bool: + if isinstance(schema, dict): + if predicate(schema): + return True + return any(_schema_contains(value, predicate) for value in schema.values()) + if isinstance(schema, list): + return any(_schema_contains(value, predicate) for value in schema) + return False + + +def test_generated_model_fields_preserve_construction_schema(): + @Flow.model + def str_source(tag: FromContext[str]) -> str: + return tag + + @Flow.model + def consumer(x: int, lazy_value: Lazy[int], y: FromContext[int]) -> int: + return x + lazy_value() + y + + generated_cls = getattr(consumer, "_generated_model") + fields = generated_cls.model_fields + properties = generated_cls.model_json_schema()["properties"] + + assert _annotation_contains(fields["x"].annotation, int) + assert _annotation_contains(fields["x"].annotation, CallableModel) + assert _schema_contains(properties["x"], lambda node: node.get("type") == "integer") + assert _schema_contains(properties["x"], lambda node: node.get("$ref", "").endswith("/CallableModel")) + + assert _annotation_contains(fields["lazy_value"].annotation, CallableModel) + assert _schema_contains(properties["lazy_value"], lambda node: node.get("$ref", "").endswith("/CallableModel")) + + assert _annotation_contains(fields["y"].annotation, int) + assert _schema_contains(properties["y"], lambda node: node.get("type") == "integer") + + model = consumer(x=str_source(), lazy_value=str_source()) + assert model.flow.compute(tag="3", y="4").value == 10 + + with pytest.raises(TypeError, match="Lazy"): + consumer(x=1, lazy_value=1) + + with pytest.raises(TypeError, match="Regular parameter"): + model.flow.compute(tag="not_a_number", y=1) + + def test_model_base_fields_rejected_by_compute(): """compute() should reject kwargs matching model_base field names.""" @@ -3021,8 +3867,4 @@ def add(a: int, b: FromContext[int]) -> int: def test_flow_model_public_exports_exclude_context_spec_models(): assert "StaticValueSpec" not in flow_model_module.__all__ - assert "FieldContextSpec" not in flow_model_module.__all__ - assert "PatchContextSpec" not in flow_model_module.__all__ assert not hasattr(ccflow, "StaticValueSpec") - assert not hasattr(ccflow, "FieldContextSpec") - assert not hasattr(ccflow, "PatchContextSpec") diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py index 6337c93..f47a58e 100644 --- a/ccflow/tests/test_flow_model_hydra.py +++ b/ccflow/tests/test_flow_model_hydra.py @@ -7,7 +7,7 @@ from ccflow import CallableModel, DateRangeContext, FlowContext, ModelRegistry -from .test_flow_model import SimpleContext +from .flow_model_hydra_fixtures import SimpleContext CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") @@ -104,12 +104,12 @@ def test_instantiate_with_omegaconf(): cfg = OmegaConf.create( { "loader": { - "_target_": "ccflow.tests.test_flow_model.basic_loader", + "_target_": "ccflow.tests.flow_model_hydra_fixtures.basic_loader", "source": "generated_input", "multiplier": 7, }, "contextual": { - "_target_": "ccflow.tests.test_flow_model.contextual_loader", + "_target_": "ccflow.tests.flow_model_hydra_fixtures.contextual_loader", "source": "library", }, } diff --git a/docs/wiki/Flow-Model.md b/docs/wiki/Flow-Model.md index 6a9d389..44e148e 100644 --- a/docs/wiki/Flow-Model.md +++ b/docs/wiki/Flow-Model.md @@ -114,6 +114,10 @@ model = add(a=load_value(offset=5)) assert model.flow.compute(value=7, b=12).value == 24 ``` +Only direct regular-parameter values are treated as upstream dependencies in +this first version. Containers such as `list`, `tuple`, and `dict` are ordinary +literal values; `@Flow.model` does not scan them for nested models. + ### Contextual Parameters Contextual parameters are the ones marked with `FromContext[...]`. @@ -323,7 +327,7 @@ custom = current.flow.with_context( ) delta = visitor_delta(current=current, previous=previous) -result = delta(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))) +result = delta.flow.compute(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))) ``` In this example, `current` and `previous` share the same underlying @@ -346,11 +350,12 @@ Key rules: multiple fields must move together, put that logic inside one patch transform. -Importable transform functions are stored by module path. Local, nested, -`__main__`, and notebook-defined transform functions are stored with an -embedded cloudpickle payload so bound models can still move through pickle and -Ray workers. For long-lived YAML/JSON configuration, prefer importable module -functions so the serialized config stays small and inspectable. +Context transforms serialize enough function metadata for bound models to move +through pickle and Ray workers. Importable module-level transforms may serialize +by module path, while local, nested, `__main__`, and notebook-defined transforms +use an embedded cloudpickle payload. For long-lived YAML/JSON configuration, +prefer small importable module functions and inspect the generated config shape +before treating it as a stable hand-written config format. ## `context_type=...` @@ -387,12 +392,36 @@ keyword-only parameters, see `Flow.call(auto_context=...)` in ## Introspection APIs -Generated models expose three useful introspection helpers: - -- `model.flow.context_inputs`: the full contextual contract, -- `model.flow.unbound_inputs`: the contextual fields still required at runtime, -- `model.flow.bound_inputs`: regular bound inputs plus any construction-time - contextual defaults. +Flow models expose a few useful introspection helpers: + +- `model.flow.context_inputs`: the declared contextual contract for the model + or wrapped model, +- `model.flow.runtime_inputs`: direct runtime context inputs this model or + wrapper may read after applying its own bindings, +- `model.flow.required_inputs`: required direct runtime context inputs that are + not already satisfied by defaults or bindings, +- `model.flow.bound_inputs`: concrete values already fixed on the model, such + as regular construction-time inputs, construction-time contextual defaults, + and literal keyword `with_context(field=value)` bindings. + +`context_inputs` intentionally stays faithful to the model's declared contract. +For bound models, `with_context(...)` bindings are reflected in +`runtime_inputs`, `required_inputs`, and `bound_inputs`. Literal bindings +satisfy their target fields. Transform bindings with runtime inputs add those +source context inputs to the effective runtime view. Static transforms, meaning +transforms whose inputs are already available, may be evaluated during +introspection so their output fields can be reported precisely. +`required_inputs` is always the required subset of `runtime_inputs`; if multiple +bindings expose the same runtime context field with conflicting annotations, +introspection raises an error instead of silently choosing one. + +These helpers report the direct API for the current model or wrapper. They do +not recursively expand every contextual input used by upstream dependencies in a +larger graph. + +Because these helpers may evaluate static `@Flow.context_transform` functions, +context transforms should be deterministic, side-effect-free, and cheap. This is +the same practical contract expected by cache identity and dependency analysis. Example: @@ -407,10 +436,29 @@ def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: model = add(a=10) assert model.flow.context_inputs == {"b": int, "c": int} -assert model.flow.unbound_inputs == {"b": int} +assert model.flow.runtime_inputs == {"b": int, "c": int} +assert model.flow.required_inputs == {"b": int} assert model.flow.bound_inputs == {"a": 10} + + +@Flow.context_transform +def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + +bound = add(a=10).flow.with_context(b=from_seed()) +assert bound.flow.context_inputs == {"b": int, "c": int} +assert bound.flow.runtime_inputs == {"c": int, "seed": int} +assert bound.flow.required_inputs == {"seed": int} +assert bound.flow.bound_inputs == {"a": 10} ``` +In the bound example, `b` remains in `context_inputs` because `add` still +declares `b` as part of its contextual contract. It is absent from +`runtime_inputs` because this wrapper supplies `b` from `from_seed()`. `seed` +appears in `runtime_inputs` because the transform reads it from the caller's +runtime context. + ## Lazy Dependencies `Lazy[T]` defers evaluation of an upstream dependency until the function body @@ -477,8 +525,15 @@ time. **A contextual parameter still shows up in `context_inputs` after I bound it** -That is expected. `context_inputs` reports the full contextual contract. -`unbound_inputs` reports only the contextual values still needed at runtime. +That is expected. `context_inputs` reports the declared contextual contract of +the model or wrapped model. It does not mean the current wrapper still requires +the caller to provide that field. + +Use `runtime_inputs` to see the effective direct runtime context inputs after +`with_context(...)` bindings. Use `required_inputs` to see what still must be +provided by the caller. Static transforms may be evaluated during introspection, +so their output fields can be removed from `runtime_inputs` and +`required_inputs` or added to `bound_inputs`. **A shared dependency runs more than once** diff --git a/docs/wiki/Workflows.md b/docs/wiki/Workflows.md index cfe03d9..6ef92b1 100644 --- a/docs/wiki/Workflows.md +++ b/docs/wiki/Workflows.md @@ -548,12 +548,17 @@ For a direct `CallableModel` call, the cache key depends on: - `model.model_dump(mode="python")` - `compute_behavior_token(type(model))` -For a `ModelEvaluationContext`, the cache key depends on: +For a `ModelEvaluationContext`, the default structural cache key depends on: - the underlying context payload plus the function name being evaluated (`__call__` vs `__deps__`) - the behavior token of the underlying model class - the data and behavior tokens of any **non-transparent** evaluators in the chain +`MemoryCacheEvaluator` and graph construction request an effective key for +generated `@Flow.model` nodes. That effective key can project a runtime +`FlowContext` down to the contextual fields a generated node actually reads, so +unused ambient fields do not necessarily split memory-cache or graph identity. + Transparent evaluators are skipped, so wrapping a model with logging or other pass-through evaluators does not change its cache identity. `compute_behavior_token()` hashes the class's Python-defined methods and also consults `__ccflow_tokenizer_deps__` for behavior that lives outside the class body. This is useful when the callable depends on module-level helpers or shared helper classes that should also invalidate the cache key when they change. diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index f42ac17..7767123 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -6,7 +6,7 @@ 1. define stages as plain Python functions, 2. compose stages by passing upstream models as ordinary arguments, 3. rewrite contextual inputs on one dependency edge with `.flow.with_context(...)`, -4. execute either as `model(context)` or `model.flow.compute(...)`. +4. execute the configured graph with `model.flow.compute(...)`. Run with: python examples/flow_model_example.py @@ -17,6 +17,19 @@ from ccflow import DateRangeContext, Flow, FromContext +def _format_input_names(inputs: dict[str, object]) -> str: + """Return a compact comma-separated list for example output.""" + return ", ".join(inputs) or "(none)" + + +def _format_bound_inputs(inputs: dict[str, object]) -> str: + parts = [] + for name, value in inputs.items(): + display = "model" if hasattr(value, "flow") else repr(value) + parts.append(f"{name}={display}") + return ", ".join(parts) or "(none)" + + @Flow.model(context_type=DateRangeContext) def count_visitors( location: str, @@ -79,21 +92,25 @@ def main() -> None: end_date=date(2024, 3, 7), ) - direct = pipeline(ctx) - computed = pipeline.flow.compute( + computed_from_context = pipeline.flow.compute(ctx) + computed_from_kwargs = pipeline.flow.compute( start_date=ctx.start_date, end_date=ctx.end_date, ) print("\nPipeline:") - print(" current input:", pipeline.current) - print(" previous input:", pipeline.previous) + print(" model: visitor_delta") + print(f" bound inputs: {_format_bound_inputs(pipeline.flow.bound_inputs)}") + print(f" declared context inputs: {_format_input_names(pipeline.flow.context_inputs)}") + print(f" runtime inputs: {_format_input_names(pipeline.flow.runtime_inputs)}") + print(f" current runtime inputs: {_format_input_names(pipeline.current.flow.runtime_inputs)}") + print(f" previous runtime inputs: {_format_input_names(pipeline.previous.flow.runtime_inputs)}") print("\nExecution:") - print(f" direct == computed: {direct == computed}") + print(f" context object == kwargs: {computed_from_context == computed_from_kwargs}") print("\nResult:") - for key, value in computed.value.items(): + for key, value in computed_from_kwargs.value.items(): print(f" {key}: {value}") diff --git a/examples/flow_model_hydra_builder_demo.py b/examples/flow_model_hydra_builder_demo.py index 2fd3653..30722b6 100644 --- a/examples/flow_model_hydra_builder_demo.py +++ b/examples/flow_model_hydra_builder_demo.py @@ -24,6 +24,26 @@ CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" +def _format_input_names(inputs: dict[str, object]) -> str: + """Return a compact comma-separated list for example output.""" + return ", ".join(inputs) or "(none)" + + +def _format_bound_inputs(inputs: dict[str, object]) -> str: + parts = [] + for name, value in inputs.items(): + display = "model" if hasattr(value, "flow") else repr(value) + parts.append(f"{name}={display}") + return ", ".join(parts) or "(none)" + + +def _print_model_summary(label: str, model: CallableModel) -> None: + print(f" {label}:") + print(f" bound inputs: {_format_bound_inputs(model.flow.bound_inputs)}") + print(f" declared context inputs: {_format_input_names(model.flow.context_inputs)}") + print(f" runtime inputs: {_format_input_names(model.flow.runtime_inputs)}") + + @Flow.model(context_type=DateRangeContext) def count_visitors(location: str, start_date: FromContext[date], end_date: FromContext[date]) -> int: """Return a deterministic visitor count for one date window.""" @@ -40,7 +60,7 @@ def visitor_delta( label: str, start_date: FromContext[date], end_date: FromContext[date], -) -> dict: +) -> dict[str, object]: """Return both visitor counts plus their difference.""" return { "label": label, @@ -88,15 +108,15 @@ def main() -> None: print("Hydra + Flow.model Builder Demo") print("=" * 68) print("\nLoaded from config:") - print(" library_visitors:", registry["library_visitors"]) - print(" previous_week:", previous_week) - print(" previous_two_weeks:", previous_two_weeks) + _print_model_summary("library_visitors", registry["library_visitors"]) + _print_model_summary("previous_week", previous_week) + _print_model_summary("previous_two_weeks", previous_two_weeks) previous_week_result = previous_week.flow.compute( start_date=ctx.start_date, end_date=ctx.end_date, ).value - previous_two_weeks_result = previous_two_weeks(ctx).value + previous_two_weeks_result = previous_two_weeks.flow.compute(ctx).value print("\nPrevious week:") for key, value in previous_week_result.items(): From 08ca0aefce5b8aafb85e49512cb3a4f431306c48 Mon Sep 17 00:00:00 2001 From: Nijat K Date: Thu, 14 May 2026 11:37:14 -0400 Subject: [PATCH 3/8] Further simplify Flow.model internals Signed-off-by: Nijat K --- ccflow/_flow_model_binding.py | 6 ----- ccflow/flow_model.py | 38 ++++--------------------------- ccflow/tests/test_flow_context.py | 2 +- ccflow/tests/test_flow_model.py | 20 +++------------- docs/wiki/Flow-Model.md | 12 +++++----- 5 files changed, 14 insertions(+), 64 deletions(-) diff --git a/ccflow/_flow_model_binding.py b/ccflow/_flow_model_binding.py index 1be6cdf..e5464bd 100644 --- a/ccflow/_flow_model_binding.py +++ b/ccflow/_flow_model_binding.py @@ -14,7 +14,6 @@ from .base import ContextBase, ResultBase from .context import FlowContext -from .exttypes import PyObjectPath from .local_persistence import create_ccflow_model from .result import GenericResult @@ -105,7 +104,6 @@ class _FlowModelConfig: auto_unwrap: bool parameters: Tuple[_FlowModelParam, ...] declared_context_type: Optional[Type[ContextBase]] = None - path: Optional[PyObjectPath] = None _regular_params: Tuple[_FlowModelParam, ...] = field(init=False, repr=False) _contextual_params: Tuple[_FlowModelParam, ...] = field(init=False, repr=False) _regular_param_names: Tuple[str, ...] = field(init=False, repr=False) @@ -182,7 +180,6 @@ class _SerializedFlowModelConfig(NamedTuple): auto_unwrap: bool parameters: Tuple[_SerializedFlowModelParam, ...] declared_context_type: _SerializedAnnotation - path: Optional[PyObjectPath] def _callable_name(func: _AnyCallable) -> str: @@ -336,7 +333,6 @@ def _serialize_flow_model_config(config: _FlowModelConfig) -> _SerializedFlowMod auto_unwrap=config.auto_unwrap, parameters=tuple(_serialize_flow_model_param(param) for param in config.parameters), declared_context_type=_serialize_annotation(config.declared_context_type), - path=config.path, ) @@ -352,7 +348,6 @@ def _restore_flow_model_config(payload: _SerializedFlowModelConfig) -> _FlowMode auto_unwrap=payload.auto_unwrap, parameters=tuple(_restore_flow_model_param(param) for param in payload.parameters), declared_context_type=_restore_annotation(payload.declared_context_type), - path=payload.path, ) @@ -647,7 +642,6 @@ def _analyze_flow_context_transform( auto_wrap_result=False, auto_unwrap=False, parameters=parameters, - path=PyObjectPath(f"{getattr(fn, '__module__', __name__)}.{_callable_name(fn)}"), ) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 6854101..43be125 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -148,22 +148,15 @@ def _is_unset_flow_input(value: Any) -> bool: class ContextTransform(PydanticModel): """Serializable binding produced by ``@Flow.context_transform``. - Transform bindings store either an import path or a cloudpickled config - payload so bound models can survive pickle, cloudpickle, and Ray round - trips. + Transform bindings store the analyzed transform contract directly. We avoid + import-path restore here because decorators usually run before the module + global points at the returned transform factory. """ kind: Literal["context_transform"] = "context_transform" - path: Optional[PyObjectPath] = None - serialized_config: Optional[str] = None + serialized_config: str bound_args: Dict[str, Any] = Field(default_factory=dict) - @model_validator(mode="after") - def _validate_location(self): - if (self.path is None) == (self.serialized_config is None): - raise ValueError("ContextTransform must define exactly one of path or serialized_config.") - return self - class StaticValueSpec(PydanticModel): """A ``with_context(field=value)`` static contextual override.""" @@ -280,8 +273,6 @@ def _context_transform_repr(transform: Any) -> str: def _context_transform_identifier(binding: ContextTransform) -> str: - if binding.path is not None: - return str(binding.path) return _callable_name(_load_context_transform_config_from_binding(binding).func) @@ -502,20 +493,6 @@ def _ensure_named_python_function(fn: _AnyCallable, *, decorator_name: str) -> N # --------------------------------------------------------------------------- -@lru_cache(maxsize=None) -def _load_context_transform_factory(path: str) -> _AnyCallable: - return PyObjectPath(path).object - - -@lru_cache(maxsize=None) -def _load_context_transform_config(path: str) -> _FlowModelConfig: - factory = _load_context_transform_factory(path) - config = getattr(factory, "__flow_context_transform_config__", None) - if not isinstance(config, _FlowModelConfig): - raise TypeError(f"Stored context transform path '{path}' does not resolve to a Flow.context_transform binding.") - return config - - def _serialize_context_transform_config(config: _FlowModelConfig) -> str: import cloudpickle @@ -536,10 +513,6 @@ def _load_serialized_context_transform_config(serialized_config: str) -> _FlowMo def _load_context_transform_config_from_binding(binding: ContextTransform) -> _FlowModelConfig: - if binding.path is not None: - return _load_context_transform_config(str(binding.path)) - if binding.serialized_config is None: - raise TypeError("ContextTransform has neither path nor serialized_config.") return _load_serialized_context_transform_config(binding.serialized_config) @@ -548,8 +521,6 @@ def clear_flow_model_caches() -> None: _HASHABLE_TYPE_ADAPTER_CACHE.clear() _UNHASHABLE_TYPE_ADAPTER_CACHE.clear() - _load_context_transform_factory.cache_clear() - _load_context_transform_config.cache_clear() _load_serialized_context_transform_config.cache_clear() @@ -2443,7 +2414,6 @@ def factory(**kwargs) -> ContextTransform: """Bind regular transform arguments into a serializable spec.""" return ContextTransform( - path=None, serialized_config=serialized_config, bound_args=_validate_context_transform_factory_kwargs(config, kwargs), ) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 2eec6f9..7bb379d 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -290,7 +290,7 @@ def add(a: int, b: FromContext[int]) -> int: bound = add(a=10).flow.with_context(b=offset_b(amount=1)) dumped = bound.model_dump(mode="json") binding = dumped["context_spec"]["operations"][0]["spec"] - assert binding["path"] is not None or binding["serialized_config"] is not None + assert binding["serialized_config"] is not None restored = type(bound).model_validate(dumped) assert restored.flow.compute(b=4).value == 15 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index c73b79b..f0c1d8d 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -247,32 +247,22 @@ def add(a: FromContext[int]) -> int: assert add().flow.with_context(static_patch()).flow.compute(a=1).value == 2 -def test_context_transform_internal_error_and_repr_paths(): +def test_context_transform_internal_error_and_repr_payloads(): assert flow_model_module._context_transform_repr(static_patch()) == "static_patch()" assert flow_model_module._context_transform_repr(increment_b(amount=2)) == "increment_b(amount=2)" assert flow_model_module._context_transform_repr(123) == "123" importable_increment_b = Flow.context_transform(increment_b.__wrapped__) importable_config = flow_model_module._load_context_transform_config_from_binding(importable_increment_b(amount=1)) - assert importable_config.path is not None + assert importable_config.func.__name__ == "increment_b" assert flow_model_module._context_transform_identifier(importable_increment_b(amount=1)) == "increment_b" - with pytest.raises(ValidationError, match="exactly one"): + with pytest.raises(ValidationError, match="serialized_config"): flow_model_module.ContextTransform() - with pytest.raises(ValidationError, match="exactly one"): - flow_model_module.ContextTransform(path="ccflow.tests.test_flow_model.increment_b", serialized_config="also-set") - - with pytest.raises(TypeError, match="does not resolve to a Flow.context_transform binding"): - flow_model_module._load_context_transform_config("ccflow.tests.test_flow_model.lazy_context_transform_for_rejection") - invalid_payload = base64.b64encode(pickle.dumps({"not": "a config"}, protocol=5)).decode("ascii") with pytest.raises(TypeError, match="payload does not contain"): flow_model_module._load_serialized_context_transform_config(invalid_payload) - invalid_binding = flow_model_module.ContextTransform.model_construct(path=None, serialized_config=None, bound_args={}) - with pytest.raises(TypeError, match="neither path nor serialized_config"): - flow_model_module._load_context_transform_config_from_binding(invalid_binding) - with pytest.raises(ImportError, match="does not have a _generated_model"): flow_model_module._restore_generated_flow_model("ccflow.tests.test_flow_model.lazy_context_transform_for_rejection", {}) @@ -1710,7 +1700,6 @@ def test_context_transform_serializes_embedded_config_and_bound_args(): binding = transform_factory(amount=3) assert isinstance(binding, flow_model_module.ContextTransform) assert binding.kind == "context_transform" - assert binding.path is None assert binding.serialized_config is not None assert binding.bound_args == {"amount": 3} @@ -1739,7 +1728,6 @@ def increment(value: FromContext[int]) -> int: transform_factory = Flow.context_transform(increment) binding = transform_factory() - assert binding.path is None assert binding.serialized_config is not None @Flow.model @@ -1779,7 +1767,6 @@ def nested_transform(b: FromContext[int], amount: int) -> int: return b + amount binding = nested_transform(amount=3) - assert binding.path is None assert binding.serialized_config is not None bound = add(a=1).flow.with_context(b=binding) @@ -1800,7 +1787,6 @@ def main_transform(value: FromContext[int]) -> int: transformed = Flow.context_transform(main_transform) binding = transformed() - assert binding.path is None assert binding.serialized_config is not None bound = add(a=1).flow.with_context(b=binding) diff --git a/docs/wiki/Flow-Model.md b/docs/wiki/Flow-Model.md index 44e148e..0007e38 100644 --- a/docs/wiki/Flow-Model.md +++ b/docs/wiki/Flow-Model.md @@ -350,12 +350,12 @@ Key rules: multiple fields must move together, put that logic inside one patch transform. -Context transforms serialize enough function metadata for bound models to move -through pickle and Ray workers. Importable module-level transforms may serialize -by module path, while local, nested, `__main__`, and notebook-defined transforms -use an embedded cloudpickle payload. For long-lived YAML/JSON configuration, -prefer small importable module functions and inspect the generated config shape -before treating it as a stable hand-written config format. +Context transforms serialize the analyzed transform contract directly so bound +models can move through pickle and Ray workers without re-resolving annotations +from the defining module. This applies to importable module functions, local +functions, nested functions, `__main__`, and notebook-defined transforms. For +long-lived YAML/JSON configuration, inspect the generated config shape before +treating it as a stable hand-written config format. ## `context_type=...` From cd25056eb3678d140aa772c033dccdf5c0ec63b9 Mon Sep 17 00:00:00 2001 From: Nijat K Date: Thu, 14 May 2026 11:50:38 -0400 Subject: [PATCH 4/8] Trim public export Signed-off-by: Nijat K --- ccflow/evaluators/common.py | 34 ++++++++++----------------------- ccflow/flow_model.py | 2 -- ccflow/tests/test_flow_model.py | 4 ++++ 3 files changed, 14 insertions(+), 26 deletions(-) diff --git a/ccflow/evaluators/common.py b/ccflow/evaluators/common.py index 06cdfad..c5ceb74 100644 --- a/ccflow/evaluators/common.py +++ b/ccflow/evaluators/common.py @@ -92,22 +92,15 @@ def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[Evaluato return MultiEvaluator(evaluators=[first, second]) -def _flatten_cache_key_context(evaluation_context: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[CallableModel]]: - """Strip transparent evaluator wrappers and keep opaque wrappers in order. - - This preserves the structural cache-key behavior: transparent evaluators are - ignored, while non-transparent evaluators remain part of the identity. The - returned function name is the innermost non-``__call__`` name, so - ``__deps__`` does not collapse into ``__call__`` when wrapped. - """ - fn = evaluation_context.fn - outer_to_inner_evaluators: List[CallableModel] = [] - while isinstance(evaluation_context.context, ModelEvaluationContext): - fn = evaluation_context.fn if evaluation_context.fn != "__call__" else fn - if not isinstance(evaluation_context, TransparentModelEvaluationContext): - outer_to_inner_evaluators.append(evaluation_context.model) - evaluation_context = evaluation_context.context - return evaluation_context, fn if fn != "__call__" else evaluation_context.fn, outer_to_inner_evaluators +def _flatten_cache_key_context(flow_obj: ModelEvaluationContext) -> tuple[ModelEvaluationContext, str, List[EvaluatorBase]]: + fn = flow_obj.fn + non_transparent: List[EvaluatorBase] = [] + while isinstance(flow_obj.context, ModelEvaluationContext): + fn = flow_obj.fn if flow_obj.fn != "__call__" else fn + if not isinstance(flow_obj, TransparentModelEvaluationContext): + non_transparent.append(flow_obj.model) + flow_obj = flow_obj.context + return flow_obj, fn if fn != "__call__" else flow_obj.fn, non_transparent class MultiEvaluator(EvaluatorBase): @@ -480,13 +473,6 @@ class CallableModelGraph(BaseModel): root_id: bytes -def _is_wrapper_to_wrapped_edge(parent_model: Optional[CallableModel], current_model: CallableModel) -> bool: - # Effective identity can intentionally collapse a wrapper model and its - # wrapped model to the same graph/cache key. Only that wrapper-to-wrapped - # edge should be treated as a duplicate self-edge. - return isinstance(parent_model, WrapperModel) and parent_model.model is current_model - - def _build_dependency_graph( evaluation_context: ModelEvaluationContext, graph: CallableModelGraph, @@ -500,7 +486,7 @@ def _build_dependency_graph( unwrapped_evaluation_context, _, _ = _flatten_cache_key_context(evaluation_context) current_model = unwrapped_evaluation_context.model is_same_evaluation_key = parent_key == key - is_collapsed_wrapper_child = is_same_evaluation_key and _is_wrapper_to_wrapped_edge(parent_model, current_model) + is_collapsed_wrapper_child = is_same_evaluation_key and isinstance(parent_model, WrapperModel) and parent_model.model is current_model # Bound/wrapper models can share an effective graph key with their wrapped # model after context rewriting. Adding the wrapper -> wrapped edge in that diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 43be125..e134fe4 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -100,8 +100,6 @@ "BoundModel", "FromContext", "Lazy", - "ContextTransform", - "flow_context_transform", ) _AnyCallable = Callable[..., Any] diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index f0c1d8d..0dea0d8 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -3853,4 +3853,8 @@ def add(a: int, b: FromContext[int]) -> int: def test_flow_model_public_exports_exclude_context_spec_models(): assert "StaticValueSpec" not in flow_model_module.__all__ + assert "ContextTransform" not in flow_model_module.__all__ + assert "flow_context_transform" not in flow_model_module.__all__ assert not hasattr(ccflow, "StaticValueSpec") + assert not hasattr(ccflow, "ContextTransform") + assert not hasattr(ccflow, "flow_context_transform") From 16aae1f6af83770728ca73b57c52c6c97fe77875 Mon Sep 17 00:00:00 2001 From: Nijat K Date: Thu, 14 May 2026 12:20:11 -0400 Subject: [PATCH 5/8] Bug fixes Signed-off-by: Nijat K --- ccflow/_flow_model_binding.py | 7 +- ccflow/callable.py | 4 +- ccflow/evaluators/common.py | 14 +- ccflow/flow_model.py | 399 +++++++++++++------------ ccflow/tests/evaluators/test_common.py | 2 +- ccflow/tests/test_flow_context.py | 39 +++ ccflow/tests/test_flow_model.py | 372 +++++++++++++++++++++-- docs/wiki/Flow-Model.md | 48 ++- 8 files changed, 640 insertions(+), 245 deletions(-) diff --git a/ccflow/_flow_model_binding.py b/ccflow/_flow_model_binding.py index e5464bd..a375617 100644 --- a/ccflow/_flow_model_binding.py +++ b/ccflow/_flow_model_binding.py @@ -64,7 +64,7 @@ class Lazy: """Lazy dependency marker used only as ``Lazy[T]`` in type annotations.""" def __new__(cls, *args, **kwargs): - raise TypeError("Lazy(model)(...) has been removed. Use model.flow.with_context(...) for contextual rewrites.") + raise TypeError("Lazy is an annotation marker; use Lazy[T] in @Flow.model signatures.") def __class_getitem__(cls, item): return Annotated[item, _LazyMarker()] @@ -400,6 +400,11 @@ def _parse_annotation(annotation: Any) -> _ParsedAnnotation: elif isinstance(metadata, _FromContextMarker): is_from_context = True + if annotation is FromContext: + raise TypeError("FromContext is an annotation marker; use FromContext[T] in @Flow.model signatures.") + if annotation is Lazy: + raise TypeError("Lazy is an annotation marker; use Lazy[T] in @Flow.model signatures.") + return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context) diff --git a/ccflow/callable.py b/ccflow/callable.py index 6a113e2..3dd8530 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -519,9 +519,9 @@ def context_transform(*args, **kwargs): Transform bindings serialize enough function metadata to survive model serialization, including local or nested functions through cloudpickle. """ - from .flow_model import flow_context_transform + from .flow_model import _flow_context_transform - return flow_context_transform(*args, **kwargs) + return _flow_context_transform(*args, **kwargs) # ***************************************************************************** diff --git a/ccflow/evaluators/common.py b/ccflow/evaluators/common.py index c5ceb74..bae133f 100644 --- a/ccflow/evaluators/common.py +++ b/ccflow/evaluators/common.py @@ -350,7 +350,7 @@ def _effective_evaluation_key( raise # Effective identity is an optimization/semantic narrowing for opt-in # generated models. If deriving it is unclear, do not make cache/graph - # key construction a new failure mode; use the old structural key. + # key construction a failure mode; fall back to the structural key. log.debug("Falling back to structural evaluation key for %s.__call__: %s", type(inner.model).__name__, exc) return cache_key(evaluation_context) if key is None: @@ -501,13 +501,11 @@ def _build_dependency_graph( if is_new_graph_key: graph.graph[key] = set() - # Main used ``key not in graph.graph`` as the traversal guard. That is no - # longer enough once effective identity can merge multiple model objects - # to one key: a bound wrapper and its wrapped model may share the graph node, - # but the wrapped model still has dependencies that must be traversed. - # - # Preserve normal graph deduplication by key, and make the only exception - # the exact same-key wrapper -> wrapped edge. + # Effective identity can merge multiple model objects into one graph key. + # A bound wrapper and its wrapped model may share the graph node, but the + # wrapped model still has dependencies that must be traversed. Preserve + # normal graph deduplication by key, and make the only exception the exact + # same-key wrapper -> wrapped edge. if not is_new_graph_key and not is_collapsed_wrapper_child: return key diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index e134fe4..2fec4b9 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -45,9 +45,10 @@ import inspect import sys +from abc import update_abstractmethods from base64 import b64decode, b64encode from collections import OrderedDict -from functools import lru_cache, singledispatch, wraps +from functools import lru_cache, wraps from typing import ( Annotated, Any, @@ -69,7 +70,7 @@ get_type_hints, ) -from pydantic import BaseModel as PydanticModel, Field, SkipValidation, TypeAdapter, ValidationError, model_validator +from pydantic import BaseModel as PydanticModel, Field, PrivateAttr, SkipValidation, TypeAdapter, ValidationError, create_model, model_validator from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation from ._flow_model_binding import ( @@ -193,6 +194,8 @@ class _BoundContextSpec(PydanticModel): class _BoundModelContext(FlowContext): """Flow.call carrier for BoundModel that preserves existing context objects.""" + _base_context: Optional[ContextBase] = PrivateAttr(default=None) + @model_validator(mode="wrap") @classmethod def _preserve_context_base(cls, value, handler, info): @@ -200,6 +203,12 @@ def _preserve_context_base(cls, value, handler, info): return value return handler(value) + @classmethod + def from_values(cls, values: Dict[str, Any], base_context: Optional[ContextBase] = None) -> "_BoundModelContext": + context = cls(**values) + context._base_context = base_context + return context + class _DependencyIdentity(NamedTuple): kind: Literal["dependency"] @@ -239,16 +248,7 @@ class _GeneratedModelIdentity(NamedTuple): class _LocalFlowModelPicklePayload(NamedTuple): serialized_config: Any factory_kwargs: Dict[str, Any] - model_data: Dict[str, Any] - - -class _PortableBaseModelState(NamedTuple): - data: Dict[str, Any] - - -class _PortablePydanticModelState(NamedTuple): - model_type: str - data: Dict[str, Any] + state_bytes: bytes # --------------------------------------------------------------------------- @@ -502,11 +502,11 @@ def _serialize_context_transform_config(config: _FlowModelConfig) -> str: def _load_serialized_context_transform_config(serialized_config: str) -> _FlowModelConfig: import cloudpickle - payload = cloudpickle.loads(b64decode(serialized_config.encode("ascii"))) try: + payload = cloudpickle.loads(b64decode(serialized_config.encode("ascii"))) config = _restore_flow_model_config(payload) - except (TypeError, ValueError): - raise TypeError("Stored context transform payload does not contain a Flow.context_transform binding.") + except Exception as exc: + raise TypeError("Stored context transform payload does not contain a Flow.context_transform binding.") from exc return config @@ -537,9 +537,12 @@ def _is_mapping_annotation(annotation: Any) -> bool: return False -def _restore_pickled_flow_model(type_path: str, model_data: Dict[str, Any]) -> BaseModel: - cls = cast(type[BaseModel], PyObjectPath(type_path).object) - return cls.model_validate(_restore_portable_generated_model_state(model_data)) +def _restore_model_from_cloudpickled_state(cls: type[BaseModel], state_bytes: bytes) -> BaseModel: + import cloudpickle + + obj = cls.__new__(cls) + obj.__setstate__(cloudpickle.loads(state_bytes)) + return obj def _restore_pickled_local_flow_model(serialized_factory_payload: bytes) -> BaseModel: @@ -554,141 +557,24 @@ def _restore_pickled_local_flow_model(serialized_factory_payload: bytes) -> Base # contract from the defining process; rebuild the generated class from it. factory = _build_flow_model_factory_from_config(config, payload.factory_kwargs) cls = cast(type[BaseModel], getattr(factory, "_generated_model")) - return cls.model_validate(_restore_portable_generated_model_state(payload.model_data)) + return _restore_model_from_cloudpickled_state(cls, payload.state_bytes) -def _restore_generated_flow_model(factory_path: str, model_data: Dict[str, Any]) -> BaseModel: +def _restore_generated_flow_model(factory_path: str, state_bytes: bytes) -> BaseModel: """Restore a generated flow model by importing its factory function. This is the cross-process-safe restore path: importing the factory's module triggers the ``@Flow.model`` decorator, which re-creates the GeneratedModel - class. The instance is reconstructed through normal validation data instead - of raw Pydantic state because raw state can embed process-local generic - classes. + class. The instance is reconstructed from cloudpickled Pydantic state, not + validation data, so private attrs and ordinary user literals survive Ray + worker handoff. Runtime-created Pydantic generic specializations still need + a base serialization fix before they are portable as loose state. """ factory = PyObjectPath(factory_path).object generated_cls = getattr(factory, "_generated_model", None) if generated_cls is None: raise ImportError(f"Cannot restore generated flow model: '{factory_path}' does not have a _generated_model attribute.") - return generated_cls.model_validate(_restore_portable_generated_model_state(model_data)) - - -@singledispatch -def _portable_generated_model_state_value(value: Any) -> Any: - """Remove fragile Pydantic generic instance classes from local pickle state.""" - - if _is_unset_flow_input(value) or _is_model_dependency(value): - return value - return value - - -@_portable_generated_model_state_value.register -def _(value: BaseModel) -> Any: - if _is_model_dependency(value): - return value - return _PortableBaseModelState(data=value.model_dump(mode="python", by_alias=True)) - - -@_portable_generated_model_state_value.register -def _(value: PydanticModel) -> Any: - if _is_model_dependency(value): - return value - return _PortablePydanticModelState( - model_type=str(PyObjectPath.validate(type(value))), - data=_portable_generated_model_state_value(value.model_dump(mode="python", by_alias=True)), - ) - - -@_portable_generated_model_state_value.register -def _(value: tuple) -> tuple: - return tuple(_portable_generated_model_state_value(item) for item in value) - - -@_portable_generated_model_state_value.register -def _(value: list) -> list: - return [_portable_generated_model_state_value(item) for item in value] - - -@_portable_generated_model_state_value.register -def _(value: OrderedDict) -> OrderedDict: - return OrderedDict((key, _portable_generated_model_state_value(item)) for key, item in value.items()) - - -@_portable_generated_model_state_value.register -def _(value: dict) -> dict: - return {key: _portable_generated_model_state_value(item) for key, item in value.items()} - - -@_portable_generated_model_state_value.register -def _(value: frozenset) -> frozenset: - return frozenset(_portable_generated_model_state_value(item) for item in value) - - -@_portable_generated_model_state_value.register -def _(value: set) -> set: - return {_portable_generated_model_state_value(item) for item in value} - - -def _portable_generated_model_state(model: "_GeneratedFlowModelBase") -> Dict[str, Any]: - """Return validation data for a generated model without raw Pydantic state.""" - - data: Dict[str, Any] = {} - for name in type(model).model_fields: - value = getattr(model, name, _UNSET_FLOW_INPUT) - if _is_unset_flow_input(value): - continue - data[name] = _portable_generated_model_state_value(value) - return data - - -@singledispatch -def _restore_portable_generated_model_state_value(value: Any) -> Any: - return value - - -@_restore_portable_generated_model_state_value.register -def _(value: _PortableBaseModelState) -> BaseModel: - return BaseModel.model_validate(_restore_portable_generated_model_state_value(value.data)) - - -@_restore_portable_generated_model_state_value.register -def _(value: _PortablePydanticModelState) -> PydanticModel: - cls = cast(type[PydanticModel], PyObjectPath(value.model_type).object) - return cls.model_validate(_restore_portable_generated_model_state_value(value.data)) - - -@_restore_portable_generated_model_state_value.register -def _(value: tuple) -> tuple: - return tuple(_restore_portable_generated_model_state_value(item) for item in value) - - -@_restore_portable_generated_model_state_value.register -def _(value: list) -> list: - return [_restore_portable_generated_model_state_value(item) for item in value] - - -@_restore_portable_generated_model_state_value.register -def _(value: OrderedDict) -> OrderedDict: - return OrderedDict((key, _restore_portable_generated_model_state_value(item)) for key, item in value.items()) - - -@_restore_portable_generated_model_state_value.register -def _(value: dict) -> dict: - return {key: _restore_portable_generated_model_state_value(item) for key, item in value.items()} - - -@_restore_portable_generated_model_state_value.register -def _(value: frozenset) -> frozenset: - return frozenset(_restore_portable_generated_model_state_value(item) for item in value) - - -@_restore_portable_generated_model_state_value.register -def _(value: set) -> set: - return {_restore_portable_generated_model_state_value(item) for item in value} - - -def _restore_portable_generated_model_state(data: Dict[str, Any]) -> Dict[str, Any]: - return {name: _restore_portable_generated_model_state_value(value) for name, value in data.items()} + return _restore_model_from_cloudpickled_state(generated_cls, state_bytes) def _is_importable_function(func: _AnyCallable) -> bool: @@ -839,16 +725,44 @@ def _dependency_context_for_model(model: CallableModel, context: ContextBase) -> return _runtime_context_for_model(model, _project_context_values_for_model(model, _context_values(context))) +def _bound_model_default_base_context(bound_model: "BoundModel") -> Optional[ContextBase]: + """Return the wrapped plain model's default context object when one exists.""" + + contract = _model_context_contract(bound_model.model) + if contract.generated_model is not None: + return None + default_context = _plain_model_default_context(bound_model.model) + if default_context is _UNSET or default_context is None: + return None + if isinstance(default_context, ContextBase) and _context_matches_type(default_context, bound_model.model.context_type): + return default_context + return contract.runtime_context_type.model_validate(default_context) + + +def _bound_model_ambient_context(bound_model: "BoundModel", values: Dict[str, Any]) -> _BoundModelContext: + """Return the ambient carrier a bound wrapper should rewrite.""" + + base_context = _bound_model_default_base_context(bound_model) + if base_context is not None: + ambient = _context_values(base_context) + ambient.update(values) + return _BoundModelContext.from_values(ambient, base_context=base_context) + return _BoundModelContext.from_values(values) + + def _resolved_dependency_invocation(value: CallableModel, context: ContextBase) -> Tuple[CallableModel, ContextBase]: """Return the concrete ``(model, context)`` pair for a dependency call. Bound models must receive the full ambient ``FlowContext`` so their binding - transforms can read source fields before narrowing to the wrapped model's - context. Unbound dependencies can be projected immediately. + transforms can read source fields before narrowing to the wrapped model's. + If the wrapper targets a handwritten model with a decorated default context, + dependency execution uses the same default baseline as ``bound.flow.compute``. + Unbound dependencies can be projected immediately. """ if isinstance(value, BoundModel): - return value, FlowContext(**_context_values(context)) + values = {} if context is None else _context_values(context) + return value, _bound_model_ambient_context(value, values) return value, _dependency_context_for_model(value, context) @@ -1474,17 +1388,13 @@ def _bound_context_transform_regular_kwargs(config: _FlowModelConfig, binding: C def _evaluate_static_context_transform(binding: ContextTransform) -> Any: - """Evaluate a transform at binding time if it has no required contextual inputs.""" + """Evaluate a transform only when it has no contextual inputs at all.""" config = _load_context_transform_config_from_binding(binding) - kwargs = _bound_context_transform_regular_kwargs(config, binding) - - for param in config.contextual_params: - if param.has_function_default: - kwargs[param.name] = param.function_default - continue + if config.contextual_params: return _UNSET + kwargs = _bound_context_transform_regular_kwargs(config, binding) return config.func(**kwargs) @@ -1558,14 +1468,15 @@ def _merge_context_input_types(target: Dict[str, Any], updates: Dict[str, Any]) target[name] = annotation -def _merge_dynamic_context_operation_inputs( - target: Dict[str, Any], model: CallableModel, context_spec: _BoundContextSpec, *, required_only: bool -) -> None: +def _dynamic_context_operation_effects(context_spec: _BoundContextSpec, *, required_only: bool) -> Tuple[Set[str], Dict[str, Any]]: + supplied_fields: Set[str] = set() + input_types: Dict[str, Any] = {} + for operation in _effective_context_operations(context_spec): if isinstance(operation, PatchContextOperation): patch_result = _evaluate_static_context_transform(operation.binding) if patch_result is _UNSET: - _merge_context_input_types(target, _context_transform_input_types(operation.binding, required_only=required_only)) + _merge_context_input_types(input_types, _context_transform_input_types(operation.binding, required_only=required_only)) continue continue @@ -1574,10 +1485,12 @@ def _merge_dynamic_context_operation_inputs( value = _evaluate_static_context_transform(operation.spec) if value is _UNSET: - target.pop(operation.name, None) - _merge_context_input_types(target, _context_transform_input_types(operation.spec, required_only=required_only)) + supplied_fields.add(operation.name) + _merge_context_input_types(input_types, _context_transform_input_types(operation.spec, required_only=required_only)) continue + return supplied_fields, input_types + def _validate_static_context_spec_declared_context(model: CallableModel, context_spec: _BoundContextSpec) -> _BoundContextSpec: generated = _generated_model_instance(model) @@ -1699,6 +1612,18 @@ def _normalize_with_context(model: CallableModel, patches: Tuple[Any, ...], fiel # --------------------------------------------------------------------------- +def _context_from_values_preserving_private_state(context: ContextBase, values: Dict[str, Any]) -> ContextBase: + """Validate updated public values while preserving private context state.""" + + if values == _context_values(context): + return context + validated = type(context).model_validate(values) + private = getattr(context, "__pydantic_private__", None) + if private is not None: + object.__setattr__(validated, "__pydantic_private__", dict(private)) + return validated + + def _apply_context_spec_values(model: CallableModel, context_spec: _BoundContextSpec, context: ContextBase) -> Dict[str, Any]: """Apply a binding spec at execution time and return rewritten context values.""" @@ -1725,6 +1650,9 @@ def _apply_context_spec(model: CallableModel, context_spec: _BoundContextSpec, c if not context_spec.operations: if isinstance(context, _BoundModelContext): + if context._base_context is not None: + values = _project_context_values_for_model(model, _context_values(context)) + return _context_from_values_preserving_private_state(context._base_context, values) return _dependency_context_for_model(model, context) if _context_matches_type(context, model.context_type): return context @@ -1732,12 +1660,65 @@ def _apply_context_spec(model: CallableModel, context_spec: _BoundContextSpec, c values = _apply_context_spec_values(model, context_spec, context) if isinstance(context, _BoundModelContext): - return _runtime_context_for_model(model, _project_context_values_for_model(model, values)) + values = _project_context_values_for_model(model, values) + if context._base_context is not None: + return _context_from_values_preserving_private_state(context._base_context, values) + return _runtime_context_for_model(model, values) if _context_matches_type(context, model.context_type): - return type(context).model_validate(values) + return _context_from_values_preserving_private_state(context, values) return _runtime_context_for_model(model, _project_context_values_for_model(model, values)) +def _plain_model_default_context(model: CallableModel) -> Any: + call = getattr(type(model), "__call__", None) + wrapped = getattr(call, "__wrapped__", None) + if wrapped is None: + return _UNSET + try: + parameter = inspect.signature(wrapped).parameters.get("context") + except (TypeError, ValueError): + return _UNSET + if parameter is None or parameter.default is inspect.Signature.empty: + return _UNSET + return parameter.default + + +def _plain_model_default_context_values( + model: CallableModel, + runtime_context_type: Type[ContextBase], +) -> Optional[Dict[str, Any]]: + default_context = _plain_model_default_context(model) + if default_context is _UNSET: + return None + if default_context is None: + if _is_optional_context_type(model.context_type): + return {} + return _context_values(runtime_context_type.model_validate(default_context)) + if isinstance(default_context, ContextBase): + return _context_values(default_context) + return _context_values(runtime_context_type.model_validate(default_context)) + + +def _plain_model_compute_context_from_default( + model: CallableModel, + default_context: Any, + default_values: Dict[str, Any], + kwargs: Dict[str, Any], + runtime_context_type: Type[ContextBase], +) -> Optional[ContextBase]: + if default_context is None and not kwargs and _is_optional_context_type(model.context_type): + return None + + if isinstance(default_context, ContextBase) and _context_matches_type(default_context, model.context_type): + values = dict(default_values) + values.update(kwargs) + return _context_from_values_preserving_private_state(default_context, values) + + values = dict(default_values) + values.update(kwargs) + return runtime_context_type.model_validate(values) + + def _build_compute_context(model: CallableModel, context: Any, kwargs: Dict[str, Any]) -> Optional[ContextBase]: """Construct the context used by ``FlowAPI.compute`` for a target model. @@ -1766,6 +1747,10 @@ def _build_compute_context(model: CallableModel, context: Any, kwargs: Dict[str, return contract.runtime_context_type.model_validate(context) if contract.generated_model is None: + default_context = _plain_model_default_context(model) + if default_context is not _UNSET: + default_values = _plain_model_default_context_values(model, contract.runtime_context_type) + return _plain_model_compute_context_from_default(model, default_context, default_values, kwargs, contract.runtime_context_type) if not kwargs and _ctx_is_optional: return None return contract.runtime_context_type.model_validate(kwargs) @@ -1831,7 +1816,7 @@ def _build_bound_compute_context(bound_model: "BoundModel", context: Any, kwargs return context if not kwargs and _bound_model_preserves_none_context(bound_model): return None - return FlowContext(**kwargs) + return _bound_model_ambient_context(bound_model, kwargs) # --------------------------------------------------------------------------- @@ -1883,7 +1868,14 @@ def required_inputs(self) -> Dict[str, Any]: if contract.generated_model is None and _is_optional_context_type(self._model.context_type): return {} if contract.generated_model is None: - return {} if contract.input_types is None else {name: contract.input_types[name] for name in contract.required_names} + if contract.input_types is None: + return {} + result = {name: contract.input_types[name] for name in contract.required_names} + default_values = _plain_model_default_context_values(self._model, contract.runtime_context_type) + if default_values is not None: + for name in default_values: + result.pop(name, None) + return result generated = contract.generated_model config = type(generated).__flow_model_config__ @@ -1948,9 +1940,6 @@ class BoundModel(WrapperModel): context_spec: _BoundContextSpec = Field(default_factory=_BoundContextSpec, repr=False) - def __reduce__(self): - return (_restore_pickled_flow_model, (str(PyObjectPath.validate(type(self))), _portable_generated_model_state(self))) - def _rewrite_context(self, context: ContextBase) -> ContextBase: """Apply this wrapper's context bindings to an ambient runtime context.""" @@ -2048,7 +2037,10 @@ def runtime_inputs(self) -> Dict[str, Any]: result = super().context_inputs for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): result.pop(name, None) - _merge_dynamic_context_operation_inputs(result, self._bound.model, self._bound.context_spec, required_only=False) + supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=False) + for name in supplied_fields: + result.pop(name, None) + _merge_context_input_types(result, dynamic_inputs) return result @property @@ -2062,7 +2054,16 @@ def required_inputs(self) -> Dict[str, Any]: result = super().required_inputs for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): result.pop(name, None) - _merge_dynamic_context_operation_inputs(result, self._bound.model, self._bound.context_spec, required_only=True) + supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=True) + for name in supplied_fields: + result.pop(name, None) + _merge_context_input_types(result, dynamic_inputs) + contract = _model_context_contract(self._bound.model) + if contract.generated_model is None: + default_values = _plain_model_default_context_values(self._bound.model, contract.runtime_context_type) + if default_values is not None: + for name in default_values: + result.pop(name, None) return result def with_context(self, *patches, **field_overrides) -> BoundModel: @@ -2084,19 +2085,25 @@ def __reduce__(self): """Prefer import-path restoration, falling back to serialized local factories.""" config = type(self).__flow_model_config__ + import cloudpickle + + state_bytes = cloudpickle.dumps(self.__getstate__(), protocol=5) factory_path = _generated_model_factory_path_for_pickle(config, type(self)) if factory_path is not None: - return (_restore_generated_flow_model, (factory_path, _portable_generated_model_state(self))) - import cloudpickle + return (_restore_generated_flow_model, (factory_path, state_bytes)) # Local generated classes are not normal importable class definitions: - # plain pickle cannot reconstruct them, and cloudpickle would otherwise - # walk generated signatures/model metadata that can contain fragile - # runtime-only annotations such as GenericResult[int]. + # plain pickle cannot reconstruct them, and Ray workers should not + # re-run fragile annotation resolution. Carry the analyzed contract + # separately, but keep instance state as Pydantic pickle state so user + # literals and private attrs survive cloudpickle. Runtime-created + # Pydantic generic specializations such as GenericResult[int] still need + # a base serialization fix before they are portable across fresh + # processes when they appear as loose user state. payload = _LocalFlowModelPicklePayload( serialized_config=_serialize_flow_model_config(config), factory_kwargs=type(self).__flow_model_factory_kwargs__, - model_data=_portable_generated_model_state(self), + state_bytes=state_bytes, ) return (_restore_pickled_local_flow_model, (cloudpickle.dumps(payload, protocol=5),)) @@ -2307,7 +2314,7 @@ def _context_transform_factory_signature(config: _FlowModelConfig) -> inspect.Si default=param.function_default if param.has_function_default else inspect.Parameter.empty, ) ) - return inspect.Signature(parameters=parameters, return_annotation=ContextTransform) + return inspect.Signature(parameters=parameters) def _resolve_generated_model_bases(model_base: Type[CallableModel]) -> Tuple[type, ...]: @@ -2333,35 +2340,37 @@ def _build_flow_model_factory_from_config(config: _FlowModelConfig, factory_kwar # restore. See ``_flow_model_config_identity`` for why this must be fixed at # class-construction time instead of recalculated on every cache-key build. config_identity = factory_kwargs.setdefault("_flow_model_identity", _flow_model_config_identity(config)) - annotations: Dict[str, Any] = {} - namespace: Dict[str, Any] = { - "__module__": getattr(fn, "__module__", __name__), - "__qualname__": f"_{_callable_name(fn)}_Model", - "__call__": Flow.call( - **{ - name: value - for name in ("cacheable", "volatile", "log_level", "validate_result", "verbose", "evaluator") - if (value := factory_kwargs.get(name, _UNSET)) is not _UNSET - } - )(_make_call_impl(config)), - "__deps__": Flow.deps(_make_deps_impl(config)), - } + field_definitions: Dict[str, Any] = {} for param in config.parameters: - annotations[param.name] = _generated_field_annotation(param) + annotation = _generated_field_annotation(param) if param.is_contextual: - namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + default = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) elif param.has_function_default: - namespace[param.name] = param.function_default + default = param.function_default else: - namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) - - namespace["__annotations__"] = annotations + default = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + field_definitions[param.name] = (annotation, default) GeneratedModel = cast( type[_GeneratedFlowModelBase], - type(f"_{_callable_name(fn)}_Model", _resolve_generated_model_bases(model_base), namespace), + create_model( + f"_{_callable_name(fn)}_Model", + __base__=_resolve_generated_model_bases(model_base), + __module__=getattr(fn, "__module__", __name__), + __qualname__=f"_{_callable_name(fn)}_Model", + **field_definitions, + ), ) + GeneratedModel.__call__ = Flow.call( + **{ + name: value + for name in ("cacheable", "volatile", "log_level", "validate_result", "verbose", "evaluator") + if (value := factory_kwargs.get(name, _UNSET)) is not _UNSET + } + )(_make_call_impl(config)) + GeneratedModel.__deps__ = Flow.deps(_make_deps_impl(config)) + update_abstractmethods(GeneratedModel) GeneratedModel.__flow_model_config__ = config GeneratedModel.__flow_model_factory_kwargs__ = factory_kwargs GeneratedModel.__flow_model_identity__ = config_identity @@ -2381,7 +2390,7 @@ def factory(**kwargs) -> _GeneratedFlowModelBase: return factory -def flow_context_transform(func: Optional[_AnyCallable] = None) -> _AnyCallable: +def _flow_context_transform(func: Optional[_AnyCallable] = None) -> _AnyCallable: """Decorator that turns a function into a serializable ``with_context`` transform factory. Regular parameters are bound when the transform factory is called. diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index 37f3dd0..8577207 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -598,7 +598,7 @@ def test_plain_callable_graph_keys_match_public_cache_key(self): self.assertEqual(graph.root_id, cache_key(ModelEvaluationContext(model=root, context=context))) def test_plain_callable_graph_deduplicates_equal_models_by_key(self): - """Ordinary graph traversal should keep main's key-only deduplication.""" + """Ordinary graph traversal deduplicates structurally equal nodes by key.""" leaf1 = NodeModel(meta=dict(name="leaf")) leaf2 = NodeModel(meta=dict(name="leaf")) root = NodeModel(meta=dict(name="root"), deps_model=[leaf1, leaf2]) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 7bb379d..d8c98bb 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -246,6 +246,45 @@ def from_str(seed: FromContext[str]) -> int: bound.flow.runtime_inputs +def test_bound_flow_api_keeps_dynamic_transform_source_inputs_after_later_field_bindings(): + @Flow.model + def add(x: FromContext[int], y: FromContext[int]) -> int: + return x + y + + @Flow.context_transform + def from_y(y: FromContext[int]) -> int: + return y + 1 + + @Flow.context_transform + def from_x(x: FromContext[int]) -> int: + return x + 10 + + bound = add().flow.with_context(x=from_y()).flow.with_context(y=from_x()) + + assert bound.flow.context_inputs == {"x": int, "y": int} + assert bound.flow.runtime_inputs == {"x": int, "y": int} + assert bound.flow.required_inputs == {"x": int, "y": int} + assert bound.flow.compute(x=1, y=2).value == 14 + + +def test_bound_flow_api_reports_optional_transform_context_inputs_as_runtime_only(): + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + @Flow.context_transform + def seed_plus_one(seed: FromContext[int] = 0) -> int: + return seed + 1 + + bound = add().flow.with_context(b=seed_plus_one()) + + assert bound.flow.runtime_inputs == {"a": int, "seed": int} + assert bound.flow.required_inputs == {"a": int} + assert bound.flow.bound_inputs == {} + assert bound.flow.compute(a=10).value == 11 + assert bound.flow.compute(a=10, seed=5).value == 16 + + def test_bound_model_rejects_regular_field_context_overrides(): @Flow.model def add(a: int, b: FromContext[int]) -> int: diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 0dea0d8..f30980f 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -10,17 +10,18 @@ from datetime import date, timedelta from pathlib import Path from types import ModuleType -from typing import Annotated, Any, Callable, Literal, Optional, get_args +from typing import Annotated, Any, Callable, Generic, Literal, Optional, TypeVar, get_args import pytest import ray -from pydantic import Field, ValidationError, model_validator +from pydantic import BaseModel as PydanticBaseModel, Field, PrivateAttr, ValidationError, model_validator from ray.cloudpickle import dumps as rcpdumps, loads as rcploads import ccflow import ccflow._flow_model_binding as flow_binding_module import ccflow.flow_model as flow_model_module from ccflow import ( + BaseModel, CallableModel, ContextBase, DateRangeContext, @@ -43,6 +44,24 @@ class SimpleContext(ContextBase): value: int +class ExternalPydanticPayload(PydanticBaseModel): + x: int + _bonus: int = PrivateAttr(default=1) + + +class ExternalCcflowPayload(BaseModel): + x: int + _bonus: int = PrivateAttr(default=1) + + +T = TypeVar("T") + + +class ExternalGenericCcflowBox(BaseModel, Generic[T]): + item: T + _bonus: int = PrivateAttr(default=1) + + class ParentRangeContext(ContextBase): start_date: date end_date: date @@ -63,6 +82,12 @@ def _validate_order(self): return self +@pytest.fixture +def local_ray_runtime(): + with ray.init(num_cpus=1): + yield + + @Flow.model def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> GenericResult[int]: return GenericResult(value=value * multiplier) @@ -263,6 +288,9 @@ def test_context_transform_internal_error_and_repr_payloads(): with pytest.raises(TypeError, match="payload does not contain"): flow_model_module._load_serialized_context_transform_config(invalid_payload) + with pytest.raises(TypeError, match="payload does not contain"): + flow_model_module._load_serialized_context_transform_config("not-base64") + with pytest.raises(ImportError, match="does not have a _generated_model"): flow_model_module._restore_generated_flow_model("ccflow.tests.test_flow_model.lazy_context_transform_for_rejection", {}) @@ -1054,7 +1082,7 @@ def test_lazy_runtime_helper_is_removed(): def source(value: FromContext[int]) -> GenericResult[int]: return GenericResult(value=value) - with pytest.raises(TypeError, match="Lazy\\(model\\)\\(\\.\\.\\.\\) has been removed"): + with pytest.raises(TypeError, match="Lazy is an annotation marker"): Lazy(source()) @@ -1066,6 +1094,32 @@ def bad(x: Lazy[FromContext[int]]) -> int: return x() +def test_bare_flow_marker_annotations_are_rejected(): + with pytest.raises(TypeError, match=r"FromContext\[T\]"): + + @Flow.model + def bad_context(x: FromContext) -> int: + return 1 + + with pytest.raises(TypeError, match=r"Lazy\[T\]"): + + @Flow.model + def bad_lazy(x: Lazy) -> int: + return x() + + with pytest.raises(TypeError, match=r"FromContext\[T\]"): + + @Flow.context_transform + def bad_transform_context(x: FromContext) -> int: + return 1 + + with pytest.raises(TypeError, match=r"Lazy\[T\]"): + + @Flow.context_transform + def bad_transform_lazy(x: Lazy) -> int: + return 1 + + def test_auto_wrap_and_serialization_roundtrip(): @Flow.model def add(a: int, b: FromContext[int]) -> int: @@ -1163,7 +1217,57 @@ def first(xs: list[GenericResult[int]], b: FromContext[int]) -> int: assert restored.flow.compute(b=2).value == 3 -def test_bound_model_plain_pickle_handles_context_transform_generic_result_bound_args(): +def test_generated_model_plain_pickle_preserves_external_pydantic_private_state(): + @Flow.model + def read(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalPydanticPayload(x=2) + payload._bonus = 40 + + restored = pickle.loads(pickle.dumps(read(payload=payload), protocol=5)) + + assert restored.payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_generated_model_pickle_preserves_external_ccflow_private_state(): + @Flow.model + def read(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalCcflowPayload(x=2) + payload._bonus = 40 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(read(payload=payload), protocol=5)) + + assert isinstance(restored.payload, ExternalCcflowPayload) + assert restored.payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_local_generated_model_cloudpickle_preserves_local_pydantic_literal_state(): + def make_model(): + class LocalPayload(PydanticBaseModel): + x: int + _bonus: int = PrivateAttr(default=1) + + @Flow.model + def read(payload: object) -> int: + return payload.x + payload._bonus + + payload = LocalPayload(x=2) + payload._bonus = 40 + return read(payload=payload) + + restored = rcploads(rcpdumps(make_model(), protocol=5)) + + assert restored.payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_cloudpickle_handles_context_transform_generic_result_bound_args(): @Flow.model def source(a: FromContext[int]) -> int: return a @@ -1173,32 +1277,124 @@ def fixed(value: GenericResult[int]) -> int: return value.value bound = source().flow.with_context(a=fixed(value=GenericResult(value=5))) - restored = pickle.loads(pickle.dumps(bound, protocol=5)) + restored = rcploads(rcpdumps(bound, protocol=5)) assert restored.flow.compute().value == 5 -def test_local_generated_model_plain_pickle_bytes_in_ray_handles_generic_result_state(): - def make_model(): - @Flow.model - def first(xs: list[GenericResult[int]], b: FromContext[int]) -> int: - return xs[0].value + b +def test_generated_model_cloudpickle_preserves_user_generic_private_state(): + @Flow.model + def read(box: object) -> int: + return box.item + box._bonus + + box = ExternalGenericCcflowBox[int](item=2) + box._bonus = 40 + restored = rcploads(rcpdumps(read(box=box), protocol=5)) + + assert type(restored.box) is type(box) + assert restored.box._bonus == 40 + assert restored.flow.compute().value == 42 + - return first(xs=[GenericResult(value=10)]) +def test_bound_model_pickle_preserves_external_pydantic_static_context_value(): + @Flow.model + def read(payload: FromContext[object]) -> int: + return payload.x + payload._bonus + + payload = ExternalPydanticPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(payload=payload) + assert bound.flow.compute().value == 42 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.value + + assert isinstance(restored_payload, ExternalPydanticPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_pickle_preserves_external_pydantic_context_transform_bound_arg(): + @Flow.model + def read(value: FromContext[int]) -> int: + return value + + @Flow.context_transform + def derive(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalPydanticPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(value=derive(payload=payload)) + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.bound_args["payload"] + + assert isinstance(restored_payload, ExternalPydanticPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_pickle_preserves_external_ccflow_static_context_value(): + @Flow.model + def read(payload: FromContext[object]) -> int: + return payload.x + payload._bonus + + payload = ExternalCcflowPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(payload=payload) + assert bound.flow.compute().value == 42 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.value + + assert isinstance(restored_payload, ExternalCcflowPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_pickle_preserves_external_ccflow_context_transform_bound_arg(): + @Flow.model + def read(value: FromContext[int]) -> int: + return value + + @Flow.context_transform + def derive(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalCcflowPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(value=derive(payload=payload)) + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.bound_args["payload"] + + assert isinstance(restored_payload, ExternalCcflowPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_ray_cloudpickle_preserves_user_generic_private_state_in_generated_model(local_ray_runtime): + @Flow.model + def read(box: object) -> int: + return box.item + box._bonus + + box = ExternalGenericCcflowBox[int](item=2) + box._bonus = 40 + model = read(box=box) @ray.remote class Runner: def run(self, payload): - model = pickle.loads(payload) - context = FlowContext(b=3) - before = cache_key(model.__call__.get_evaluation_context(model, context), effective=True) - value = model.flow.compute(context).value - after = cache_key(model.__call__.get_evaluation_context(model, context), effective=True) - return value, before == after + restored = rcploads(payload) + return type(restored.box).__name__, restored.box._bonus, restored.flow.compute().value - with ray.init(num_cpus=1): - runner = Runner.remote() - assert ray.get(runner.run.remote(pickle.dumps(make_model(), protocol=5))) == (13, True) + runner = Runner.remote() + assert ray.get(runner.run.remote(rcpdumps(model, protocol=5))) == ("ExternalGenericCcflowBox[int]", 40, 42) def test_importable_generated_model_plain_pickle_cross_process_handles_generic_result_state(tmp_path, monkeypatch): @@ -1757,6 +1953,20 @@ def shift_from_anchor(anchor: date, days: FromContext[int]) -> date: assert restored.flow.compute(days=2).value == "2024-01-03" +def test_context_transform_json_roundtrip_reports_malformed_payload_cleanly(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=increment_b(amount=1)) + dumped = bound.model_dump(mode="python") + dumped["context_spec"]["operations"][0]["spec"]["serialized_config"] = "not-base64" + + restored = type(bound).model_validate(dumped) + with pytest.raises(TypeError, match="payload does not contain"): + restored.flow.compute(b=4) + + def test_context_transform_supports_nested_functions_with_serialized_payload(): @Flow.model def add(a: int, b: FromContext[int]) -> int: @@ -1794,7 +2004,7 @@ def main_transform(value: FromContext[int]) -> int: assert restored.flow.compute(value=4).value == 6 -def test_context_transform_nested_function_survives_ray_task(): +def test_context_transform_nested_function_survives_ray_task(local_ray_runtime): @Flow.model def add(a: int, b: FromContext[int]) -> int: return a + b @@ -1809,8 +2019,7 @@ def nested_transform(b: FromContext[int], amount: int) -> int: def run_model(model): return model.flow.compute(b=4).value - with ray.init(num_cpus=1): - assert ray.get(run_model.remote(bound)) == 8 + assert ray.get(run_model.remote(bound)) == 8 def test_with_context_rejects_raw_callables(): @@ -2174,7 +2383,7 @@ def test_context_transform_factory_signature_only_exposes_regular_bindings(): assert sig.parameters["amount"].kind is inspect.Parameter.KEYWORD_ONLY assert sig.parameters["amount"].annotation is int assert sig.parameters["amount"].default is inspect.Parameter.empty - assert sig.return_annotation is flow_model_module.ContextTransform + assert sig.return_annotation is inspect.Signature.empty with pytest.raises(TypeError, match="positional"): increment_b(1) @@ -2267,6 +2476,120 @@ def __call__(self, context: SimpleContext) -> GenericResult[tuple[int, str]]: assert first_key != second_key +def test_plain_callable_flow_compute_uses_default_context_when_available(): + class DefaultContext(ContextBase): + value: int + tag: str = "default" + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = DefaultContext(value=7)) -> GenericResult[tuple[int, str]]: + return GenericResult(value=(context.value, context.tag)) + + model = PlainModel() + + assert model.flow.required_inputs == {} + assert model.flow.compute().value == (7, "default") + assert model.flow.compute(value=3).value == (3, "default") + assert model.flow.compute(tag="runtime").value == (7, "runtime") + + empty_bound = model.flow.with_context() + assert empty_bound.flow.required_inputs == {} + assert empty_bound.flow.compute().value == (7, "default") + + bound = model.flow.with_context(tag="bound") + assert bound.flow.required_inputs == {} + assert bound.flow.compute().value == (7, "bound") + assert bound.flow.compute(value=3).value == (3, "bound") + + +def test_plain_callable_default_context_private_state_is_preserved(): + class DefaultContext(ContextBase): + value: int + _bonus: int = PrivateAttr(default=1) + + default_context = DefaultContext(value=7) + default_context._bonus = 40 + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = default_context) -> GenericResult[tuple[int, int, bool]]: + return GenericResult(value=(context.value, context._bonus, context is default_context)) + + model = PlainModel() + + assert model.flow.compute().value == (7, 40, True) + assert model.flow.compute(value=8).value == (8, 40, False) + assert model.flow.with_context().flow.compute().value == (7, 40, True) + assert model.flow.with_context(value=8).flow.compute().value == (8, 40, False) + + +def test_bound_plain_callable_flow_compute_uses_default_context_for_dynamic_transforms(): + class DefaultContext(ContextBase): + value: int + seed: int + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = DefaultContext(value=7, seed=8)) -> GenericResult[tuple[int, int]]: + return GenericResult(value=(context.value, context.seed)) + + @Flow.context_transform + def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + bound = PlainModel().flow.with_context(value=from_seed()) + + assert bound.flow.required_inputs == {} + assert bound.flow.runtime_inputs == {"seed": int} + assert bound.flow.compute().value == (9, 8) + assert bound.flow.compute(seed=10).value == (11, 10) + + +def test_bound_plain_callable_dependency_uses_default_context_baseline(): + class DefaultContext(ContextBase): + value: int + seed: int + _bonus: int = PrivateAttr(default=1) + + default_context = DefaultContext(value=7, seed=8) + default_context._bonus = 40 + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = default_context) -> GenericResult[tuple[int, int, int]]: + return GenericResult(value=(context.value, context.seed, context._bonus)) + + @Flow.context_transform + def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + @Flow.model + def consume(x: tuple[int, int, int]) -> tuple[int, int, int]: + return x + + static_bound = PlainModel().flow.with_context(value=3) + assert static_bound.flow.compute().value == (3, 8, 40) + assert consume(x=static_bound).flow.compute().value == (3, 8, 40) + + dynamic_bound = PlainModel().flow.with_context(value=from_seed()) + assert dynamic_bound.flow.compute().value == (9, 8, 40) + assert consume(x=dynamic_bound).flow.compute().value == (9, 8, 40) + assert consume(x=dynamic_bound).flow.compute(seed=10).value == (11, 10, 40) + + model = consume(x=dynamic_bound) + graph = get_dependency_graph(model.__call__.get_evaluation_context(model, model.context_type())) + plain_contexts = [] + for evaluation_context in graph.ids.values(): + while isinstance(evaluation_context.context, ModelEvaluationContext): + evaluation_context = evaluation_context.context + if isinstance(evaluation_context.model, PlainModel): + plain_contexts.append(evaluation_context.context) + + assert plain_contexts == [DefaultContext(value=9, seed=8)] + assert plain_contexts[0]._bonus == 40 + + def test_unhashable_annotations_still_validate(): annotation = Annotated[int, []] @@ -3858,3 +4181,4 @@ def test_flow_model_public_exports_exclude_context_spec_models(): assert not hasattr(ccflow, "StaticValueSpec") assert not hasattr(ccflow, "ContextTransform") assert not hasattr(ccflow, "flow_context_transform") + assert not hasattr(flow_model_module, "flow_context_transform") diff --git a/docs/wiki/Flow-Model.md b/docs/wiki/Flow-Model.md index 0007e38..425e151 100644 --- a/docs/wiki/Flow-Model.md +++ b/docs/wiki/Flow-Model.md @@ -126,15 +126,22 @@ They can be satisfied by: - runtime context, - construction-time keyword arguments, stored as contextual defaults on the model instance, +- keyword callable literals for `FromContext[Callable[..., T]]` fields, - function defaults. -They cannot be satisfied by `CallableModel` values. +Construction-time contextual defaults cannot be `CallableModel` values, because +that would be ambiguous with dependency binding. If the contextual field itself +is typed to accept a `CallableModel`, pass the model as runtime context or via +`.flow.with_context(...)` as ordinary data. A raw callable passed positionally is +always treated as a regular argument candidate, not as a contextual default. In +other words, `FromContext[Callable[..., T]]` allows a callable literal only when +it is provided by keyword for that contextual field. A construction-time value for a contextual parameter is still a default, not a conversion into a regular bound parameter. Generated models reserve a few framework attribute names for the model API: -`flow`, `meta`, `context_type`, and `result_type`. Do not use these as +`flow`, `meta`, `context_type`, `result_type`, and `type_`. Do not use these as `@Flow.model` function parameter names. ```python @@ -170,6 +177,12 @@ For generated `@Flow.model` stages it accepts either: It does not accept both at the same time. +Plain handwritten `CallableModel` instances also expose `.flow.compute(...)`. +For those models, keyword arguments build or update the runtime context. If the +decorated `@Flow.call` method declares a default context object, no-argument +`.flow.compute()` uses that default, and keyword arguments override fields from +that default for the `.flow.compute(...)` call. + ```python from ccflow import Flow, FlowContext, FromContext @@ -342,7 +355,11 @@ Key rules: - `with_context()` only targets contextual fields, - positional arguments must be patch transforms, - keyword overrides may be literals or field transforms, -- raw anonymous callables are rejected; use named `@Flow.context_transform` helpers, +- raw positional callables are rejected; use named `@Flow.context_transform` + helpers for positional patch transforms, +- keyword callable literals are allowed only when the target field is typed as + `FromContext[Callable[..., T]]`; other keyword callables must be field + transforms, - transforms are branch-local — they only affect the wrapped dependency, not the entire pipeline, - patch results merge left-to-right, then keyword overrides apply last, @@ -350,12 +367,13 @@ Key rules: multiple fields must move together, put that logic inside one patch transform. -Context transforms serialize the analyzed transform contract directly so bound -models can move through pickle and Ray workers without re-resolving annotations -from the defining module. This applies to importable module functions, local -functions, nested functions, `__main__`, and notebook-defined transforms. For -long-lived YAML/JSON configuration, inspect the generated config shape before -treating it as a stable hand-written config format. +Context transforms serialize the analyzed transform contract directly in +`serialized_config` so bound models can move through pickle and Ray workers +without re-resolving annotations from the defining module. This applies to +importable module functions, local functions, nested functions, `__main__`, and +notebook-defined transforms. Treat that `serialized_config` as an opaque +generated artifact owned by ccflow, not as a stable hand-written YAML/JSON +configuration format. ## `context_type=...` @@ -409,8 +427,10 @@ For bound models, `with_context(...)` bindings are reflected in `runtime_inputs`, `required_inputs`, and `bound_inputs`. Literal bindings satisfy their target fields. Transform bindings with runtime inputs add those source context inputs to the effective runtime view. Static transforms, meaning -transforms whose inputs are already available, may be evaluated during -introspection so their output fields can be reported precisely. +transforms with no contextual parameters, may be evaluated during introspection +so their output fields can be reported precisely. A transform parameter like +`seed: FromContext[int] = 0` is still a runtime input; its default only means the +caller is not required to provide it. `required_inputs` is always the required subset of `runtime_inputs`; if multiple bindings expose the same runtime context field with conflicting annotations, introspection raises an error instead of silently choosing one. @@ -531,9 +551,9 @@ the caller to provide that field. Use `runtime_inputs` to see the effective direct runtime context inputs after `with_context(...)` bindings. Use `required_inputs` to see what still must be -provided by the caller. Static transforms may be evaluated during introspection, -so their output fields can be removed from `runtime_inputs` and -`required_inputs` or added to `bound_inputs`. +provided by the caller. Static transforms with no contextual parameters may be +evaluated during introspection, so their output fields can be removed from +`runtime_inputs` and `required_inputs` or added to `bound_inputs`. **A shared dependency runs more than once** From 7c620e2e0420c0347ce110440d83b39be59533d1 Mon Sep 17 00:00:00 2001 From: Nijat K Date: Fri, 15 May 2026 04:46:49 -0400 Subject: [PATCH 6/8] Simplify tests, harden detecting nested/local @Flow.model Signed-off-by: Nijat K --- ccflow/flow_model.py | 68 +- ccflow/tests/config/conf_flow.yaml | 29 - ccflow/tests/flow_model_hydra_fixtures.py | 51 +- ccflow/tests/test_flow_model.py | 1066 ++++----------------- ccflow/tests/test_flow_model_hydra.py | 101 +- 5 files changed, 236 insertions(+), 1079 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 2fec4b9..efe03f8 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -43,6 +43,7 @@ 8. Generated model class construction and decorators. """ +import importlib import inspect import sys from abc import update_abstractmethods @@ -70,6 +71,7 @@ get_type_hints, ) +import cloudpickle from pydantic import BaseModel as PydanticModel, Field, PrivateAttr, SkipValidation, TypeAdapter, ValidationError, create_model, model_validator from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation @@ -248,7 +250,6 @@ class _GeneratedModelIdentity(NamedTuple): class _LocalFlowModelPicklePayload(NamedTuple): serialized_config: Any factory_kwargs: Dict[str, Any] - state_bytes: bytes # --------------------------------------------------------------------------- @@ -492,16 +493,12 @@ def _ensure_named_python_function(fn: _AnyCallable, *, decorator_name: str) -> N def _serialize_context_transform_config(config: _FlowModelConfig) -> str: - import cloudpickle - payload = cloudpickle.dumps(_serialize_flow_model_config(config), protocol=5) return b64encode(payload).decode("ascii") @lru_cache(maxsize=None) def _load_serialized_context_transform_config(serialized_config: str) -> _FlowModelConfig: - import cloudpickle - try: payload = cloudpickle.loads(b64decode(serialized_config.encode("ascii"))) config = _restore_flow_model_config(payload) @@ -537,44 +534,33 @@ def _is_mapping_annotation(annotation: Any) -> bool: return False -def _restore_model_from_cloudpickled_state(cls: type[BaseModel], state_bytes: bytes) -> BaseModel: - import cloudpickle - - obj = cls.__new__(cls) - obj.__setstate__(cloudpickle.loads(state_bytes)) - return obj - - -def _restore_pickled_local_flow_model(serialized_factory_payload: bytes) -> BaseModel: - import cloudpickle - +def _new_local_flow_model_for_pickle(serialized_factory_payload: bytes) -> BaseModel: payload = cloudpickle.loads(serialized_factory_payload) config = _restore_flow_model_config(payload.serialized_config) # Do not call ``flow_model(config.func, **factory_kwargs)`` here. That would - # re-run type-hint resolution in the receiving process, which is exactly the - # path that fails for local/postponed annotations and runtime-only generic - # aliases such as GenericResult[int]. The serialized config is the resolved - # contract from the defining process; rebuild the generated class from it. + # re-run worker-side type-hint resolution for local/postponed annotations. + # The serialized config is the resolved contract from the defining process; + # rebuild the generated class from it, then let pickle apply the third + # reducer element through ``__setstate__``. factory = _build_flow_model_factory_from_config(config, payload.factory_kwargs) cls = cast(type[BaseModel], getattr(factory, "_generated_model")) - return _restore_model_from_cloudpickled_state(cls, payload.state_bytes) + return cls.__new__(cls) -def _restore_generated_flow_model(factory_path: str, state_bytes: bytes) -> BaseModel: - """Restore a generated flow model by importing its factory function. +def _new_generated_flow_model_for_pickle(factory_path: str) -> BaseModel: + """Allocate a generated flow model by importing its factory function. This is the cross-process-safe restore path: importing the factory's module triggers the ``@Flow.model`` decorator, which re-creates the GeneratedModel - class. The instance is reconstructed from cloudpickled Pydantic state, not - validation data, so private attrs and ordinary user literals survive Ray - worker handoff. Runtime-created Pydantic generic specializations still need - a base serialization fix before they are portable as loose state. + class. The reducer returns Pydantic state separately so pickle applies + ``__setstate__`` in the outer pickle stream, preserving normal memo + semantics for shared references, cycles, and protocol-5 buffers. """ - factory = PyObjectPath(factory_path).object + factory = _load_module_attribute_uncached(factory_path) generated_cls = getattr(factory, "_generated_model", None) if generated_cls is None: raise ImportError(f"Cannot restore generated flow model: '{factory_path}' does not have a _generated_model attribute.") - return _restore_model_from_cloudpickled_state(generated_cls, state_bytes) + return generated_cls.__new__(generated_cls) def _is_importable_function(func: _AnyCallable) -> bool: @@ -591,13 +577,18 @@ def _importable_function_path(func: _AnyCallable) -> Optional[str]: return f"{func.__module__}.{func.__name__}" +def _load_module_attribute_uncached(path: str) -> Any: + module_name, attribute_name = path.rsplit(".", 1) + return getattr(importlib.import_module(module_name), attribute_name) + + def _generated_model_factory_path_for_pickle(config: _FlowModelConfig, generated_cls: type) -> Optional[str]: path = _importable_function_path(config.func) if path is None: return None try: - factory = PyObjectPath(path).object - except ImportError: + factory = _load_module_attribute_uncached(path) + except (ImportError, AttributeError, ValueError): return None if getattr(factory, "_generated_model", None) is generated_cls: return path @@ -2085,27 +2076,22 @@ def __reduce__(self): """Prefer import-path restoration, falling back to serialized local factories.""" config = type(self).__flow_model_config__ - import cloudpickle - state_bytes = cloudpickle.dumps(self.__getstate__(), protocol=5) + state = self.__getstate__() factory_path = _generated_model_factory_path_for_pickle(config, type(self)) if factory_path is not None: - return (_restore_generated_flow_model, (factory_path, state_bytes)) + return (_new_generated_flow_model_for_pickle, (factory_path,), state) # Local generated classes are not normal importable class definitions: # plain pickle cannot reconstruct them, and Ray workers should not # re-run fragile annotation resolution. Carry the analyzed contract - # separately, but keep instance state as Pydantic pickle state so user - # literals and private attrs survive cloudpickle. Runtime-created - # Pydantic generic specializations such as GenericResult[int] still need - # a base serialization fix before they are portable across fresh - # processes when they appear as loose user state. + # separately, but leave instance state in the outer pickle stream so + # normal pickle memo semantics remain intact. payload = _LocalFlowModelPicklePayload( serialized_config=_serialize_flow_model_config(config), factory_kwargs=type(self).__flow_model_factory_kwargs__, - state_bytes=state_bytes, ) - return (_restore_pickled_local_flow_model, (cloudpickle.dumps(payload, protocol=5),)) + return (_new_local_flow_model_for_pickle, (cloudpickle.dumps(payload, protocol=5),), state) @model_validator(mode="after") def _validate_flow_model_fields(self): diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml index 51fc6ba..7963d5f 100644 --- a/ccflow/tests/config/conf_flow.yaml +++ b/ccflow/tests/config/conf_flow.yaml @@ -5,11 +5,6 @@ flow_loader: source: fixture_input multiplier: 5 -flow_processor: - _target_: ccflow.tests.flow_model_hydra_fixtures.string_processor - prefix: "value=" - suffix: "!" - flow_source: _target_: ccflow.tests.flow_model_hydra_fixtures.data_source base_value: 100 @@ -19,20 +14,6 @@ flow_transformer: source: flow_source factor: 3 -flow_stage1: - _target_: ccflow.tests.flow_model_hydra_fixtures.pipeline_stage1 - initial: 10 - -flow_stage2: - _target_: ccflow.tests.flow_model_hydra_fixtures.pipeline_stage2 - stage1_output: flow_stage1 - multiplier: 2 - -flow_stage3: - _target_: ccflow.tests.flow_model_hydra_fixtures.pipeline_stage3 - stage2_output: flow_stage2 - offset: 50 - diamond_source: _target_: ccflow.tests.flow_model_hydra_fixtures.data_source base_value: 10 @@ -53,16 +34,6 @@ diamond_aggregator: input_b: diamond_branch_b operation: add -flow_date_loader: - _target_: ccflow.tests.flow_model_hydra_fixtures.date_range_loader_previous_day - source: calendar_feed - include_weekends: false - -flow_date_processor: - _target_: ccflow.tests.flow_model_hydra_fixtures.date_range_processor - raw_data: flow_date_loader - normalize: true - contextual_loader_model: _target_: ccflow.tests.flow_model_hydra_fixtures.contextual_loader source: data_source diff --git a/ccflow/tests/flow_model_hydra_fixtures.py b/ccflow/tests/flow_model_hydra_fixtures.py index f99baf4..2ebb022 100644 --- a/ccflow/tests/flow_model_hydra_fixtures.py +++ b/ccflow/tests/flow_model_hydra_fixtures.py @@ -1,12 +1,8 @@ """Flow.model fixtures used by Hydra integration tests.""" -from datetime import date, timedelta +from datetime import date -from ccflow import ContextBase, Flow, FromContext, GenericResult - - -class SimpleContext(ContextBase): - value: int +from ccflow import Flow, FromContext, GenericResult @Flow.model @@ -14,11 +10,6 @@ def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> Gener return GenericResult(value=value * multiplier) -@Flow.model -def string_processor(value: FromContext[int], prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: - return GenericResult(value=f"{prefix}{value}{suffix}") - - @Flow.model def data_source(base_value: int, value: FromContext[int]) -> GenericResult[int]: return GenericResult(value=value + base_value) @@ -36,44 +27,6 @@ def data_aggregator(input_a: int, input_b: int, operation: str = "add") -> Gener raise ValueError(f"unsupported operation: {operation}") -@Flow.model -def pipeline_stage1(initial: int, value: FromContext[int]) -> GenericResult[int]: - return GenericResult(value=value + initial) - - -@Flow.model -def pipeline_stage2(stage1_output: int, multiplier: int) -> GenericResult[int]: - return GenericResult(value=stage1_output * multiplier) - - -@Flow.model -def pipeline_stage3(stage2_output: int, offset: int) -> GenericResult[int]: - return GenericResult(value=stage2_output + offset) - - -@Flow.model -def date_range_loader_previous_day( - source: str, - start_date: FromContext[date], - end_date: FromContext[date], - include_weekends: bool = False, -) -> GenericResult[dict]: - del include_weekends - return GenericResult( - value={ - "source": source, - "start_date": str(start_date - timedelta(days=1)), - "end_date": str(end_date), - } - ) - - -@Flow.model -def date_range_processor(raw_data: dict, normalize: bool = False) -> GenericResult[str]: - prefix = "normalized:" if normalize else "raw:" - return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") - - @Flow.model def contextual_loader(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: return GenericResult( diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index f30980f..261977a 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -10,7 +10,7 @@ from datetime import date, timedelta from pathlib import Path from types import ModuleType -from typing import Annotated, Any, Callable, Generic, Literal, Optional, TypeVar, get_args +from typing import Annotated, Any, Callable, Literal, Optional, get_args import pytest import ray @@ -18,7 +18,7 @@ from ray.cloudpickle import dumps as rcpdumps, loads as rcploads import ccflow -import ccflow._flow_model_binding as flow_binding_module +import ccflow._flow_model_binding as binding_module import ccflow.flow_model as flow_model_module from ccflow import ( BaseModel, @@ -54,14 +54,6 @@ class ExternalCcflowPayload(BaseModel): _bonus: int = PrivateAttr(default=1) -T = TypeVar("T") - - -class ExternalGenericCcflowBox(BaseModel, Generic[T]): - item: T - _bonus: int = PrivateAttr(default=1) - - class ParentRangeContext(ContextBase): start_date: date end_date: date @@ -272,29 +264,6 @@ def add(a: FromContext[int]) -> int: assert add().flow.with_context(static_patch()).flow.compute(a=1).value == 2 -def test_context_transform_internal_error_and_repr_payloads(): - assert flow_model_module._context_transform_repr(static_patch()) == "static_patch()" - assert flow_model_module._context_transform_repr(increment_b(amount=2)) == "increment_b(amount=2)" - assert flow_model_module._context_transform_repr(123) == "123" - importable_increment_b = Flow.context_transform(increment_b.__wrapped__) - importable_config = flow_model_module._load_context_transform_config_from_binding(importable_increment_b(amount=1)) - assert importable_config.func.__name__ == "increment_b" - assert flow_model_module._context_transform_identifier(importable_increment_b(amount=1)) == "increment_b" - - with pytest.raises(ValidationError, match="serialized_config"): - flow_model_module.ContextTransform() - - invalid_payload = base64.b64encode(pickle.dumps({"not": "a config"}, protocol=5)).decode("ascii") - with pytest.raises(TypeError, match="payload does not contain"): - flow_model_module._load_serialized_context_transform_config(invalid_payload) - - with pytest.raises(TypeError, match="payload does not contain"): - flow_model_module._load_serialized_context_transform_config("not-base64") - - with pytest.raises(ImportError, match="does not have a _generated_model"): - flow_model_module._restore_generated_flow_model("ccflow.tests.test_flow_model.lazy_context_transform_for_rejection", {}) - - def test_flow_model_rejects_invalid_decorator_targets(): with pytest.raises(TypeError): Flow.model(123) @@ -302,47 +271,7 @@ def test_flow_model_rejects_invalid_decorator_targets(): Flow.model(lambda: None) -def test_lazy_thunks_and_regular_resolution_edge_paths(): - calls = {"dependency": 0, "inner": 0} - - @Flow.model - def source(value: FromContext[int]) -> int: - calls["dependency"] += 1 - return value + 10 - - thunk = flow_model_module._make_lazy_thunk(source(), FlowContext(value=2)) - assert thunk() == 12 - assert thunk() == 12 - assert calls["dependency"] == 1 - - def inner(): - calls["inner"] += 1 - return "13" - - coercing = flow_model_module._make_coercing_lazy_thunk(inner, "value", int) - assert coercing() == 13 - assert coercing() == 13 - assert calls["inner"] == 1 - - @Flow.model - def missing_regular(x: int) -> int: - return x - - missing_config = type(missing_regular()).__flow_model_config__ - with pytest.raises(TypeError, match="still unbound"): - flow_model_module._resolve_regular_param_value(missing_regular(), missing_config.param("x"), FlowContext()) - - @Flow.model - def lazy_consumer(x: Lazy[int]) -> int: - return x() - - lazy_model = getattr(lazy_consumer, "_generated_model").model_construct(x=1) - lazy_config = type(lazy_model).__flow_model_config__ - with pytest.raises(TypeError, match="must be bound to a CallableModel"): - flow_model_module._resolve_regular_param_value(lazy_model, lazy_config.param("x"), FlowContext()) - - -def test_context_transform_validation_and_static_resolution_edge_paths(): +def test_context_transform_defaults_and_public_validation(): @Flow.context_transform def default_amount(amount: int = 5) -> int: return amount @@ -360,27 +289,7 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: return a + b assert add().flow.with_context(a=default_amount(), b=default_seed()).flow.compute().value == 15 - assert flow_model_module._evaluate_context_transform_from_values(default_seed(), {}) == 10 - with pytest.raises(TypeError, match="Missing contextual input"): - flow_model_module._evaluate_context_transform_from_values(seed_plus_one(), {}) - - dynamic_spec = flow_model_module._BoundContextSpec( - operations=[ - flow_model_module.PatchContextOperation( - binding=dynamic_patch(), - ) - ], - ) - assert flow_model_module._statically_resolved_context_values(add(), dynamic_spec) is None - assert flow_model_module._statically_resolved_context_field_names(add(), dynamic_spec) == set() - - identity_values, missing_transforms = flow_model_module._apply_context_spec_values_for_identity(add(), dynamic_spec, FlowContext(b=2)) - assert identity_values == {"b": 2} - assert missing_transforms == ((flow_model_module._context_transform_identifier(dynamic_patch()), ("seed",)),) - - missing_regular = default_amount() - config = flow_model_module._load_context_transform_config_from_binding(default_amount()) - assert flow_model_module._bound_context_transform_regular_kwargs(config, missing_regular) == {"amount": 5} + assert add().flow.with_context(dynamic_patch()).flow.compute(seed=1, b=2).value == 3 with pytest.raises(TypeError, match="unexpected keyword"): increment_b(amount=1, extra=2) @@ -389,103 +298,18 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: with pytest.raises(TypeError, match="missing required regular"): increment_b() - with pytest.raises(TypeError, match="must return a mapping"): - flow_model_module._validate_patch_result(add(), 1) - with pytest.raises(TypeError, match="string field names"): - flow_model_module._validate_patch_result(add(), {1: 2}) - - class OpaqueModel: - context_type = object - - assert flow_model_module._validate_patch_result(OpaqueModel(), {"x": 1}) == {"x": 1} - flow_model_module._validate_with_context_field_names(OpaqueModel(), ["anything"]) - assert flow_model_module._evaluate_static_context_transform(default_amount()) == 5 - with pytest.raises(TypeError, match="Positional with_context"): add().flow.with_context(lambda: {"a": 1}) with pytest.raises(TypeError, match="Positional with_context"): add().flow.with_context(123) -def test_additional_flow_model_source_edge_paths(monkeypatch): - @Flow.context_transform - def default_seed(seed: FromContext[int] = 9) -> int: - return seed + 1 - - @Flow.context_transform - def dynamic_patch(seed: FromContext[int]) -> dict[str, object]: - return {"a": seed} - - @Flow.model - def add(a: FromContext[int], b: FromContext[int]) -> int: - return a + b - - @Flow.model - def regular_required(x: int) -> int: - return x - - @Flow.model - def lazy_consumer(x: Lazy[int]) -> int: - return x() - - class FailingPath: - def __init__(self, path): - self.path = path - - @property - def object(self): - raise ImportError(self.path) - - generated_name = "_basic_loader_Model" - original_generated = getattr(sys.modules[__name__], generated_name) - try: - with monkeypatch.context() as path_patch: - path_patch.setattr(flow_model_module, "PyObjectPath", FailingPath) - restored = pickle.loads(pickle.dumps(basic_loader(source="s", multiplier=2), protocol=5)) - finally: - setattr(sys.modules[__name__], generated_name, original_generated) - assert restored.flow.compute(value=3).value == 6 - - bound = add().flow.with_context(dynamic_patch()) - assert bound.flow.context_inputs == {"a": int, "b": int} - assert bound.flow.runtime_inputs == {"a": int, "b": int, "seed": int} - assert bound.flow.required_inputs == {"a": int, "b": int, "seed": int} - - with pytest.raises(TypeError, match="missing required regular"): - flow_model_module._bound_context_transform_regular_kwargs( - flow_model_module._load_context_transform_config_from_binding(increment_b(amount=1)), - increment_b(amount=1).model_copy(update={"bound_args": {}}), - ) - with pytest.raises(TypeError, match="Missing regular parameter"): - regular_required().__deps__(FlowContext()) - assert lazy_consumer(x=data_source(base_value=1)).__deps__(FlowContext(value=1)) == [] - - def transform_with_bad_hints(value: FromContext[int]) -> int: - return value - - def raise_attribute_error(*args, **kwargs): - raise AttributeError("bad hints") - - monkeypatch.setattr(flow_model_module, "get_type_hints", raise_attribute_error) - with pytest.raises(AttributeError, match="bad hints"): - Flow.context_transform(transform_with_bad_hints) - - -def test_plain_and_bound_optional_compute_paths_and_identity_helpers(): - class AnyContextModel: - context_type = object - - class FlowContextModel: - context_type = FlowContext - +def test_plain_and_bound_optional_compute_paths(): class OptionalContextModel(CallableModel): @Flow.call def __call__(self, context: Optional[SimpleContext] = None) -> GenericResult[int]: return GenericResult(value=0 if context is None else context.value) - assert flow_model_module._model_context_contract(AnyContextModel()).input_types is None - assert flow_model_module._model_context_contract(FlowContextModel()).input_types is None - assert flow_model_module._identity_context_values_for_model_values(AnyContextModel(), {"extra": 1}) == {"extra": 1} assert OptionalContextModel().flow.compute(None).value == 0 assert OptionalContextModel().flow.compute().value == 0 assert OptionalContextModel().flow.required_inputs == {} @@ -494,7 +318,6 @@ def __call__(self, context: Optional[SimpleContext] = None) -> GenericResult[int assert bound.flow.compute(FlowContext(value=3)).value == 3 with pytest.raises(TypeError, match="either one context object"): bound.flow.compute(FlowContext(value=3), value=4) - assert bound.flow._compute_target is bound def test_bound_optional_none_context_preserves_wrapped_dependencies(): @@ -694,16 +517,6 @@ def foo(a: int, b: FromContext[int] = 5) -> int: assert model.flow.compute(b=10).value == 12 -def test_compute_rejects_kwargs_for_already_bound_regular_params(): - @Flow.model - def add(a: int, b: FromContext[int]) -> int: - return a + b - - model = add(a=1) - with pytest.raises(TypeError, match="does not accept regular parameter override\\(s\\): a"): - model.flow.compute(a=999, b=2) - - def test_context_type_accepts_richer_subclass_for_from_context(): @Flow.model(context_type=ParentRangeContext) def span_days(multiplier: int, start_date: FromContext[date], end_date: FromContext[date]) -> int: @@ -1077,15 +890,6 @@ def choose(lazy_value: Lazy[int] = 1) -> int: return lazy_value() -def test_lazy_runtime_helper_is_removed(): - @Flow.model - def source(value: FromContext[int]) -> GenericResult[int]: - return GenericResult(value=value) - - with pytest.raises(TypeError, match="Lazy is an annotation marker"): - Lazy(source()) - - def test_lazy_and_from_context_combination_is_rejected(): with pytest.raises(TypeError, match="cannot combine Lazy"): @@ -1173,19 +977,6 @@ def multiply(a: int, b: FromContext[int]) -> int: assert restored.flow.compute(b=7).value == 42 -def test_generated_models_cloudpickle_preserves_unset_validation_sentinel(): - @Flow.model - def multiply(a: int, b: FromContext[int]) -> int: - return a * b - - model = multiply(a=6) - restored = rcploads(rcpdumps(model, protocol=5)) - param = type(restored).__flow_model_config__.contextual_params[0] - - assert param.context_validation_annotation is flow_model_module._UNSET - assert param.validation_annotation is int - - def test_local_generated_model_effective_cache_key_survives_pickle_roundtrip(): def make_model(): @Flow.model @@ -1204,19 +995,6 @@ def add(a: int, b: FromContext[int]) -> int: assert after == before -def test_local_generated_model_plain_pickle_handles_generic_result_state(): - def make_model(): - @Flow.model - def first(xs: list[GenericResult[int]], b: FromContext[int]) -> int: - return xs[0].value + b - - return first(xs=[GenericResult(value=1)]) - - restored = pickle.loads(pickle.dumps(make_model(), protocol=5)) - - assert restored.flow.compute(b=2).value == 3 - - def test_generated_model_plain_pickle_preserves_external_pydantic_private_state(): @Flow.model def read(payload: object) -> int: @@ -1247,6 +1025,22 @@ def read(payload: object) -> int: assert restored.flow.compute().value == 42 +def test_generated_model_pickle_preserves_outer_graph_identity_and_cycles(): + @Flow.model + def read(payload: object) -> int: + return 0 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + shared = [] + restored_model, restored_shared = loads(dumps((read(payload=shared), shared), protocol=5)) + assert restored_model.payload is restored_shared + + cycle = read(payload=None) + cycle.payload = cycle + restored_cycle = loads(dumps(cycle, protocol=5)) + assert restored_cycle.payload is restored_cycle + + def test_local_generated_model_cloudpickle_preserves_local_pydantic_literal_state(): def make_model(): class LocalPayload(PydanticBaseModel): @@ -1267,35 +1061,6 @@ def read(payload: object) -> int: assert restored.flow.compute().value == 42 -def test_bound_model_cloudpickle_handles_context_transform_generic_result_bound_args(): - @Flow.model - def source(a: FromContext[int]) -> int: - return a - - @Flow.context_transform - def fixed(value: GenericResult[int]) -> int: - return value.value - - bound = source().flow.with_context(a=fixed(value=GenericResult(value=5))) - restored = rcploads(rcpdumps(bound, protocol=5)) - - assert restored.flow.compute().value == 5 - - -def test_generated_model_cloudpickle_preserves_user_generic_private_state(): - @Flow.model - def read(box: object) -> int: - return box.item + box._bonus - - box = ExternalGenericCcflowBox[int](item=2) - box._bonus = 40 - restored = rcploads(rcpdumps(read(box=box), protocol=5)) - - assert type(restored.box) is type(box) - assert restored.box._bonus == 40 - assert restored.flow.compute().value == 42 - - def test_bound_model_pickle_preserves_external_pydantic_static_context_value(): @Flow.model def read(payload: FromContext[object]) -> int: @@ -1378,113 +1143,11 @@ def derive(payload: object) -> int: assert restored.flow.compute().value == 42 -def test_ray_cloudpickle_preserves_user_generic_private_state_in_generated_model(local_ray_runtime): - @Flow.model - def read(box: object) -> int: - return box.item + box._bonus - - box = ExternalGenericCcflowBox[int](item=2) - box._bonus = 40 - model = read(box=box) - - @ray.remote - class Runner: - def run(self, payload): - restored = rcploads(payload) - return type(restored.box).__name__, restored.box._bonus, restored.flow.compute().value - - runner = Runner.remote() - assert ray.get(runner.run.remote(rcpdumps(model, protocol=5))) == ("ExternalGenericCcflowBox[int]", 40, 42) - - -def test_importable_generated_model_plain_pickle_cross_process_handles_generic_result_state(tmp_path, monkeypatch): - module_dir = tmp_path / "generic_state_module" - module_dir.mkdir() - module_path = module_dir / "generic_state_mod.py" - module_path.write_text( - "\n".join( - [ - "from ccflow import Flow, FromContext, GenericResult", - "", - "@Flow.model", - "def first(xs: list[GenericResult[int]], b: FromContext[int]) -> int:", - " return xs[0].value + b", - "", - ] - ) - ) - monkeypatch.syspath_prepend(str(module_dir)) - - import generic_state_mod - - payload = base64.b64encode(pickle.dumps(generic_state_mod.first(xs=[GenericResult(value=10)]), protocol=5)).decode() - script = ( - "import base64, pickle, sys\n" - f"sys.path.insert(0, {str(module_dir)!r})\n" - f"model = pickle.loads(base64.b64decode({payload!r}))\n" - "assert model.flow.compute(b=3).value == 13\n" - ) - result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) - - assert result.returncode == 0, result.stderr - - -def test_generated_model_plain_pickle_preserves_generic_result_in_loose_state(): - @Flow.model - def use_any(x: Any, b: FromContext[int]) -> int: - return x.value + b - - @Flow.model - def use_list(xs: list[Any]) -> int: - return xs[0].value - - restored_any = pickle.loads(pickle.dumps(use_any(x=GenericResult(value=10)), protocol=5)) - restored_list = pickle.loads(pickle.dumps(use_list(xs=[GenericResult(value=11)]), protocol=5)) - - assert restored_any.flow.compute(b=3).value == 13 - assert isinstance(restored_any.x, GenericResult) - assert restored_list.flow.compute().value == 11 - assert isinstance(restored_list.xs[0], GenericResult) - - -def test_importable_generated_model_plain_pickle_cross_process_preserves_generic_result_in_loose_state(tmp_path, monkeypatch): - module_dir = tmp_path / "generic_any_state_module" - module_dir.mkdir() - module_path = module_dir / "generic_any_state_mod.py" - module_path.write_text( - "\n".join( - [ - "from typing import Any", - "from ccflow import Flow, FromContext", - "", - "@Flow.model", - "def use_any(x: Any, b: FromContext[int]) -> int:", - " return x.value + b", - "", - ] - ) - ) - monkeypatch.syspath_prepend(str(module_dir)) - - import generic_any_state_mod - - payload = base64.b64encode(pickle.dumps(generic_any_state_mod.use_any(x=GenericResult(value=10)), protocol=5)).decode() - script = ( - "import base64, pickle, sys\n" - f"sys.path.insert(0, {str(module_dir)!r})\n" - f"model = pickle.loads(base64.b64decode({payload!r}))\n" - "assert model.flow.compute(b=3).value == 13\n" - ) - result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) - - assert result.returncode == 0, result.stderr - - -def test_local_generated_model_plain_pickle_handles_generic_result_function_default(): +def test_local_generated_model_plain_pickle_handles_function_default_state(): def make_model(): @Flow.model - def first(xs: list[GenericResult[int]] = [GenericResult(value=1)], b: FromContext[int] = 2) -> int: - return xs[0].value + b + def first(xs: list[int] = [1], b: FromContext[int] = 2) -> int: + return xs[0] + b return first() @@ -1493,31 +1156,6 @@ def first(xs: list[GenericResult[int]] = [GenericResult(value=1)], b: FromContex assert restored.flow.compute().value == 3 -def test_unresolved_lazy_local_generated_dependency_identity_survives_pickle_roundtrip(): - def make_model(): - @Flow.model - def source(a: FromContext[int]) -> int: - return a - - @Flow.model - def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: - return lazy_value() if use_lazy else x - - return choose(x=7, lazy_value=source()) - - model = make_model() - context = FlowContext(use_lazy=False) - before_eval = model.__call__.get_evaluation_context(model, context) - before_key = cache_key(before_eval, effective=True) - before_root = get_dependency_graph(before_eval).root_id - - restored = pickle.loads(pickle.dumps(model, protocol=5)) - after_eval = restored.__call__.get_evaluation_context(restored, context) - - assert cache_key(after_eval, effective=True) == before_key - assert get_dependency_graph(after_eval).root_id == before_root - - def test_unresolved_lazy_nested_local_generated_dependency_identity_survives_pickle_roundtrip(): def make_model(): @Flow.model @@ -1729,6 +1367,47 @@ def stage(value: FromContext[int]) -> int: assert first.flow.compute(value=10).value == 11 +def test_generated_model_pickle_path_ignores_stale_pyobjectpath_cache(monkeypatch): + module = ModuleType("ccflow_test_stale_generated_pickle_path") + module.Flow = Flow + module.FromContext = FromContext + monkeypatch.setitem(sys.modules, module.__name__, module) + + exec( + """ +def foo(a: int, x: FromContext[int]) -> int: + return a + x + 1 +""", + module.__dict__, + ) + first_factory = Flow.model(module.foo) + module.foo = first_factory + first = first_factory(a=10) + path = f"{module.__name__}.foo" + + # Prime PyObjectPath/import_string's cache with the first factory. Pickle + # path selection must still inspect the live module attribute after a reload + # or replacement, otherwise old models can serialize by a path that a clean + # process resolves to different behavior. + assert PyObjectPath(path).object is first_factory + + exec( + """ +def foo(a: int, x: FromContext[int]) -> int: + return a + x + 100 +""", + module.__dict__, + ) + second_factory = Flow.model(module.foo) + module.foo = second_factory + + config = type(first).__flow_model_config__ + assert flow_model_module._generated_model_factory_path_for_pickle(config, type(first)) is None + assert first.__reduce__()[0] is flow_model_module._new_local_flow_model_for_pickle + assert pickle.loads(pickle.dumps(first, protocol=5)).flow.compute(x=1).value == 12 + assert second_factory(a=10).flow.compute(x=1).value == 111 + + def test_reloaded_importable_generated_model_keeps_clean_process_path(tmp_path, monkeypatch): module_dir = tmp_path / "reload_case" module_dir.mkdir() @@ -1824,20 +1503,6 @@ def test_reloaded_importable_generated_model_allows_stale_decorator_aliases(tmp_ assert reloaded.foo().flow.compute(x=2).value == 3 -def test_importable_generated_model_json_roundtrip_cross_process(): - model = basic_loader(source="library", multiplier=3) - payload = model.model_dump_json() - script = ( - "from ccflow import BaseModel\n" - f"data = {payload!r}\n" - "model = BaseModel.model_validate_json(data)\n" - "result = model.flow.compute(value=4)\n" - "assert result.value == 12, f'Expected 12, got {result.value}'\n" - ) - result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) - assert result.returncode == 0, f"Cross-process JSON roundtrip failed:\n{result.stderr}" - - def test_importable_bound_model_context_transform_json_roundtrip_cross_process(): model = basic_loader(source="library", multiplier=3).flow.with_context(value=increment_b(amount=3)) payload = model.model_dump_json() @@ -2133,11 +1798,6 @@ def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: result = bound(FlowContext(start_date=1, end_date=2)) assert result.value == 100_012 - dumped = bound.model_dump(mode="json") - assert dumped["context_spec"]["operations"][0]["binding"]["bound_args"] == {"amount": 10} - assert dumped["context_spec"]["operations"][1]["name"] == "start_date" - assert dumped["context_spec"]["operations"][1]["spec"]["kind"] == "static_value" - def test_chained_transforms_read_original_runtime_context(): @Flow.model @@ -2153,59 +1813,6 @@ def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: assert result.value == 101_012 -def test_graph_integration_fanout_fanin(): - @Flow.model - def source(base: int, value: FromContext[int]) -> int: - return value + base - - @Flow.model - def scale(data: int, factor: int) -> int: - return data * factor - - @Flow.model - def merge(left: int, right: int, bonus: FromContext[int]) -> int: - return left + right + bonus - - src = source(base=10) - left = scale(data=src, factor=2) - right = scale(data=src, factor=5) - model = merge(left=left, right=right) - - assert model.flow.compute(FlowContext(value=3, bonus=7)).value == ((3 + 10) * 2) + ((3 + 10) * 5) + 7 - - -def test_graph_integration_cycle_raises_cleanly(): - @Flow.model - def increment(x: int, n: FromContext[int]) -> int: - return x + n - - root = increment() - branch = increment(x=root) - object.__setattr__(root, "x", branch) - - with FlowOptionsOverride(options={"evaluator": GraphEvaluator()}): - with pytest.raises(graphlib.CycleError): - root.flow.compute(n=1) - - -def test_large_contextual_contract_stress(): - @Flow.model - def total( - base: int, - x1: FromContext[int], - x2: FromContext[int], - x3: FromContext[int], - x4: FromContext[int], - x5: FromContext[int], - x6: FromContext[int], - ) -> int: - return base + x1 + x2 + x3 + x4 + x5 + x6 - - model = total(base=10) - assert model.flow.context_inputs == {"x1": int, "x2": int, "x3": int, "x4": int, "x5": int, "x6": int} - assert model.flow.compute(x1=1, x2=2, x3=3, x4=4, x5=5, x6=6).value == 31 - - def test_registry_integration_for_generated_models(): registry = ModelRegistry.root().clear() model = basic_loader(source="library", multiplier=3) @@ -2278,18 +1885,6 @@ def bad(x: BrokenSchema, y: FromContext[int]) -> int: Flow.model(bad) -@pytest.mark.parametrize("error", [RuntimeError("boom"), TypeError("boom"), ValueError("boom")]) -def test_schema_safe_annotation_does_not_swallow_unexpected_type_adapter_errors(monkeypatch, error): - def broken_type_adapter(annotation): - del annotation - raise error - - monkeypatch.setattr(flow_model_module, "_type_adapter", broken_type_adapter) - - with pytest.raises(type(error), match="boom"): - flow_model_module._pydantic_schema_safe_annotation(int) - - @pytest.mark.parametrize("error", [RuntimeError("boom"), TypeError("boom")]) def test_unexpected_type_validation_errors_are_not_rewritten(error): from pydantic_core import core_schema @@ -2389,37 +1984,6 @@ def test_context_transform_factory_signature_only_exposes_regular_bindings(): increment_b(1) -def test_type_adapter_caches_are_bounded_and_clearable(monkeypatch): - monkeypatch.setattr(flow_model_module, "_TYPE_ADAPTER_CACHE_MAXSIZE", 2) - flow_model_module.clear_flow_model_caches() - - try: - for annotation in (int, str, float): - flow_model_module._type_adapter(annotation) - - assert list(flow_model_module._HASHABLE_TYPE_ADAPTER_CACHE) == [str, float] - - unhashable_annotations = ( - Annotated[int, []], - Annotated[str, []], - Annotated[float, []], - ) - repeated_unhashable = Annotated[bytes, []] - assert flow_model_module._type_adapter(repeated_unhashable) is flow_model_module._type_adapter(repeated_unhashable) - - for annotation in unhashable_annotations: - flow_model_module._type_adapter(annotation) - - assert len(flow_model_module._UNHASHABLE_TYPE_ADAPTER_CACHE) == 2 - assert [entry[0] for entry in flow_model_module._UNHASHABLE_TYPE_ADAPTER_CACHE.values()] == list(unhashable_annotations[-2:]) - - flow_model_module.clear_flow_model_caches() - assert not flow_model_module._HASHABLE_TYPE_ADAPTER_CACHE - assert not flow_model_module._UNHASHABLE_TYPE_ADAPTER_CACHE - finally: - flow_model_module.clear_flow_model_caches() - - def test_plain_callable_flow_api_paths(): class PlainModel(CallableModel): @property @@ -2600,6 +2164,79 @@ def add(x: annotation, y: FromContext[annotation]) -> int: assert add(x="2").flow.compute(y="3").value == 5 +def test_flow_model_internal_contract_helpers_cover_portable_edge_shapes(): + annotation = Annotated[dict[str, list[GenericResult[int]]], "contract"] + restored = binding_module._restore_annotation(binding_module._serialize_annotation(annotation)) + + assert get_args(restored)[1:] == ("contract",) + assert binding_module._restore_annotation(binding_module._serialize_annotation(Literal["a", "b"])) == Literal["a", "b"] + assert binding_module._restore_annotation(binding_module._serialize_annotation(int | None)) == Optional[int] + assert binding_module._clone_function_without_annotations(5) == 5 + assert repr(get_args(Lazy[int])[1]) == "Lazy" + assert repr(get_args(FromContext[int])[1]) == "FromContext" + with pytest.raises(TypeError, match="Lazy is an annotation marker"): + Lazy() + with pytest.raises(TypeError, match="Unknown serialized annotation payload"): + binding_module._restore_annotation(object()) + with pytest.raises(TypeError, match="Unknown serialized annotation payload kind"): + binding_module._restore_annotation(binding_module._SerializedAnnotation(kind="bad", value=None)) + with pytest.raises(TypeError, match="Unknown Flow.model parameter payload"): + binding_module._restore_flow_model_param(object()) + with pytest.raises(TypeError, match="Unknown Flow.model config payload"): + binding_module._restore_flow_model_config(object()) + + compatible = binding_module._context_type_annotations_compatible + assert compatible(Any, int) + assert compatible(int, Any) + assert compatible(int | str, int) + assert compatible(int | None, type(None)) + assert compatible(Literal["a", "b"], Literal["a"]) + assert compatible(str, Literal["a"]) + assert compatible(list[int], list[int]) + assert not compatible(int, int | None) + assert not compatible(Literal["a"], Literal["b"]) + assert not compatible(Literal["a"], str) + assert not compatible(list[int], list[str]) + + flow_model_module.clear_flow_model_caches() + unhashable_annotation = Annotated[int, []] + assert flow_model_module._type_adapter(unhashable_annotation).validate_python("1") == 1 + assert flow_model_module._type_adapter(unhashable_annotation).validate_python("2") == 2 + assert flow_model_module._concrete_context_type(Optional[SimpleContext]) is SimpleContext + assert flow_model_module._concrete_context_type(int) is None + assert flow_model_module._bound_field_names(object()) == set() + assert flow_model_module._expected_type_repr(int | str) == "int | str" + assert flow_model_module._coerce_value("payload", object(), object(), "Regular parameter").__class__ is object + assert flow_model_module._unwrap_model_result(3) == 3 + assert flow_model_module._resolve_registry_candidate("__missing_registry_entry__") is None + assert flow_model_module._registry_candidate_allowed(object(), "literal") + assert not flow_model_module._registry_candidate_allowed(int, "not-an-int") + assert flow_model_module._registry_candidate_allowed(int, "3") + with pytest.raises(ImportError, match="does not have a _generated_model"): + flow_model_module._new_generated_flow_model_for_pickle("ccflow.GenericResult") + with pytest.raises(TypeError, match="model_base must be a CallableModel subclass"): + flow_model_module._resolve_generated_model_bases(int) + assert flow_model_module._context_transform_identifier(static_patch()) == "static_patch" + assert flow_model_module._context_transform_repr(static_patch()) == "static_patch()" + assert flow_model_module._context_transform_repr(increment_b(amount=2)) == "increment_b(amount=2)" + assert flow_model_module._context_transform_repr(5) == "5" + + @Flow.model + def dep(value: FromContext[int]) -> int: + return value + + with pytest.raises(TypeError, match="must return a mapping"): + flow_model_module._validate_patch_result(dep(), 1) + with pytest.raises(TypeError, match="string field names"): + flow_model_module._validate_patch_result(dep(), {1: 2}) + lazy_thunk = flow_model_module._make_lazy_thunk(dep(), FlowContext(value=4)) + assert lazy_thunk() == 4 + assert lazy_thunk() == 4 + coercing_thunk = flow_model_module._make_coercing_lazy_thunk(lambda: "5", "value", int) + assert coercing_thunk() == 5 + assert coercing_thunk() == 5 + + def test_compute_accepts_context_object_for_from_context_models(): model = basic_loader(source="library", multiplier=3) @@ -2765,6 +2402,23 @@ def bad(vals: FromContext[list[int]]) -> int: return sum(vals) +def test_context_type_allows_compatible_union_literal_and_generic_fields(): + class RichContext(ContextBase): + flag: Literal["a"] + maybe_value: int | None + values: list[int] + + @Flow.model(context_type=RichContext) + def summarize( + flag: FromContext[Literal["a", "b"]], + maybe_value: FromContext[int | None], + values: FromContext[list[int]], + ) -> int: + return len(flag) + (maybe_value or 0) + sum(values) + + assert summarize().flow.compute(flag="a", maybe_value=None, values=[1, 2]).value == 4 + + def test_context_type_rejects_nullable_field_for_non_nullable_from_context(): class OptionalValueContext(ContextBase): value: int | None @@ -2776,44 +2430,6 @@ def add_one(value: FromContext[int]) -> int: return value + 1 -@pytest.mark.parametrize( - ("func_annotation", "context_annotation", "expected"), - [ - (int, int, True), - (int, str, False), - (int | None, int, True), - (int, int | None, False), - (int | None, int | None, True), - (list[int], list[str], False), - (list[int], list[int], True), - (int | str, int, True), - (int | str, int | None, False), - (int | None, type(None), True), - (int | None, int | str, False), - (int, int | str, False), - (object, int | str, True), - (Literal["a"], Literal["a"], True), - (Literal["a"], Literal["b"], False), - (str, Literal["a"], True), - (Literal["a", "b"], Literal["a"], True), - (Literal["a"], str, False), - (list[int], Literal["a"], False), - (Annotated[int, "meta"], int, True), - (dict[str, list[int]], dict[str, list[int]], True), - (dict[str, list[int]], dict[str, list[str]], False), - (list, list[int], False), - (tuple[int], tuple[int, str], False), - (Any, int, True), - (Any, str, True), - (int, Any, True), - (str, Any, True), - (Any, Any, True), - ], -) -def test_context_type_annotations_compatible_cases(func_annotation, context_annotation, expected): - assert flow_binding_module._context_type_annotations_compatible(func_annotation, context_annotation) is expected - - def test_compute_forwards_options_with_custom_evaluator(): calls = {"count": 0} @@ -2833,11 +2449,6 @@ def counter(value: FromContext[int]) -> int: assert calls["count"] == 1 -def test_unset_flow_input_pickle_roundtrip_preserves_singleton(): - restored = pickle.loads(pickle.dumps(flow_model_module._UNSET_FLOW_INPUT, protocol=5)) - assert restored is flow_model_module._UNSET_FLOW_INPUT - - def test_compute_forwards_options_with_graph_evaluator(): @Flow.model def source(value: FromContext[int]) -> int: @@ -3237,27 +2848,6 @@ def add(a: int, b: FromContext[int]) -> int: assert len(cache.cache) == 1 -def test_cache_key_effective_option_exposes_generated_model_identity(): - @Flow.model - def source(value: FromContext[int]) -> int: - return value * 10 - - @Flow.model - def root(x: int, bonus: FromContext[int]) -> int: - return x + bonus - - model = root(x=source()) - ctx1 = FlowContext(value=3, bonus=7, unused="one") - ctx2 = FlowContext(value=3, bonus=7, unused="two") - eval1 = model.__call__.get_evaluation_context(model, ctx1) - eval2 = model.__call__.get_evaluation_context(model, ctx2) - cache = MemoryCacheEvaluator() - - assert cache_key(eval1) != cache_key(eval2) - assert cache_key(eval1, effective=True) == cache_key(eval2, effective=True) - assert cache.key(eval1) == cache_key(eval1, effective=True) - - def test_effective_cache_key_ignores_untokenizable_unused_ambient_context(): class BadToken: def __deepcopy__(self, memo): @@ -3347,34 +2937,6 @@ def helper_v2(): assert key1 != key2 -def test_generated_model_effective_cache_key_includes_opaque_evaluator_behavior(): - class OpaqueA(EvaluatorBase): - tag: str = "same" - - def __call__(self, context: ModelEvaluationContext): - return context() - - class OpaqueB(EvaluatorBase): - tag: str = "same" - - def __call__(self, context: ModelEvaluationContext): - result = context() - return result - - @Flow.model - def add(a: int, b: FromContext[int]) -> int: - return a + b - - model = add(a=10) - inner = model.__call__.get_evaluation_context(model, FlowContext(b=1, unused="same"), _options={"cacheable": True}) - cache = MemoryCacheEvaluator() - - key1 = cache.key(ModelEvaluationContext(model=OpaqueA(), context=inner)) - key2 = cache.key(ModelEvaluationContext(model=OpaqueB(), context=inner)) - - assert key1 != key2 - - def test_generated_model_cache_changes_when_consumed_context_field_changes(): calls = {"count": 0} @@ -3412,33 +2974,6 @@ def add(a: int, b: FromContext[int]) -> int: assert len(cache.cache) == 2 -def test_generated_model_cache_uses_structural_key_with_nontransparent_evaluator(): - calls = {"count": 0} - - class OpaqueEvaluator(EvaluatorBase): - def __call__(self, context: ModelEvaluationContext): - return context() - - @Flow.model - def add(a: int, b: FromContext[int]) -> int: - calls["count"] += 1 - return a + b - - cache = MemoryCacheEvaluator() - evaluator = combine_evaluators(OpaqueEvaluator(), cache) - model = add(a=10) - - with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model.flow.compute(b=1, unused="plain").value == 11 - - with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): - assert model.flow.compute(b=1, unused="one").value == 11 - assert model.flow.compute(b=1, unused="two").value == 11 - - assert calls["count"] == 3 - assert len(cache.cache) == 3 - - def test_generated_model_cache_does_not_ignore_context_read_by_nontransparent_evaluator(): class AddAmbient(EvaluatorBase): def __call__(self, context: ModelEvaluationContext): @@ -3460,25 +2995,6 @@ def add(a: int, b: FromContext[int]) -> int: assert len(cache.cache) == 2 -def test_generated_model_cache_uses_effective_key_when_result_validation_disabled(): - calls = {"count": 0} - - @Flow.model - def add(a: int, b: FromContext[int]) -> int: - calls["count"] += 1 - return a + b - - cache = MemoryCacheEvaluator() - model = add(a=10) - - with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True, "validate_result": False}): - assert model.flow.compute(b=1, unused="one").value == 11 - assert model.flow.compute(b=1, unused="two").value == 11 - - assert calls["count"] == 1 - assert len(cache.cache) == 1 - - def test_generated_model_cache_key_preserves_result_validation_option(): calls = {"count": 0} @@ -3555,78 +3071,13 @@ def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: assert len(cache.cache) == 1 -def test_unused_lazy_bound_plain_dependency_applies_static_context_identity(): - calls = {"source": 0, "choose": 0} - - class RequiredContext(ContextBase): - a: int - b: int - - class PlainSource(CallableModel): - @Flow.call - def __call__(self, context: RequiredContext) -> GenericResult[int]: - calls["source"] += 1 - return GenericResult(value=context.a + context.b) - - @Flow.model - def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: - calls["choose"] += 1 - return lazy_value() if use_lazy else x - - cache = MemoryCacheEvaluator() - model = choose(x=7, lazy_value=PlainSource().flow.with_context(a=1)) - - with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model.flow.compute(use_lazy=False, unused="one").value == 7 - assert model.flow.compute(use_lazy=False, unused="two").value == 7 - with pytest.raises(ValidationError): - model.flow.compute(use_lazy=True) - - assert calls == {"source": 0, "choose": 2} - assert len(cache.cache) == 1 - - -def test_unused_lazy_bound_plain_dependency_dynamic_transform_can_leave_missing_context(): +def test_unused_lazy_bound_dependency_uses_partial_context_identity(): calls = {"source": 0, "choose": 0} - class RequiredContext(ContextBase): - a: int - b: int - - class PlainSource(CallableModel): - @Flow.call - def __call__(self, context: RequiredContext) -> GenericResult[int]: - calls["source"] += 1 - return GenericResult(value=context.a + context.b) - @Flow.model - def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: - calls["choose"] += 1 - return lazy_value() if use_lazy else x - - cache = MemoryCacheEvaluator() - model = choose(x=7, lazy_value=PlainSource().flow.with_context(b=seed_plus_one())) - - with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model.flow.compute(use_lazy=False, seed=1, unused="one").value == 7 - assert model.flow.compute(use_lazy=False, seed=1, unused="two").value == 7 - - assert calls == {"source": 0, "choose": 1} - assert len(cache.cache) == 1 - - -def test_unused_lazy_bound_plain_dependency_fully_resolved_identity_ignores_ambient_context(): - calls = {"source": 0, "choose": 0} - - class RequiredContext(ContextBase): - a: int - b: int - - class PlainSource(CallableModel): - @Flow.call - def __call__(self, context: RequiredContext) -> GenericResult[int]: - calls["source"] += 1 - return GenericResult(value=context.a + context.b) + def source(a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b @Flow.model def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: @@ -3634,17 +3085,19 @@ def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: return lazy_value() if use_lazy else x cache = MemoryCacheEvaluator() - model = choose(x=7, lazy_value=PlainSource().flow.with_context(a=1, b=2)) + model = choose(x=7, lazy_value=source().flow.with_context(a=1)) with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model.flow.compute(use_lazy=False, unused="one").value == 7 - assert model.flow.compute(use_lazy=False, unused="two").value == 7 + assert model.flow.compute(use_lazy=False, a=100).value == 7 + assert model.flow.compute(use_lazy=False, a=200).value == 7 + with pytest.raises(TypeError, match="Missing contextual input"): + model.flow.compute(use_lazy=True, a=300) - assert calls == {"source": 0, "choose": 1} + assert calls == {"source": 0, "choose": 2} assert len(cache.cache) == 1 -def test_unused_lazy_bound_dependency_uses_partial_context_identity(): +def test_unused_lazy_bound_dependency_with_unresolved_transform_has_stable_identity(): calls = {"source": 0, "choose": 0} @Flow.model @@ -3652,19 +3105,23 @@ def source(a: FromContext[int], b: FromContext[int]) -> int: calls["source"] += 1 return a + b + @Flow.context_transform + def a_from_seed(seed: FromContext[int]) -> int: + return seed + 1 + @Flow.model def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: calls["choose"] += 1 return lazy_value() if use_lazy else x cache = MemoryCacheEvaluator() - model = choose(x=7, lazy_value=source().flow.with_context(a=1)) + model = choose(x=7, lazy_value=source().flow.with_context(a=a_from_seed())) with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model.flow.compute(use_lazy=False, a=100).value == 7 - assert model.flow.compute(use_lazy=False, a=200).value == 7 + assert model.flow.compute(use_lazy=False, b=1, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, b=2, unused="two").value == 7 with pytest.raises(TypeError, match="Missing contextual input"): - model.flow.compute(use_lazy=True, a=300) + model.flow.compute(use_lazy=True, b=3) assert calls == {"source": 0, "choose": 2} assert len(cache.cache) == 1 @@ -3714,33 +3171,6 @@ def choose(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: assert model.flow.compute(use_lazy=True, seed=2, b=10).value == 13 assert calls == {"source": 2, "choose": 2} - assert len(cache.cache) == 6 - - -def test_unused_lazy_bound_dependency_records_missing_transform_context(): - calls = {"source": 0, "choose": 0} - - @Flow.model - def source(a: FromContext[int], b: FromContext[int]) -> int: - calls["source"] += 1 - return a + b - - @Flow.model - def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: - calls["choose"] += 1 - return lazy_value() if use_lazy else x - - cache = MemoryCacheEvaluator() - model = choose(x=7, lazy_value=source().flow.with_context(a=seed_plus_one())) - - with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model.flow.compute(use_lazy=False, unused="one").value == 7 - assert model.flow.compute(use_lazy=False, unused="two").value == 7 - assert model.flow.compute(use_lazy=False, a=100, b=1).value == 7 - assert model.flow.compute(use_lazy=False, a=200, b=1).value == 7 - - assert calls == {"source": 0, "choose": 1} - assert len(cache.cache) == 1 def test_used_lazy_generated_dependency_identity_respects_contextual_defaults(): @@ -3797,41 +3227,6 @@ def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: assert len(cache.cache) == 2 -def test_generated_model_diamond_cache_reuses_shared_source_and_ignores_unused_fields(): - calls = {"source": 0, "left": 0, "right": 0, "root": 0} - - @Flow.model - def source(value: FromContext[int]) -> int: - calls["source"] += 1 - return value + 10 - - @Flow.model - def left(x: int) -> int: - calls["left"] += 1 - return x * 2 - - @Flow.model - def right(x: int) -> int: - calls["right"] += 1 - return x * 5 - - @Flow.model - def root(left_value: int, right_value: int, bonus: FromContext[int]) -> int: - calls["root"] += 1 - return left_value + right_value + bonus - - shared = source() - model = root(left_value=left(x=shared), right_value=right(x=shared)) - cache = MemoryCacheEvaluator() - - with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model(FlowContext(value=3, bonus=7, unused="one")).value == 98 - assert model(FlowContext(value=3, bonus=7, unused="two")).value == 98 - - assert calls == {"source": 1, "left": 1, "right": 1, "root": 1} - assert len(cache.cache) == 4 - - def test_bound_generated_sibling_dependencies_keep_distinct_rewritten_contexts_with_graph_cache(): calls = {"source": 0, "root": 0} @@ -3857,35 +3252,6 @@ def root(left: int, right: int) -> int: assert len(cache.cache) >= 3 -def test_generated_dependency_graph_identity_ignores_unused_flow_context_fields(): - @Flow.model - def source(value: FromContext[int]) -> int: - return value * 10 - - @Flow.model - def root(x: int, bonus: FromContext[int]) -> int: - return x + bonus - - model = root(x=source()) - graph1 = get_dependency_graph(model.__call__.get_evaluation_context(model, FlowContext(value=3, bonus=7, unused="one"))) - graph2 = get_dependency_graph(model.__call__.get_evaluation_context(model, FlowContext(value=3, bonus=7, unused="two"))) - - assert graph1.root_id == graph2.root_id - assert set(graph1.graph.keys()) == set(graph2.graph.keys()) - assert set(graph1.ids.keys()) == set(graph2.ids.keys()) - - -def test_bound_generated_model_dependency_graph_has_no_self_loop(): - @Flow.model - def add(a: int, b: FromContext[int]) -> int: - return a + b - - bound = add(a=1).flow.with_context(b=2) - graph = get_dependency_graph(bound.__call__.get_evaluation_context(bound, FlowContext(b=99))) - - assert graph.root_id not in graph.graph[graph.root_id] - - def test_bound_generated_model_dependency_graph_traverses_collapsed_child_deps(): @Flow.model def source(value: FromContext[int]) -> int: @@ -3948,26 +3314,6 @@ def add(a: int, b: FromContext[int]) -> int: assert calls == {"add": 2, "evaluator": 1} -def test_plain_callable_model_cache_remains_structural_for_flow_context(): - calls = {"count": 0} - - class Counter(CallableModel): - @Flow.call - def __call__(self, context: FlowContext) -> GenericResult[int]: - calls["count"] += 1 - return GenericResult(value=context.value) - - cache = MemoryCacheEvaluator() - model = Counter() - - with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): - assert model(FlowContext(value=10, unused="one")).value == 10 - assert model(FlowContext(value=10, unused="two")).value == 10 - - assert calls["count"] == 2 - assert len(cache.cache) == 2 - - def test_generated_models_cross_process_pickle(): """Module-level @Flow.model instances are deserializable in a separate process.""" model = basic_loader(source="library", multiplier=3) @@ -3985,26 +3331,6 @@ def test_generated_models_cross_process_pickle(): assert result.returncode == 0, f"Cross-process unpickle failed:\n{result.stderr}" -def test_generated_models_cross_process_cloudpickle(): - """Module-level @Flow.model instances are deserializable via cloudpickle in a separate process.""" - from ray.cloudpickle import dumps as rcpdumps - - model = basic_loader(source="library", multiplier=3) - data = rcpdumps(model, protocol=5) - encoded = base64.b64encode(data).decode() - script = ( - "import base64\n" - "from ray.cloudpickle import loads as rcploads\n" - f"data = base64.b64decode('{encoded}')\n" - "model = rcploads(data)\n" - "from ccflow import FlowContext\n" - "result = model.flow.compute(value=4)\n" - "assert result.value == 12, f'Expected 12, got {result.value}'\n" - ) - result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) - assert result.returncode == 0, f"Cross-process cloudpickle failed:\n{result.stderr}" - - def test_local_generated_models_cross_process_cloudpickle(): """Local @Flow.model instances carry their generated class across processes.""" from ray.cloudpickle import dumps as rcpdumps @@ -4029,28 +3355,6 @@ def add(a: int, b: FromContext[int]) -> int: assert result.returncode == 0, f"Cross-process local cloudpickle failed:\n{result.stderr}" -def test_local_generated_model_explicit_generic_result_cross_process_cloudpickle(): - """Local generated models should not rehydrate from fragile GenericResult[T] annotations.""" - from ray.cloudpickle import dumps as rcpdumps - - @Flow.model - def add(a: int, b: FromContext[int]) -> GenericResult[int]: - return GenericResult(value=a + b) - - encoded = base64.b64encode(rcpdumps(add(a=1), protocol=5)).decode() - script = ( - "import base64\n" - "from ray.cloudpickle import loads as rcploads\n" - f"data = base64.b64decode('{encoded}')\n" - "model = rcploads(data)\n" - "result = model.flow.compute(b=2)\n" - "assert result.value == 3, f'Expected 3, got {result.value}'\n" - "assert repr(model.result_type) == \"\"\n" - ) - result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) - assert result.returncode == 0, f"Cross-process explicit GenericResult cloudpickle failed:\n{result.stderr}" - - def test_local_generated_model_postponed_annotations_cross_process_cloudpickle(): """Local generated models should restore from analyzed config, not worker-side type-hint resolution.""" from ray.cloudpickle import dumps as rcpdumps @@ -4083,6 +3387,28 @@ def add(a: int, b: FromContext[int]) -> int: assert result.returncode == 0, f"Cross-process postponed-annotation cloudpickle failed:\n{result.stderr}" +def test_local_generated_model_complex_annotations_same_process_cloudpickle(): + """Local restore should reuse the analyzed contract for nested typing shapes.""" + + def make_model(): + @Flow.model + def summarize( + values: Annotated[list[int], Field(min_length=1)], + mode: Literal["sum", "first"], + offsets: tuple[int, ...], + maybe_offset: int | None, + scale: FromContext[int], + ) -> GenericResult[int]: + total = values[0] if mode == "first" else sum(values) + return GenericResult(value=(total + sum(offsets) + (maybe_offset or 0)) * scale) + + return summarize(values=[1, 2], mode="sum", offsets=(3,), maybe_offset=None) + + restored = rcploads(rcpdumps(make_model(), protocol=5)) + + assert restored.flow.compute(scale=2) == GenericResult[int](value=12) + + def test_model_base_fields_visible_in_bound_inputs(): """model_base fields that are explicitly set should appear in bound_inputs.""" diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py index f47a58e..e93bac8 100644 --- a/ccflow/tests/test_flow_model_hydra.py +++ b/ccflow/tests/test_flow_model_hydra.py @@ -3,11 +3,7 @@ from datetime import date from pathlib import Path -from omegaconf import OmegaConf - -from ccflow import CallableModel, DateRangeContext, FlowContext, ModelRegistry - -from .flow_model_hydra_fixtures import SimpleContext +from ccflow import CallableModel, FlowContext, ModelRegistry CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") @@ -26,51 +22,26 @@ def test_basic_loader_from_yaml(): loader = registry["flow_loader"] assert isinstance(loader, CallableModel) - assert loader(SimpleContext(value=10)).value == 50 + assert loader.flow.compute(value=10).value == 50 -def test_basic_processor_from_yaml(): - registry = ModelRegistry.root() - registry.load_config_from_path(CONFIG_PATH) - - processor = registry["flow_processor"] - assert processor(SimpleContext(value=42)).value == "value=42!" - - -def test_two_stage_pipeline_from_yaml(): +def test_registry_dependency_from_yaml(): registry = ModelRegistry.root() registry.load_config_from_path(CONFIG_PATH) transformer = registry["flow_transformer"] - assert transformer(SimpleContext(value=5)).value == 315 - - -def test_three_stage_pipeline_from_yaml(): - registry = ModelRegistry.root() - registry.load_config_from_path(CONFIG_PATH) - - stage3 = registry["flow_stage3"] - assert stage3(SimpleContext(value=10)).value == 90 + assert transformer.source is registry["flow_source"] + assert transformer.flow.compute(value=5).value == 315 -def test_diamond_dependency_from_yaml(): +def test_diamond_dependency_from_yaml_reuses_shared_source(): registry = ModelRegistry.root() registry.load_config_from_path(CONFIG_PATH) aggregator = registry["diamond_aggregator"] - assert aggregator(SimpleContext(value=10)).value == 140 - - -def test_date_range_pipeline_from_yaml(): - registry = ModelRegistry.root() - registry.load_config_from_path(CONFIG_PATH) - - processor = registry["flow_date_processor"] - ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) - result = processor(ctx) - - assert "normalized:" in result.value - assert "2024-01-09" in result.value + assert aggregator.input_a.source is registry["diamond_source"] + assert aggregator.input_b.source is registry["diamond_source"] + assert aggregator.flow.compute(value=10).value == 140 def test_from_context_pipeline_from_yaml(): @@ -80,56 +51,6 @@ def test_from_context_pipeline_from_yaml(): loader = registry["contextual_loader_model"] processor = registry["contextual_processor_model"] - assert loader.flow.context_inputs == {"start_date": date, "end_date": date} - result = processor.flow.compute(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) - assert result.value == "output:data_source:2024-03-01 to 2024-03-31" assert processor.data is loader - - -def test_registry_name_references_share_instances(): - registry = ModelRegistry.root() - registry.load_config_from_path(CONFIG_PATH) - - transformer = registry["flow_transformer"] - source = registry["flow_source"] - assert transformer.source is source - - stage2 = registry["flow_stage2"] - stage3 = registry["flow_stage3"] - assert stage2.stage1_output is registry["flow_stage1"] - assert stage3.stage2_output is stage2 - - -def test_instantiate_with_omegaconf(): - cfg = OmegaConf.create( - { - "loader": { - "_target_": "ccflow.tests.flow_model_hydra_fixtures.basic_loader", - "source": "generated_input", - "multiplier": 7, - }, - "contextual": { - "_target_": "ccflow.tests.flow_model_hydra_fixtures.contextual_loader", - "source": "library", - }, - } - ) - - registry = ModelRegistry.root() - registry.load_config(cfg) - - assert registry["loader"](SimpleContext(value=3)).value == 21 - assert registry["contextual"].flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value == { - "source": "library", - "start_date": "2024-01-01", - "end_date": "2024-01-02", - } - - -def test_flow_context_execution_with_yaml_models(): - registry = ModelRegistry.root() - registry.load_config_from_path(CONFIG_PATH) - - processor = registry["contextual_processor_model"] - result = processor.flow.compute(FlowContext(start_date=date(2024, 4, 1), end_date=date(2024, 4, 30))) - assert result.value == "output:data_source:2024-04-01 to 2024-04-30" + result = processor.flow.compute(FlowContext(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31))) + assert result.value == "output:data_source:2024-03-01 to 2024-03-31" From cede744a37d2c12e396fbfaf3224b40c00ec62a0 Mon Sep 17 00:00:00 2001 From: Nijat K Date: Sun, 17 May 2026 05:39:35 -0400 Subject: [PATCH 7/8] Move examples inside ccflow, not top level Signed-off-by: Nijat K --- ccflow/examples/flow_model/__init__.py | 1 + .../flow_model}/config/flow_model_hydra_builder_demo.yaml | 8 ++++---- .../examples/flow_model}/flow_model_example.py | 2 +- .../examples/flow_model}/flow_model_hydra_builder_demo.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 ccflow/examples/flow_model/__init__.py rename {examples => ccflow/examples/flow_model}/config/flow_model_hydra_builder_demo.yaml (62%) rename {examples => ccflow/examples/flow_model}/flow_model_example.py (98%) rename {examples => ccflow/examples/flow_model}/flow_model_hydra_builder_demo.py (98%) diff --git a/ccflow/examples/flow_model/__init__.py b/ccflow/examples/flow_model/__init__.py new file mode 100644 index 0000000..9b8bb74 --- /dev/null +++ b/ccflow/examples/flow_model/__init__.py @@ -0,0 +1 @@ +"""Flow.model examples.""" diff --git a/examples/config/flow_model_hydra_builder_demo.yaml b/ccflow/examples/flow_model/config/flow_model_hydra_builder_demo.yaml similarity index 62% rename from examples/config/flow_model_hydra_builder_demo.yaml rename to ccflow/examples/flow_model/config/flow_model_hydra_builder_demo.yaml index 9755e5d..af049fe 100644 --- a/examples/config/flow_model_hydra_builder_demo.yaml +++ b/ccflow/examples/flow_model/config/flow_model_hydra_builder_demo.yaml @@ -1,4 +1,4 @@ -# Hydra config for examples/flow_model_hydra_builder_demo.py +# Hydra config for ccflow/examples/flow_model/flow_model_hydra_builder_demo.py # # Pattern: # - configure static pipeline specs in YAML @@ -6,11 +6,11 @@ # - keep runtime context as runtime inputs, supplied later at execution time library_visitors: - _target_: examples.flow_model_hydra_builder_demo.count_visitors + _target_: ccflow.examples.flow_model.flow_model_hydra_builder_demo.count_visitors location: library previous_week: - _target_: examples.flow_model_hydra_builder_demo.build_visitor_delta + _target_: ccflow.examples.flow_model.flow_model_hydra_builder_demo.build_visitor_delta current: _target_: ccflow.compose.model_alias model_name: library_visitors @@ -18,7 +18,7 @@ previous_week: days_back: 7 previous_two_weeks: - _target_: examples.flow_model_hydra_builder_demo.build_visitor_delta + _target_: ccflow.examples.flow_model.flow_model_hydra_builder_demo.build_visitor_delta current: _target_: ccflow.compose.model_alias model_name: library_visitors diff --git a/examples/flow_model_example.py b/ccflow/examples/flow_model/flow_model_example.py similarity index 98% rename from examples/flow_model_example.py rename to ccflow/examples/flow_model/flow_model_example.py index 7767123..f6e2196 100644 --- a/examples/flow_model_example.py +++ b/ccflow/examples/flow_model/flow_model_example.py @@ -9,7 +9,7 @@ 4. execute the configured graph with `model.flow.compute(...)`. Run with: - python examples/flow_model_example.py + python ccflow/examples/flow_model/flow_model_example.py """ from datetime import date, timedelta diff --git a/examples/flow_model_hydra_builder_demo.py b/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py similarity index 98% rename from examples/flow_model_hydra_builder_demo.py rename to ccflow/examples/flow_model/flow_model_hydra_builder_demo.py index 30722b6..64c1a45 100644 --- a/examples/flow_model_hydra_builder_demo.py +++ b/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py @@ -13,7 +13,7 @@ - let Hydra instantiate that builder and register the returned model. Run with: - python examples/flow_model_hydra_builder_demo.py + python ccflow/examples/flow_model/flow_model_hydra_builder_demo.py """ from datetime import date, timedelta From 5d73ed6967abcd49e33741e5a4d2e6f5857726b5 Mon Sep 17 00:00:00 2001 From: Nijat K Date: Mon, 18 May 2026 07:39:41 -0400 Subject: [PATCH 8/8] Add 'inspect' to consolidate introspection for @Flow.model Signed-off-by: Nijat K --- ccflow/callable.py | 16 + .../examples/flow_model/flow_model_example.py | 13 +- .../flow_model_hydra_builder_demo.py | 7 +- ccflow/flow_model.py | 621 +++++++++++++++++- ccflow/tests/test_flow_context.py | 50 +- ccflow/tests/test_flow_model.py | 591 +++++++++++++++-- docs/wiki/Flow-Model.md | 174 +++-- 7 files changed, 1314 insertions(+), 158 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 3dd8530..cff2a32 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -498,6 +498,22 @@ def model(*args, **kwargs): Generated-model options also include ``context_type`` for validating contextual fields, ``auto_unwrap`` for ergonomic compute results, and ``model_base`` for custom ``CallableModel`` bases. + + Args: + func: The function being decorated. This is passed automatically in + bare decorator form, for example ``@Flow.model``. When using + options, for example ``@Flow.model(auto_unwrap=True)``, Python + first calls ``Flow.model(...)`` without a function and then + applies the returned decorator. + context_type: Optional ``ContextBase`` subclass used to validate + ``FromContext[...]`` fields together. + auto_unwrap: When ``True`` and the function's plain return value is + auto-wrapped in ``GenericResult[T]``, external + ``model.flow.compute(...)`` calls return the raw ``T`` value + instead of ``GenericResult[T]``. Internal model dependencies + still use normal ccflow result objects. + model_base: Custom ``CallableModel`` subclass to use as the base + class for the generated model. """ from .flow_model import flow_model diff --git a/ccflow/examples/flow_model/flow_model_example.py b/ccflow/examples/flow_model/flow_model_example.py index f6e2196..3762853 100644 --- a/ccflow/examples/flow_model/flow_model_example.py +++ b/ccflow/examples/flow_model/flow_model_example.py @@ -100,11 +100,14 @@ def main() -> None: print("\nPipeline:") print(" model: visitor_delta") - print(f" bound inputs: {_format_bound_inputs(pipeline.flow.bound_inputs)}") - print(f" declared context inputs: {_format_input_names(pipeline.flow.context_inputs)}") - print(f" runtime inputs: {_format_input_names(pipeline.flow.runtime_inputs)}") - print(f" current runtime inputs: {_format_input_names(pipeline.current.flow.runtime_inputs)}") - print(f" previous runtime inputs: {_format_input_names(pipeline.previous.flow.runtime_inputs)}") + pipeline_inspection = pipeline.flow.inspect() + current_inspection = pipeline.current.flow.inspect() + previous_inspection = pipeline.previous.flow.inspect() + print(f" bound inputs: {_format_bound_inputs(pipeline_inspection.bound_inputs)}") + print(f" declared context inputs: {_format_input_names(pipeline_inspection.context_inputs)}") + print(f" runtime inputs: {_format_input_names(pipeline_inspection.runtime_inputs)}") + print(f" current runtime inputs: {_format_input_names(current_inspection.runtime_inputs)}") + print(f" previous runtime inputs: {_format_input_names(previous_inspection.runtime_inputs)}") print("\nExecution:") print(f" context object == kwargs: {computed_from_context == computed_from_kwargs}") diff --git a/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py b/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py index 64c1a45..fbecd7b 100644 --- a/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py +++ b/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py @@ -38,10 +38,11 @@ def _format_bound_inputs(inputs: dict[str, object]) -> str: def _print_model_summary(label: str, model: CallableModel) -> None: + inspection = model.flow.inspect() print(f" {label}:") - print(f" bound inputs: {_format_bound_inputs(model.flow.bound_inputs)}") - print(f" declared context inputs: {_format_input_names(model.flow.context_inputs)}") - print(f" runtime inputs: {_format_input_names(model.flow.runtime_inputs)}") + print(f" bound inputs: {_format_bound_inputs(inspection.bound_inputs)}") + print(f" declared context inputs: {_format_input_names(inspection.context_inputs)}") + print(f" runtime inputs: {_format_input_names(inspection.runtime_inputs)}") @Flow.model(context_type=DateRangeContext) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index efe03f8..101998f 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -101,6 +101,8 @@ __all__ = ( "FlowAPI", "BoundModel", + "FlowInspection", + "InputSpec", "FromContext", "Lazy", ) @@ -127,6 +129,105 @@ def __reduce__(self): _UNHASHABLE_TYPE_ADAPTER_CACHE: "OrderedDict[int, Tuple[Any, TypeAdapter]]" = OrderedDict() +class InputSpec(NamedTuple): + """Richer description of one direct function input in ``FlowInspection``.""" + + name: str + type: Any + required: bool + default: Any + value: Any + source: str + + @property + def default_repr(self) -> str: + """Compact representation of the declared function default.""" + + return _flow_debug_value_repr(self.default) + + @property + def value_repr(self) -> str: + """Compact representation of the effective direct value.""" + + return _flow_debug_value_repr(self.value) + + def __repr__(self) -> str: + return ( + f"InputSpec(name={self.name!r}, type={_expected_type_repr(self.type)}, " + f"required={self.required!r}, default={self.default_repr}, " + f"value={self.value_repr}, source={self.source!r})" + ) + + +class DependencySpec(NamedTuple): + """User-facing provenance for one direct dependency edge.""" + + path: str + model: CallableModel + context: Optional[ContextBase] = None + lazy: bool = False + + def __repr__(self) -> str: + return ( + f"DependencySpec(path={self.path!r}, model={_flow_debug_value_repr(self.model)}, " + f"context={_flow_debug_value_repr(self.context)}, lazy={self.lazy!r})" + ) + + +class FlowInspection(NamedTuple): + """Structured current-level debugging summary returned by ``model.flow.inspect()``.""" + + model: CallableModel + context_inputs: Dict[str, Any] + runtime_inputs: Dict[str, Any] + required_inputs: Dict[str, Any] + bound_inputs: Dict[str, Any] + inputs: Dict[str, InputSpec] + dependencies: Tuple[DependencySpec, ...] + + def __str__(self) -> str: + lines = [f"FlowInspection(model={_flow_debug_model_name(self.model)})"] + if self.inputs: + lines.append(" inputs:") + for spec in self.inputs.values(): + required = " required" if spec.required else "" + lines.append(f" {spec.name}: {_expected_type_repr(spec.type)} = {spec.value_repr} [{spec.source}{required}]") + else: + lines.append(" inputs: none") + if self.context_inputs: + context_inputs = ", ".join(f"{name}: {_expected_type_repr(annotation)}" for name, annotation in self.context_inputs.items()) + else: + context_inputs = "none" + lines.append(f" contextual inputs: {context_inputs}") + if self.runtime_inputs: + runtime_inputs = ", ".join(f"{name}: {_expected_type_repr(annotation)}" for name, annotation in self.runtime_inputs.items()) + else: + runtime_inputs = "none" + lines.append(f" runtime inputs: {runtime_inputs}") + if self.required_inputs: + required = ", ".join(f"{name}: {_expected_type_repr(annotation)}" for name, annotation in self.required_inputs.items()) + else: + required = "none" + lines.append(f" required runtime inputs: {required}") + lines.append(f" bound inputs: {', '.join(self.bound_inputs) if self.bound_inputs else 'none'}") + if self.dependencies: + lines.append(" dependencies:") + for dependency in self.dependencies: + target = _flow_debug_model_name(dependency.model) + suffix = " lazy" if dependency.lazy else "" + context = f" context={_flow_debug_value_repr(dependency.context)}" if dependency.context is not None else "" + lines.append(f" {dependency.path} -> {target}{suffix}{context}") + else: + lines.append(" dependencies: none") + return "\n".join(lines) + + def __repr__(self) -> str: + return str(self) + + def _repr_pretty_(self, printer, cycle: bool) -> None: + printer.text("..." if cycle else str(self)) + + def _unset_flow_input_factory() -> _UnsetFlowInput: return _UNSET_FLOW_INPUT @@ -135,6 +236,29 @@ def _is_unset_flow_input(value: Any) -> bool: return value is _UNSET_FLOW_INPUT +def _flow_debug_value_repr(value: Any) -> str: + if _is_unset_flow_input(value): + return repr(value) + + bound_context_type = globals().get("_BoundModelContext") + if bound_context_type is not None and isinstance(value, bound_context_type): + return repr(FlowContext(**_context_values(value))) + + callable_model_type = globals().get("CallableModel") + if callable_model_type is not None and isinstance(value, callable_model_type): + return f"" + + return repr(value) + + +def _flow_debug_model_name(model: CallableModel) -> str: + if isinstance(model, BoundModel): + return f"{_flow_debug_model_name(model.model)}.flow.with_context(...)" + name = model.meta.name + cls_name = type(model).__name__ + return f"{name} ({cls_name})" if name else cls_name + + _ModelContextContract = NamedTuple( "_ModelContextContract", [ @@ -848,7 +972,15 @@ def _resolve_regular_param_value(model: "_GeneratedFlowModelBase", param: _FlowM if _is_model_dependency(value): dependency_model, dependency_context = _resolved_dependency_invocation(value, context) - return _unwrap_model_result(dependency_model(dependency_context)) + try: + return _unwrap_model_result(dependency_model(dependency_context)) + except Exception as exc: + parent = _callable_name(type(model).__flow_model_config__.func) + child = dependency_model.meta.name or type(dependency_model).__name__ + add_note = getattr(exc, "add_note", None) + if callable(add_note): + add_note(f"Error while evaluating dependency {parent}.{param.name} -> {child}.") + raise return value @@ -1719,7 +1851,7 @@ def _build_compute_context(model: CallableModel, context: Any, kwargs: Dict[str, """ if context is not _UNSET and kwargs: - raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + raise TypeError("compute() accepts either one context object or contextual keyword inputs, but not both.") ctx_type = model.context_type _ctx_is_optional = _is_optional_context_type(ctx_type) @@ -1802,7 +1934,7 @@ def _build_bound_compute_context(bound_model: "BoundModel", context: Any, kwargs """Construct the ambient context passed into a ``BoundModel`` by ``compute``.""" if context is not _UNSET and kwargs: - raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + raise TypeError("compute() accepts either one context object or contextual keyword inputs, but not both.") if context is not _UNSET: return context if not kwargs and _bound_model_preserves_none_context(bound_model): @@ -1810,6 +1942,180 @@ def _build_bound_compute_context(bound_model: "BoundModel", context: Any, kwargs return _bound_model_ambient_context(bound_model, kwargs) +def _raw_input_values_for_debug(model: CallableModel, context: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Return caller-supplied context values without executing the model.""" + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword inputs, but not both.") + if context is _UNSET: + return dict(kwargs) + if context is None: + return {} + if isinstance(context, ContextBase): + return _context_values(context) + if isinstance(context, Mapping): + return dict(context) + return _context_values(_model_context_contract(model).runtime_context_type.model_validate(context)) + + +def _partial_context_for_inspect(model: CallableModel, values: Dict[str, Any]) -> ContextBase: + contract = _model_context_contract(model) + if contract.input_types is None: + return FlowContext(**values) + return FlowContext(**{name: values[name] for name in contract.input_types if name in values}) + + +def _partial_dependency_context_for_inspect(model: CallableModel, context: ContextBase) -> ContextBase: + if isinstance(model, BoundModel): + return FlowContext(**_context_values(context)) + return _partial_context_for_inspect(model, _context_values(context)) + + +def _project_bound_dependency_context_for_inspect(model: "BoundModel", context: ContextBase) -> ContextBase: + values = _context_values(context) + projected = {name: values[name] for name in model.flow._runtime_inputs if name in values} + if isinstance(context, _BoundModelContext) and context._base_context is not None: + return _BoundModelContext.from_values(projected, base_context=context._base_context) + return FlowContext(**projected) + + +def _generated_context_argument_specs(generated: "_GeneratedFlowModelBase", input_types: Optional[Dict[str, Any]]) -> Dict[str, InputSpec]: + config = type(generated).__flow_model_config__ + explicit_fields = _bound_field_names(generated) + specs: Dict[str, InputSpec] = {} + for param in config.contextual_params: + annotation = param.annotation if input_types is None else input_types[param.name] + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if param.name in explicit_fields and not _is_unset_flow_input(value): + specs[param.name] = InputSpec(param.name, annotation, False, _UNSET_FLOW_INPUT, value, "construction") + elif param.has_function_default: + specs[param.name] = InputSpec(param.name, annotation, False, param.function_default, param.function_default, "function_default") + else: + specs[param.name] = InputSpec(param.name, annotation, True, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "runtime") + return specs + + +def _plain_context_argument_specs(model: CallableModel, contract: _ModelContextContract) -> Dict[str, InputSpec]: + if contract.input_types is None: + return {} + default_values = _plain_model_default_context_values(model, contract.runtime_context_type) + required_inputs = set(contract.required_names) + if default_values is not None: + required_inputs -= set(default_values) + specs = {} + for name, annotation in contract.input_types.items(): + if default_values is not None and name in default_values: + specs[name] = InputSpec(name, annotation, False, default_values[name], default_values[name], "context_default") + else: + specs[name] = InputSpec(name, annotation, name in required_inputs, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "runtime") + return specs + + +def _direct_dependency_specs( + model: CallableModel, + context: Optional[ContextBase] = None, + *, + trim_context: bool = True, +) -> Tuple[DependencySpec, ...]: + if isinstance(model, BoundModel): + rewritten_context = None if context is None else model._rewrite_context(context) + return _direct_dependency_specs(model.model, rewritten_context, trim_context=trim_context) + + generated = _generated_model_instance(model) + if generated is None: + return () + + specs = [] + config = type(generated).__flow_model_config__ + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value) or not _is_model_dependency(value): + continue + dependency_model = value + dependency_context = None + if context is not None: + try: + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + except (TypeError, ValueError, ValidationError): + dependency_model = value + dependency_context = _partial_dependency_context_for_inspect(value, context) + else: + contract = _model_context_contract(dependency_model) + if trim_context and dependency_context is not None: + if isinstance(dependency_model, BoundModel): + dependency_context = _project_bound_dependency_context_for_inspect(dependency_model, dependency_context) + elif contract.input_types is not None: + values = _context_values(dependency_context) + dependency_context = _runtime_context_for_model( + dependency_model, + {name: values[name] for name in contract.input_types if name in values}, + ) + specs.append(DependencySpec(param.name, dependency_model, dependency_context, lazy=param.is_lazy)) + return tuple(specs) + + +def _normalize_inspect_dependencies(dependencies: str) -> Literal["none", "direct", "recursive"]: + if dependencies == "none": + return "none" + if dependencies == "direct": + return "direct" + if dependencies == "recursive": + return "recursive" + raise ValueError("dependencies must be one of: direct, none, recursive") + + +def _is_missing_contextual_input_error(error: TypeError) -> bool: + return str(error).startswith("Missing contextual input(s)") + + +def _recursive_dependency_specs_for_flow( + flow: "FlowAPI", + context: Optional[ContextBase], + *, + prefix: str = "", + lazy_parent: bool = False, + active: Optional[Set[int]] = None, +) -> Tuple[DependencySpec, ...]: + """Return inspect-visible dependency specs below ``flow`` with prefixed paths.""" + + active = set() if active is None else active + model_id = id(flow._compute_target) + if model_id in active: + return () + + active.add(model_id) + try: + try: + argument_context = flow._argument_context(context) + direct_dependencies = flow._dependency_specs_for_inspect( + context, + argument_context, + preserve_ambient_context=True, + ) + except TypeError as exc: + if not _is_missing_contextual_input_error(exc): + raise + return () + result = [] + for dependency in direct_dependencies: + path = f"{prefix}.{dependency.path}" if prefix else dependency.path + lazy = lazy_parent or dependency.lazy + prefixed = DependencySpec(path, dependency.model, dependency.context, lazy) + result.append(prefixed) + result.extend( + _recursive_dependency_specs_for_flow( + dependency.model.flow, + dependency.context, + prefix=path, + lazy_parent=lazy, + active=active, + ) + ) + return tuple(result) + finally: + active.remove(model_id) + + # --------------------------------------------------------------------------- # model.flow API and BoundModel wrapper # --------------------------------------------------------------------------- @@ -1824,9 +2130,25 @@ class FlowAPI: what can be inferred from their ``context_type`` and pydantic fields. """ + _PUBLIC_HELP: ClassVar[Dict[str, str]] = { + "compute": "Evaluate the model from a context object or runtime keyword context.", + "with_context": "Return a new wrapper that binds or rewrites runtime context before evaluation; it does not mutate this model.", + "inspect": "Return a readable debugging summary. Options: dependencies='direct|recursive|none'.", + } + _PUBLIC_NAMES: ClassVar[Tuple[str, ...]] = tuple(_PUBLIC_HELP) + def __init__(self, model: CallableModel): self._model = model + def __dir__(self) -> List[str]: + """Return a focused list of public helpers for interactive autocomplete.""" + + return sorted(self._PUBLIC_NAMES) + + def __repr__(self) -> str: + helpers = ", ".join(self._PUBLIC_NAMES) + return f"{type(self).__name__}(model={self._compute_target!r}, helpers=[{helpers}])" + @property def _compute_target(self) -> CallableModel: return self._model @@ -1839,20 +2161,20 @@ def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = No return _maybe_auto_unwrap_external_result(target, target(built_context, _options=_options)) @property - def context_inputs(self) -> Dict[str, Any]: + def _context_inputs(self) -> Dict[str, Any]: """Declared contextual input names and expected types for this model.""" contract = _model_context_contract(self._model) return dict(contract.input_types or {}) @property - def runtime_inputs(self) -> Dict[str, Any]: + def _runtime_inputs(self) -> Dict[str, Any]: """Direct runtime context fields this model may read from the caller.""" - return self.context_inputs + return self._context_inputs @property - def required_inputs(self) -> Dict[str, Any]: + def _required_inputs(self) -> Dict[str, Any]: """Required direct runtime context fields still needed from the caller.""" contract = _model_context_contract(self._model) @@ -1880,7 +2202,172 @@ def required_inputs(self) -> Dict[str, Any]: return result @property - def bound_inputs(self) -> Dict[str, Any]: + def _context_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of declared direct contextual inputs.""" + + contract = _model_context_contract(self._model) + if contract.generated_model is not None: + return _generated_context_argument_specs(contract.generated_model, contract.input_types) + return _plain_context_argument_specs(self._model, contract) + + @property + def _runtime_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of direct runtime inputs read from the caller.""" + + specs = self._context_argument_specs + runtime_names = set(self._runtime_inputs) + required_names = set(self._required_inputs) + result = {} + for name in runtime_names: + spec = specs.get(name) + result[name] = InputSpec( + name, + self._runtime_inputs[name], + name in required_names, + spec.default if spec is not None else _UNSET_FLOW_INPUT, + spec.value if spec is not None else _UNSET_FLOW_INPUT, + spec.source if spec is not None else "runtime", + ) + return result + + @property + def _argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of direct construction and contextual inputs.""" + + generated = _model_context_contract(self._model).generated_model + if generated is None: + return self._context_argument_specs + + config = type(generated).__flow_model_config__ + specs: Dict[str, InputSpec] = {} + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + specs[param.name] = InputSpec(param.name, param.annotation, False, _UNSET_FLOW_INPUT, value, "construction") + elif param.has_function_default: + specs[param.name] = InputSpec(param.name, param.annotation, False, param.function_default, param.function_default, "function_default") + else: + specs[param.name] = InputSpec(param.name, param.annotation, True, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "construction") + specs.update(self._context_argument_specs) + return specs + + def inspect( + self, + context: Any = _UNSET, + /, + *, + dependencies: Literal["direct", "recursive", "none"] = "direct", + **kwargs, + ) -> FlowInspection: + """Return a readable direct debugging summary for this model. + + Args: + context: Optional runtime context object. + dependencies: Dependency inspection depth. ``"direct"`` includes + only immediate dependencies; ``"recursive"`` follows + inspect-visible dependencies recursively; ``"none"`` skips + dependency inspection. + **kwargs: Runtime context field values. + """ + + dependency_depth = _normalize_inspect_dependencies(dependencies) + dependency_context = None + argument_context = None + dependency_specs: Tuple[DependencySpec, ...] + if context is not _UNSET or kwargs: + raw_values = _raw_input_values_for_debug(self._compute_target, context, kwargs) + try: + dependency_context = _build_compute_context(self._compute_target, context, kwargs) + except (TypeError, ValidationError): + dependency_context = _partial_context_for_inspect(self._compute_target, raw_values) + try: + argument_context = self._argument_context(dependency_context) + except (TypeError, ValidationError) as exc: + if isinstance(exc, TypeError) and not _is_missing_contextual_input_error(exc): + raise + argument_context = None + dependency_specs = self._dependency_specs_for_inspect(dependency_context, argument_context) if dependency_depth == "direct" else () + if dependency_depth == "recursive": + dependency_specs = _recursive_dependency_specs_for_flow(self, dependency_context) + else: + try: + argument_context = self._argument_context(None) + except (TypeError, ValidationError) as exc: + if isinstance(exc, TypeError) and not _is_missing_contextual_input_error(exc): + raise + argument_context = None + dependency_specs = self._dependency_specs_for_inspect(None, argument_context) if dependency_depth == "direct" else () + if dependency_depth == "recursive": + dependency_specs = _recursive_dependency_specs_for_flow(self, None) + return FlowInspection( + model=self._compute_target, + context_inputs=self._context_inputs, + runtime_inputs=self._runtime_inputs, + required_inputs=self._required_inputs, + bound_inputs=self._bound_inputs, + inputs=self._argument_specs_for_context(argument_context), + dependencies=dependency_specs, + ) + + def _argument_context(self, context: Optional[ContextBase]) -> Optional[ContextBase]: + return context + + def _dependency_specs_for_inspect( + self, + dependency_context: Optional[ContextBase], + _argument_context: Optional[ContextBase], + *, + preserve_ambient_context: bool = False, + ) -> Tuple[DependencySpec, ...]: + return _direct_dependency_specs(self._compute_target, dependency_context, trim_context=not preserve_ambient_context) + + def _context_argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + result = dict(self._context_argument_specs) + if context is None: + return result + values = _context_values(context) + for name, spec in list(result.items()): + if name in values: + result[name] = InputSpec(name, spec.type, False, spec.default, values[name], "runtime") + return result + + def _runtime_argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + specs = self._context_argument_specs_for_context(context) + runtime_names = set(self._runtime_inputs) + required_names = set(self._required_inputs) + result = {} + for name in runtime_names: + spec = specs.get(name) + result[name] = InputSpec( + name, + self._runtime_inputs[name], + name in required_names, + spec.default if spec is not None else _UNSET_FLOW_INPUT, + spec.value if spec is not None else _UNSET_FLOW_INPUT, + spec.source if spec is not None else "runtime", + ) + return result + + def _argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + generated = _model_context_contract(self._model).generated_model + if generated is None: + return self._context_argument_specs_for_context(context) + + config = type(generated).__flow_model_config__ + specs: Dict[str, InputSpec] = {} + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + specs[param.name] = InputSpec(param.name, param.annotation, False, _UNSET_FLOW_INPUT, value, "construction") + elif param.has_function_default: + specs[param.name] = InputSpec(param.name, param.annotation, False, param.function_default, param.function_default, "function_default") + else: + specs[param.name] = InputSpec(param.name, param.annotation, True, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "construction") + specs.update(self._context_argument_specs_for_context(context)) + return specs + + @property + def _bound_inputs(self) -> Dict[str, Any]: """Inputs already fixed by construction-time values or static context bindings.""" generated = _model_context_contract(self._model).generated_model @@ -1981,6 +2468,8 @@ def _evaluation_identity_payload( @property def flow(self) -> "FlowAPI": + """Access bound flow helpers for execution, context transforms, and introspection.""" + return _BoundFlowAPI(self) @@ -2001,31 +2490,79 @@ def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = No built_context = _build_bound_compute_context(self._bound, context, kwargs) return _maybe_auto_unwrap_external_result(self._bound, self._bound(built_context, _options=_options)) + def _argument_context(self, context: Optional[ContextBase]) -> Optional[ContextBase]: + if context is None: + if _bound_model_preserves_none_context(self._bound): + return None + _supplied_fields, required_dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=True) + if required_dynamic_inputs: + return None + return self._bound._rewrite_context(_bound_model_ambient_context(self._bound, {})) + return self._bound._rewrite_context(context) + + def _dependency_specs_for_inspect( + self, + _dependency_context: Optional[ContextBase], + argument_context: Optional[ContextBase], + *, + preserve_ambient_context: bool = False, + ) -> Tuple[DependencySpec, ...]: + return _direct_dependency_specs(self._bound.model, argument_context, trim_context=not preserve_ambient_context) + @property - def bound_inputs(self) -> Dict[str, Any]: + def _bound_inputs(self) -> Dict[str, Any]: """Concrete values already fixed, including statically resolved context bindings.""" - result = super().bound_inputs - for name in self.context_inputs: + result = super()._bound_inputs + for name in self._context_inputs: result.pop(name, None) result.update(_statically_resolved_context_field_values(self._bound.model, self._bound.context_spec)) return result @property - def context_inputs(self) -> Dict[str, Any]: + def _context_inputs(self) -> Dict[str, Any]: """Declared contextual inputs of the wrapped model.""" - return super().context_inputs + return super()._context_inputs @property - def runtime_inputs(self) -> Dict[str, Any]: + def _context_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of wrapped-model contextual inputs after bindings.""" + + result = dict(super()._context_argument_specs) + for name, value in _statically_resolved_context_field_values(self._bound.model, self._bound.context_spec).items(): + if name in result: + spec = result[name] + result[name] = InputSpec(name, spec.type, False, spec.default, value, "with_context") + + supplied_fields, _dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=False) + for name in supplied_fields: + if name in result: + spec = result[name] + result[name] = InputSpec(name, spec.type, False, spec.default, _UNSET_FLOW_INPUT, "context_transform") + return result + + def _context_argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + result = dict(self._context_argument_specs) + if context is None: + return result + values = _context_values(context) + for name, spec in list(result.items()): + if name not in values: + continue + source = spec.source if spec.source in {"context_transform", "with_context"} else "runtime" + result[name] = InputSpec(name, spec.type, False, spec.default, values[name], source) + return result + + @property + def _runtime_inputs(self) -> Dict[str, Any]: """Direct runtime context inputs after applying this wrapper's bindings. Static context transforms may be evaluated to identify resolved fields. Dynamic transforms contribute their own runtime context inputs. """ - result = super().context_inputs + result = super()._context_inputs for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): result.pop(name, None) supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=False) @@ -2035,14 +2572,36 @@ def runtime_inputs(self) -> Dict[str, Any]: return result @property - def required_inputs(self) -> Dict[str, Any]: + def _runtime_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of runtime inputs after applying this wrapper's bindings.""" + + result = {} + base_specs = self._context_argument_specs + for name, annotation in self._runtime_inputs.items(): + base = base_specs.get(name) + result[name] = InputSpec( + name, + annotation, + name in self._required_inputs, + base.default if base is not None else _UNSET_FLOW_INPUT, + base.value if base is not None else _UNSET_FLOW_INPUT, + base.source if base is not None else "runtime", + ) + _supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=False) + required_dynamic = set(_dynamic_context_operation_effects(self._bound.context_spec, required_only=True)[1]) + for name, annotation in dynamic_inputs.items(): + result[name] = InputSpec(name, annotation, name in required_dynamic, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "context_transform") + return result + + @property + def _required_inputs(self) -> Dict[str, Any]: """Required direct runtime context inputs still missing after static bindings. Static context transforms may be evaluated to identify resolved fields. Dynamic transforms contribute their own required runtime context inputs. """ - result = super().required_inputs + result = super()._required_inputs for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): result.pop(name, None) supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=True) @@ -2122,10 +2681,6 @@ def context_type(self) -> Type[ContextBase]: def result_type(self) -> Type[ResultBase]: return self.__class__.__flow_model_config__.result_type - @property - def flow(self) -> FlowAPI: - return FlowAPI(self) - def _evaluation_identity_payload( self, context: ContextBase, @@ -2440,6 +2995,30 @@ def flow_model( ``FlowContext``, a declared ``context_type``, ``compute(...)`` kwargs, or ``with_context(...)`` bindings. The returned object is a factory that creates instances of the generated model class. + + Args: + func: The function being decorated. This is passed automatically in + bare decorator form, for example ``@Flow.model``. When using + options, for example ``@Flow.model(auto_unwrap=True)``, Python first + calls ``Flow.model(...)`` without a function and then applies the + returned decorator. + context_type: Optional ``ContextBase`` subclass used to validate all + contextual inputs together after individual ``FromContext[...]`` + fields are resolved. + auto_unwrap: When ``True`` and ccflow auto-wraps a plain return + annotation in ``GenericResult[T]``, external + ``model.flow.compute(...)`` calls return the raw ``T`` value instead + of ``GenericResult[T]``. Dependency evaluation and direct model calls + keep the normal ccflow result contract. + model_base: Custom ``CallableModel`` subclass to use as the base class + for the generated model. + cacheable: Optional generated-model default for ``FlowOptions.cacheable``. + volatile: Optional generated-model default for ``FlowOptions.volatile``. + log_level: Optional generated-model default for ``FlowOptions.log_level``. + validate_result: Optional generated-model default for + ``FlowOptions.validate_result``. + verbose: Optional generated-model default for ``FlowOptions.verbose``. + evaluator: Optional generated-model default evaluator. """ def decorator(fn: _AnyCallable) -> _AnyCallable: diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index d8c98bb..506e8f6 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -116,10 +116,10 @@ def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: return a + b + c model = add(a=10) - assert model.flow.context_inputs == {"b": int, "c": int} - assert model.flow.runtime_inputs == {"b": int, "c": int} - assert model.flow.required_inputs == {"b": int} - assert model.flow.bound_inputs == {"a": 10} + assert model.flow.inspect().context_inputs == {"b": int, "c": int} + assert model.flow.inspect().runtime_inputs == {"b": int, "c": int} + assert model.flow.inspect().required_inputs == {"b": int} + assert model.flow.inspect().bound_inputs == {"a": 10} assert model.flow.compute(b=2).value == 17 @@ -132,7 +132,7 @@ def add(a: int, b: FromContext[int]) -> int: assert model.flow.compute(b=5).value == 15 assert model.flow.compute(FlowContext(b=6)).value == 16 - with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + with pytest.raises(TypeError, match="either one context object or contextual keyword inputs"): model.flow.compute(FlowContext(b=5), b=6) @@ -202,8 +202,8 @@ def combine(left: int, right: int, bonus: FromContext[int]) -> int: right=base.flow.with_context(value=offset_value(amount=10)), ) - assert model.flow.context_inputs == {"bonus": int} - assert model.flow.runtime_inputs == {"bonus": int} + assert model.flow.inspect().context_inputs == {"bonus": int} + assert model.flow.inspect().runtime_inputs == {"bonus": int} assert model.flow.compute(value=5, bonus=100).value == (6 + 15 + 100) @@ -218,12 +218,12 @@ def from_seed(seed: FromContext[int]) -> int: bound = add(a=10).flow.with_context(b=from_seed()) - assert bound.flow.context_inputs == {"b": int, "c": int} - assert bound.flow.runtime_inputs == {"c": int, "seed": int} - assert bound.flow.required_inputs == {"seed": int} - for name, annotation in bound.flow.required_inputs.items(): - assert bound.flow.runtime_inputs[name] == annotation - assert bound.flow.bound_inputs == {"a": 10} + assert bound.flow.inspect().context_inputs == {"b": int, "c": int} + assert bound.flow.inspect().runtime_inputs == {"c": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"seed": int} + for name, annotation in bound.flow.inspect().required_inputs.items(): + assert bound.flow.inspect().runtime_inputs[name] == annotation + assert bound.flow.inspect().bound_inputs == {"a": 10} assert bound.flow.compute(seed=1).value == 17 @@ -243,7 +243,7 @@ def from_str(seed: FromContext[str]) -> int: bound = add().flow.with_context(b=from_int(), c=from_str()) with pytest.raises(TypeError, match="Conflicting runtime context annotations for 'seed'"): - bound.flow.runtime_inputs + bound.flow.inspect().runtime_inputs def test_bound_flow_api_keeps_dynamic_transform_source_inputs_after_later_field_bindings(): @@ -261,9 +261,9 @@ def from_x(x: FromContext[int]) -> int: bound = add().flow.with_context(x=from_y()).flow.with_context(y=from_x()) - assert bound.flow.context_inputs == {"x": int, "y": int} - assert bound.flow.runtime_inputs == {"x": int, "y": int} - assert bound.flow.required_inputs == {"x": int, "y": int} + assert bound.flow.inspect().context_inputs == {"x": int, "y": int} + assert bound.flow.inspect().runtime_inputs == {"x": int, "y": int} + assert bound.flow.inspect().required_inputs == {"x": int, "y": int} assert bound.flow.compute(x=1, y=2).value == 14 @@ -278,9 +278,9 @@ def seed_plus_one(seed: FromContext[int] = 0) -> int: bound = add().flow.with_context(b=seed_plus_one()) - assert bound.flow.runtime_inputs == {"a": int, "seed": int} - assert bound.flow.required_inputs == {"a": int} - assert bound.flow.bound_inputs == {} + assert bound.flow.inspect().runtime_inputs == {"a": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"a": int} + assert bound.flow.inspect().bound_inputs == {} assert bound.flow.compute(a=10).value == 11 assert bound.flow.compute(a=10, seed=5).value == 16 @@ -318,7 +318,7 @@ def add(a: int, b: FromContext[int]) -> int: restored = type(bound).model_validate(dumped) assert restored.flow.compute().value == 15 - assert restored.model.flow.bound_inputs == {"a": 10} + assert restored.model.flow.inspect().bound_inputs == {"a": 10} def test_bound_model_json_roundtrip_preserves_context_transforms(): @@ -413,10 +413,10 @@ def test_regular_callable_models_still_support_with_context(): def test_flow_api_for_regular_callable_model(): model = OffsetModel(offset=10) assert model.flow.compute(x=5).value == 15 - assert model.flow.context_inputs == {"x": int} - assert model.flow.runtime_inputs == {"x": int} - assert model.flow.required_inputs == {"x": int} - assert model.flow.bound_inputs == {"offset": 10} + assert model.flow.inspect().context_inputs == {"x": int} + assert model.flow.inspect().runtime_inputs == {"x": int} + assert model.flow.inspect().required_inputs == {"x": int} + assert model.flow.inspect().bound_inputs == {"offset": 10} def test_generated_flow_model_compute_is_thread_safe(): diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 261977a..a6810d7 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -312,7 +312,7 @@ def __call__(self, context: Optional[SimpleContext] = None) -> GenericResult[int assert OptionalContextModel().flow.compute(None).value == 0 assert OptionalContextModel().flow.compute().value == 0 - assert OptionalContextModel().flow.required_inputs == {} + assert OptionalContextModel().flow.inspect().required_inputs == {} bound = OptionalContextModel().flow.with_context() assert bound.flow.compute(FlowContext(value=3)).value == 3 @@ -498,9 +498,9 @@ def foo(a: int, b: FromContext[int]) -> int: return a + b model = foo(a=11, b=12) - assert model.flow.bound_inputs == {"a": 11, "b": 12} - assert model.flow.context_inputs == {"b": int} - assert model.flow.required_inputs == {} + assert model.flow.inspect().bound_inputs == {"a": 11, "b": 12} + assert model.flow.inspect().context_inputs == {"b": int} + assert model.flow.inspect().required_inputs == {} assert model.flow.compute().value == 23 @@ -510,9 +510,9 @@ def foo(a: int, b: FromContext[int] = 5) -> int: return a + b model = foo(a=2) - assert model.flow.bound_inputs == {"a": 2} - assert model.flow.context_inputs == {"b": int} - assert model.flow.required_inputs == {} + assert model.flow.inspect().bound_inputs == {"a": 2} + assert model.flow.inspect().context_inputs == {"b": int} + assert model.flow.inspect().required_inputs == {} assert model.flow.compute().value == 7 assert model.flow.compute(b=10).value == 12 @@ -537,8 +537,8 @@ def choose(mode: FromContext[str]) -> str: model = choose() - assert model.flow.context_inputs == {"mode": Literal["a"]} - assert model.flow.required_inputs == {"mode": Literal["a"]} + assert model.flow.inspect().context_inputs == {"mode": Literal["a"]} + assert model.flow.inspect().required_inputs == {"mode": Literal["a"]} assert model.flow.compute(mode="a").value == "a" with pytest.raises(ValidationError): model.flow.compute(mode="b") @@ -636,8 +636,8 @@ def loader(context: DateRangeContext, source: str = "db") -> GenericResult[str]: return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") model = loader(context=DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)), source="api") - assert model.flow.bound_inputs["context"] == DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)) - assert model.flow.context_inputs == {} + assert model.flow.inspect().bound_inputs["context"] == DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)) + assert model.flow.inspect().context_inputs == {} assert model.flow.compute().value == "api:2024-01-01:2024-01-02" with pytest.raises(TypeError, match="Missing regular parameter\\(s\\) for loader: context"): @@ -819,8 +819,8 @@ def mixed(context: SimpleContext, y: FromContext[int]) -> int: return context.value + y model = mixed(context=SimpleContext(value=10)) - assert model.flow.bound_inputs == {"context": SimpleContext(value=10)} - assert model.flow.context_inputs == {"y": int} + assert model.flow.inspect().bound_inputs == {"context": SimpleContext(value=10)} + assert model.flow.inspect().context_inputs == {"y": int} assert model.flow.compute(y=5).value == 15 @@ -933,8 +933,8 @@ def add(a: int, b: FromContext[int]) -> int: dumped = model.model_dump(mode="python") restored = type(model).model_validate(dumped) - assert restored.flow.bound_inputs == {"a": 10} - assert restored.flow.required_inputs == {"b": int} + assert restored.flow.inspect().bound_inputs == {"a": 10} + assert restored.flow.inspect().required_inputs == {"b": int} assert restored.flow.compute(b=5).value == 15 @@ -1757,7 +1757,7 @@ def source(a: FromContext[int], b: FromContext[int]) -> int: bound = source().flow.with_context(patch_a()).flow.with_context(b=b_from_a()) assert bound.flow.compute(a=10).value == 15 - assert bound.flow.required_inputs == {"a": int} + assert bound.flow.inspect().required_inputs == {"a": int} def test_chained_with_context_later_field_override_skips_dead_field_transform(): @@ -1767,9 +1767,9 @@ def source(a: FromContext[int]) -> int: bound = source().flow.with_context(a=seed_plus_one()).flow.with_context(a=1) - assert bound.flow.bound_inputs == {"a": 1} - assert bound.flow.runtime_inputs == {} - assert bound.flow.required_inputs == {} + assert bound.flow.inspect().bound_inputs == {"a": 1} + assert bound.flow.inspect().runtime_inputs == {} + assert bound.flow.inspect().required_inputs == {} assert bound.flow.compute().value == 1 for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): @@ -1935,12 +1935,475 @@ def add(a: int, b: FromContext[int]) -> int: return a + b model = add(a=10) - assert model.flow.context_inputs == {"b": int} - assert model.flow.bound_inputs == {"a": 10} - assert model.flow.required_inputs == {"b": int} + assert model.flow.inspect().context_inputs == {"b": int} + assert model.flow.inspect().bound_inputs == {"a": 10} + assert model.flow.inspect().required_inputs == {"b": int} assert model.flow.compute(b=5).value == 15 +def test_flow_api_is_self_describing_for_interactive_sessions(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + flow = add(a=1).flow + + assert dir(flow) == ["compute", "inspect", "with_context"] + assert "inspect" in dir(flow) + assert "input_specs" not in dir(flow) + assert "argument_specs" not in dir(flow) + assert "inspect" in repr(flow) + + inspect_signature = inspect.signature(flow.inspect) + assert "inputs" not in inspect_signature.parameters + assert inspect_signature.parameters["dependencies"].annotation == Literal["direct", "recursive", "none"] + assert inspect_signature.return_annotation is flow_model_module.FlowInspection + + +def test_flow_api_completes_after_binding_property_value_in_ipython(): + pytest.importorskip("IPython") + from IPython.core.interactiveshell import InteractiveShell + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + shell = InteractiveShell.instance() + shell.user_ns["flow_model_completion_target"] = add(a=1) + shell.user_ns["flow_model_completion_flow"] = shell.user_ns["flow_model_completion_target"].flow + + matches = shell.Completer.attr_matches("flow_model_completion_flow.") + assert "flow_model_completion_flow.compute" in matches + assert "flow_model_completion_flow.inspect" in matches + assert "flow_model_completion_flow.input_specs" not in matches + assert "flow_model_completion_flow.argument_specs" not in matches + assert "flow_model_completion_flow.help" not in matches + assert "flow_model_completion_flow.validate_inputs" not in matches + + +def test_flow_inspect_inputs_include_defaults_and_sources(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + model = add(a=10) + inspection = model.flow.inspect() + + assert inspection.inputs["a"] == flow_model_module.InputSpec("a", int, False, flow_model_module._UNSET_FLOW_INPUT, 10, "construction") + assert inspection.inputs["b"] == flow_model_module.InputSpec( + "b", int, True, flow_model_module._UNSET_FLOW_INPUT, flow_model_module._UNSET_FLOW_INPUT, "runtime" + ) + assert inspection.inputs["c"] == flow_model_module.InputSpec("c", int, False, 5, 5, "function_default") + assert inspection.inputs["b"].required + assert not inspection.inputs["c"].required + assert model.flow.inspect(b=2).inputs["b"] == flow_model_module.InputSpec("b", int, False, flow_model_module._UNSET_FLOW_INPUT, 2, "runtime") + + +def test_flow_inspect_inputs_keep_dependency_value_but_use_compact_repr(): + @Flow.model + def add(a: int, b: int) -> int: + return a + b + + child = add(a=1, b=1) + model = add(a=1, b=child) + + spec = model.flow.inspect().inputs["b"] + assert spec.value is child + assert spec.value_repr == "" + assert "meta=" not in repr(spec) + assert "value=" in repr(spec) + + +def test_flow_inspect_with_runtime_values_is_structural_not_a_validator(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + + inspection = model.flow.inspect(b=2, unused=1) + assert inspection.inputs["b"].value == 2 + assert inspection.inputs["b"].source == "runtime" + assert "unused" not in inspection.inputs + + regular_inspection = model.flow.inspect(a=1, b=2) + assert regular_inspection.inputs["b"].value == 2 + assert regular_inspection.inputs["b"].source == "runtime" + assert regular_inspection.inputs["a"].value == 10 + assert "input check" not in repr(regular_inspection) + + missing_inspection = model.flow.inspect(FlowContext()) + assert missing_inspection.inputs["b"].required + assert flow_model_module._is_unset_flow_input(missing_inspection.inputs["b"].value) + + +def test_flow_inspect_reports_direct_dependencies_and_unused_context(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 2 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=source()) + explanation = model.flow.inspect(value=3, bonus=4, unused=5) + + assert explanation.inputs["x"].value is model.flow.inspect().inputs["x"].value + assert explanation.required_inputs == {"bonus": int} + assert len(explanation.dependencies) == 1 + assert explanation.dependencies[0].path == "x" + assert explanation.dependencies[0].context == FlowContext(value=3) + child_inspection = explanation.dependencies[0].model.flow.inspect(explanation.dependencies[0].context) + assert child_inspection.runtime_inputs == {"value": int} + assert child_inspection.inputs["value"].value == 3 + assert "dependencies" in str(explanation) + assert repr(explanation) == str(explanation) + assert "FlowInspection(model=_root_Model)" in repr(explanation) + assert "inputs:" in repr(explanation) + assert "x -> _source_Model context=FlowContext(value=3)" in repr(explanation) + + assert model.flow.inspect(dependencies="none").dependencies == () + assert model.flow.inspect().runtime_inputs == {"bonus": int} + assert set(model.flow.inspect().bound_inputs) == {"x"} + with pytest.raises(ValueError, match="dependencies must be one of"): + model.flow.inspect(dependencies="full") + + +def test_flow_with_context_returns_new_bound_model_without_mutating_source(): + @Flow.model(auto_unwrap=False) + def add(x: int, y: int, z: FromContext[int] = 2) -> int: + return x + y + z + + @Flow.context_transform() + def shift_1(z: FromContext[int]) -> int: + return z + 1 + + model = add(x=1, y=add(x=1, y=1)) + bound = model.flow.with_context(z=shift_1()) + + assert model.flow.compute(z=2).value == 7 + assert bound.flow.compute(z=2).value == 9 + assert model.flow.inspect().bound_inputs.keys() == {"x", "y"} + assert bound.flow.inspect().runtime_inputs == {"z": int} + + +def test_bound_model_inspect_reports_wrapped_argument_dependencies(): + @Flow.model(auto_unwrap=False) + def add(x: int, y: int, z: FromContext[int] = 2) -> int: + return x + y + z + + @Flow.context_transform() + def shift_1(z: FromContext[int]) -> int: + return z + 1 + + bound = add(x=1, y=add(x=1, y=1)).flow.with_context(z=shift_1()) + + inspection = bound.flow.inspect(z=2) + + assert "FlowInspection(model=_add_Model.flow.with_context(...))" in repr(inspection) + assert inspection.inputs["z"] == flow_model_module.InputSpec("z", int, False, 2, 3, "context_transform") + assert len(inspection.dependencies) == 1 + assert inspection.dependencies[0].path == "y" + assert inspection.dependencies[0].context == FlowContext(z=3) + assert "" not in repr(inspection) + + root = add(x=1, y=bound) + root_inspection = root.flow.inspect() + assert root_inspection.required_inputs == {} + assert root_inspection.dependencies[0].model.flow.inspect().required_inputs == {"z": int} + + root_with_context = root.flow.inspect(z=2) + child_with_context = root_with_context.dependencies[0].model.flow.inspect(root_with_context.dependencies[0].context) + assert child_with_context.runtime_inputs == {"z": int} + assert child_with_context.inputs["z"].value == 3 + assert "context=FlowContext(z=2)" in repr(root_with_context) + + +def test_flow_inspect_direct_dependencies_do_not_traverse_grandchildren(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.model + def root(x: int) -> int: + return x + + model = root(x=middle(x=leaf())) + + root_inspection = model.flow.inspect() + child_inspection = root_inspection.dependencies[0].model.flow.inspect() + + assert root_inspection.dependencies[0].path == "x" + assert child_inspection.dependencies[0].model.flow.inspect().required_inputs == {"v": int} + + +def test_flow_inspect_recursive_dependencies_report_grandchild_requirements(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.model + def root(x: int) -> int: + return x + + model = root(x=middle(x=leaf())) + + missing = model.flow.inspect(dependencies="recursive") + + assert tuple(dependency.path for dependency in missing.dependencies) == ("x", "x.x") + assert missing.dependencies[1].model.flow.inspect().required_inputs == {"v": int} + + supplied = model.flow.inspect(v=2, dependencies="recursive") + + assert supplied.inputs == model.flow.inspect(v=2, dependencies="none").inputs + assert tuple(dependency.path for dependency in model.flow.inspect(v=2, dependencies="direct").dependencies) == ("x",) + assert tuple(dependency.context for dependency in supplied.dependencies) == (FlowContext(v=2), FlowContext(v=2)) + assert "v" not in supplied.inputs + leaf_with_context = supplied.dependencies[1].model.flow.inspect(supplied.dependencies[1].context) + assert leaf_with_context.inputs["v"].value == 2 + + +def test_flow_inspect_recursive_dependencies_treat_bound_transform_inputs_as_used(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.context_transform() + def shift(seed: FromContext[int]) -> int: + return seed + 1 + + model = middle(x=leaf().flow.with_context(v=shift())) + + missing = model.flow.inspect(dependencies="recursive") + assert missing.dependencies[0].model.flow.inspect().required_inputs == {"seed": int} + + supplied = model.flow.inspect(seed=1, dependencies="recursive") + assert tuple(dependency.path for dependency in supplied.dependencies) == ("x",) + assert flow_model_module._context_values(supplied.dependencies[0].context) == {"seed": 1} + child = supplied.dependencies[0].model.flow.inspect(supplied.dependencies[0].context) + assert child.inputs["v"].value == 2 + + +def test_flow_inspect_projects_bound_dependency_context_to_runtime_inputs(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + @Flow.context_transform() + def shift(seed: FromContext[int]) -> int: + return seed + 1 + + model = root(x=leaf().flow.with_context(v=shift())) + inspection = model.flow.inspect(seed=1, bonus=10, unused=99) + + assert flow_model_module._context_values(inspection.dependencies[0].context) == {"seed": 1} + child = inspection.dependencies[0].model.flow.inspect(inspection.dependencies[0].context) + assert child.inputs["v"].value == 2 + + +def test_flow_inspect_dependency_requirements_skip_lazy_edges(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def root(x: Lazy[int]) -> int: + return 0 + + inspection = root(x=leaf()).flow.inspect(dependencies="recursive") + + assert inspection.dependencies[0].lazy + assert inspection.dependencies[0].model.flow.inspect().required_inputs == {"v": int} + assert "x -> _leaf_Model lazy" in repr(inspection) + + +def test_flow_inspect_dependency_requirements_skip_lazy_descendants(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.model + def root(x: Lazy[int]) -> int: + return 0 + + model = root(x=middle(x=leaf())) + inspection = model.flow.inspect(dependencies="recursive") + + assert model.flow.compute().value == 0 + assert tuple(dependency.path for dependency in inspection.dependencies) == ("x", "x.x") + assert all(dependency.lazy for dependency in inspection.dependencies) + assert inspection.dependencies[1].model.flow.inspect().required_inputs == {"v": int} + + +def test_flow_inspect_recursive_dependencies_tolerates_partial_transform_context(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.context_transform() + def shift(seed: FromContext[int], other: FromContext[int]) -> int: + return seed + other + + model = middle(x=leaf().flow.with_context(v=shift())) + + inspection = model.flow.inspect(seed=1, dependencies="recursive") + + assert tuple(dependency.path for dependency in inspection.dependencies) == ("x",) + child_inspection = inspection.dependencies[0].model.flow.inspect(inspection.dependencies[0].context) + assert child_inspection.runtime_inputs == {"seed": int, "other": int} + assert child_inspection.inputs["v"].source == "context_transform" + assert flow_model_module._is_unset_flow_input(child_inspection.inputs["v"].value) + + +def test_bound_flow_inspect_tolerates_partial_transform_context_without_dependencies(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.context_transform() + def shift(seed: FromContext[int], other: FromContext[int]) -> int: + return seed + other + + bound = leaf().flow.with_context(v=shift()) + + inspection = bound.flow.inspect(seed=1, dependencies="none") + + assert inspection.runtime_inputs == {"seed": int, "other": int} + assert inspection.required_inputs == {"seed": int, "other": int} + assert inspection.inputs["v"].source == "context_transform" + assert flow_model_module._is_unset_flow_input(inspection.inputs["v"].value) + assert inspection.dependencies == () + + +def test_flow_inspect_tolerates_partial_plain_callable_dependency_context(): + class PlainSource(CallableModel): + @property + def context_type(self): + return SimpleContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + inspection = root(x=PlainSource()).flow.inspect(bonus=1) + + assert len(inspection.dependencies) == 1 + assert inspection.dependencies[0].context == FlowContext() + assert inspection.dependencies[0].model.flow.inspect().required_inputs == {"value": int} + + +def test_plain_callable_flow_inspect_reports_partial_runtime_context(): + class PairContext(ContextBase): + a: int + b: int + + class PlainAdder(CallableModel): + @property + def context_type(self): + return PairContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: PairContext) -> GenericResult[int]: + return GenericResult(value=context.a + context.b) + + inspection = PlainAdder().flow.inspect(a=1) + + assert inspection.inputs["a"].value == 1 + assert flow_model_module._is_unset_flow_input(inspection.inputs["b"].value) + + +def test_bound_flow_inspect_uses_static_context_for_dependency_requirements(): + @Flow.model(auto_unwrap=False) + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model(auto_unwrap=False) + def root(x: int, v: FromContext[int] = 0) -> int: + return x + 1 + + bound = root(x=leaf()).flow.with_context(v=1) + inspection = bound.flow.inspect() + + assert bound.flow.compute().value == 2 + assert inspection.dependencies[0].context == FlowContext(v=1) + assert inspection.dependencies[0].model.flow.inspect(inspection.dependencies[0].context).inputs["v"].value == 1 + + +def test_flow_inspect_recursive_dependencies_do_not_swallow_transform_type_errors(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def root(x: int) -> int: + return x + + @Flow.context_transform() + def broken(seed: FromContext[int]) -> int: + raise TypeError("transform bug") + + model = root(x=leaf().flow.with_context(v=broken())) + + with pytest.raises(TypeError, match="transform bug"): + model.flow.inspect(seed=1, dependencies="recursive") + + +def test_dependency_evaluation_preserves_original_exception_type_with_context_note(): + class CustomDependencyError(RuntimeError): + pass + + @Flow.model + def child() -> int: + raise CustomDependencyError("boom") + + @Flow.model + def root(x: int) -> int: + return x + + with pytest.raises(CustomDependencyError) as exc_info: + root(x=child()).flow.compute() + + if hasattr(exc_info.value, "__notes__"): + assert exc_info.value.__notes__ == ["Error while evaluating dependency root.x -> _child_Model."] + + def test_generated_factory_signature_is_keyword_only_and_includes_model_base_fields(): sig = inspect.signature(basic_loader) @@ -2005,15 +2468,34 @@ def __deps__(self, context: SimpleContext): model = PlainModel() - assert model.flow.context_inputs == {"value": int} - assert model.flow.required_inputs == {"value": int} - assert model.flow.bound_inputs == {} + assert dir(model.flow) == ["compute", "inspect", "with_context"] + assert not hasattr(model.flow, "context_inputs") + assert not hasattr(model.flow, "runtime_inputs") + assert not hasattr(model.flow, "required_inputs") + assert not hasattr(model.flow, "bound_inputs") + assert model.flow.inspect().context_inputs == {"value": int} + assert model.flow.inspect().required_inputs == {"value": int} + assert model.flow.inspect().bound_inputs == {} assert model.flow.compute({"value": 3}).value == 3 - with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + with pytest.raises(TypeError, match="either one context object or contextual keyword inputs"): model.flow.compute(SimpleContext(value=1), value=2) +def test_plain_callable_flow_api_is_base_property(): + class PlainModel(CallableModel): + offset: int = 7 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.offset + context.value) + + model = PlainModel() + + assert dir(model.flow) == ["compute", "inspect", "with_context"] + assert model(SimpleContext(value=3)).value == 10 + + def test_plain_callable_flow_compute_preserves_matching_context_subclass(): class RequestContext(SimpleContext): request_id: str @@ -2052,17 +2534,17 @@ def __call__(self, context: DefaultContext = DefaultContext(value=7)) -> Generic model = PlainModel() - assert model.flow.required_inputs == {} + assert model.flow.inspect().required_inputs == {} assert model.flow.compute().value == (7, "default") assert model.flow.compute(value=3).value == (3, "default") assert model.flow.compute(tag="runtime").value == (7, "runtime") empty_bound = model.flow.with_context() - assert empty_bound.flow.required_inputs == {} + assert empty_bound.flow.inspect().required_inputs == {} assert empty_bound.flow.compute().value == (7, "default") bound = model.flow.with_context(tag="bound") - assert bound.flow.required_inputs == {} + assert bound.flow.inspect().required_inputs == {} assert bound.flow.compute().value == (7, "bound") assert bound.flow.compute(value=3).value == (3, "bound") @@ -2104,8 +2586,8 @@ def from_seed(seed: FromContext[int]) -> int: bound = PlainModel().flow.with_context(value=from_seed()) - assert bound.flow.required_inputs == {} - assert bound.flow.runtime_inputs == {"seed": int} + assert bound.flow.inspect().required_inputs == {} + assert bound.flow.inspect().runtime_inputs == {"seed": int} assert bound.flow.compute().value == (9, 8) assert bound.flow.compute(seed=10).value == (11, 10) @@ -2240,12 +2722,12 @@ def dep(value: FromContext[int]) -> int: def test_compute_accepts_context_object_for_from_context_models(): model = basic_loader(source="library", multiplier=3) - assert model.flow.context_inputs == {"value": int} - assert model.flow.required_inputs == {"value": int} + assert model.flow.inspect().context_inputs == {"value": int} + assert model.flow.inspect().required_inputs == {"value": int} assert model.flow.compute({"value": 4}).value == 12 assert model.flow.compute(SimpleContext(value=5)).value == 15 - with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + with pytest.raises(TypeError, match="either one context object or contextual keyword inputs"): model.flow.compute(SimpleContext(value=1), value=2) @@ -2757,11 +3239,11 @@ def __call__(self, context: RequiredContext) -> GenericResult[int]: def add(a: FromContext[int], b: FromContext[int]) -> int: return a + b - assert PlainSource().flow.with_context(a=1).flow.required_inputs == {"b": int} - assert add().flow.with_context(a=1).flow.required_inputs == {"b": int} - assert add().flow.with_context(a=static_bad()).flow.required_inputs == {"b": int} - assert add().flow.with_context(static_patch()).flow.required_inputs == {"b": int} - assert add().flow.with_context(a=1, b=2).flow.required_inputs == {} + assert PlainSource().flow.with_context(a=1).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(a=1).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(a=static_bad()).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(static_patch()).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(a=1, b=2).flow.inspect().required_inputs == {} def test_bound_flow_required_inputs_reflects_dynamic_field_transform_inputs(): @@ -2772,9 +3254,9 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: bound = add().flow.with_context(a=seed_plus_one()) assert bound.flow.compute(seed=1, b=10).value == 12 - assert bound.flow.context_inputs == {"a": int, "b": int} - assert bound.flow.runtime_inputs == {"b": int, "seed": int} - assert bound.flow.required_inputs == {"b": int, "seed": int} + assert bound.flow.inspect().context_inputs == {"a": int, "b": int} + assert bound.flow.inspect().runtime_inputs == {"b": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"b": int, "seed": int} def test_bound_flow_bound_inputs_include_static_context_bindings(): @@ -2784,10 +3266,10 @@ def add(a: int, b: FromContext[int]) -> int: bound = add(a=1).flow.with_context(b=2) - assert bound.flow.context_inputs == {"b": int} - assert bound.flow.runtime_inputs == {} - assert bound.flow.required_inputs == {} - assert bound.flow.bound_inputs == {"a": 1, "b": 2} + assert bound.flow.inspect().context_inputs == {"b": int} + assert bound.flow.inspect().runtime_inputs == {} + assert bound.flow.inspect().required_inputs == {} + assert bound.flow.inspect().bound_inputs == {"a": 1, "b": 2} def test_bound_flow_bound_inputs_drops_static_patch_after_dynamic_override(): @@ -2798,10 +3280,10 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: bound = add().flow.with_context(static_patch()).flow.with_context(a=seed_plus_one()) assert bound.flow.compute(seed=3, b=10).value == 14 - assert bound.flow.bound_inputs == {} - assert bound.flow.context_inputs == {"a": int, "b": int} - assert bound.flow.runtime_inputs == {"b": int, "seed": int} - assert bound.flow.required_inputs == {"b": int, "seed": int} + assert bound.flow.inspect().bound_inputs == {} + assert bound.flow.inspect().context_inputs == {"a": int, "b": int} + assert bound.flow.inspect().runtime_inputs == {"b": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"b": int, "seed": int} def test_generated_model_cache_ignores_unused_flow_context_fields(): @@ -3424,11 +3906,11 @@ def add(a: int, b: FromContext[int]) -> int: return a + b model = add(a=10, multiplier=3) - assert model.flow.bound_inputs == {"a": 10, "multiplier": 3} + assert model.flow.inspect().bound_inputs == {"a": 10, "multiplier": 3} # Default-only model_base field is NOT in bound_inputs model_default = add(a=10) - assert model_default.flow.bound_inputs == {"a": 10} + assert model_default.flow.inspect().bound_inputs == {"a": 10} def _annotation_contains(annotation: object, expected: object) -> bool: @@ -3504,7 +3986,12 @@ def test_flow_model_public_exports_exclude_context_spec_models(): assert "StaticValueSpec" not in flow_model_module.__all__ assert "ContextTransform" not in flow_model_module.__all__ assert "flow_context_transform" not in flow_model_module.__all__ + assert "DependencySpec" not in flow_model_module.__all__ + assert "InputCheck" not in flow_model_module.__all__ assert not hasattr(ccflow, "StaticValueSpec") assert not hasattr(ccflow, "ContextTransform") assert not hasattr(ccflow, "flow_context_transform") + assert not hasattr(ccflow, "DependencySpec") + assert not hasattr(ccflow, "InputCheck") assert not hasattr(flow_model_module, "flow_context_transform") + assert not hasattr(flow_model_module, "InputCheck") diff --git a/docs/wiki/Flow-Model.md b/docs/wiki/Flow-Model.md index 425e151..4bed310 100644 --- a/docs/wiki/Flow-Model.md +++ b/docs/wiki/Flow-Model.md @@ -9,7 +9,7 @@ The design is intentionally narrow: - `@Flow.context_transform` defines reusable contextual rewrites, - `.flow.compute(...)` is the execution entry point for the full DAG, - `.flow.with_context(*patches, **field_overrides)` rewires contextual inputs on one dependency edge, -- upstream `CallableModel`s can still be passed as ordinary arguments. +- upstream `CallableModel`s can still be passed as ordinary inputs. The goal is that a reader can look at one function signature and immediately see: @@ -125,7 +125,7 @@ Contextual parameters are the ones marked with `FromContext[...]`. They can be satisfied by: - runtime context, -- construction-time keyword arguments, stored as contextual defaults on the model instance, +- construction-time keyword inputs, stored as contextual defaults on the model instance, - keyword callable literals for `FromContext[Callable[..., T]]` fields, - function defaults. @@ -172,15 +172,15 @@ execution of the whole DAG. For generated `@Flow.model` stages it accepts either: -- keyword arguments that become the ambient runtime context bag, or +- keyword inputs that become the ambient runtime context bag, or - one context object. It does not accept both at the same time. Plain handwritten `CallableModel` instances also expose `.flow.compute(...)`. -For those models, keyword arguments build or update the runtime context. If the +For those models, keyword inputs build or update the runtime context. If the decorated `@Flow.call` method declares a default context object, no-argument -`.flow.compute()` uses that default, and keyword arguments override fields from +`.flow.compute()` uses that default, and keyword inputs override fields from that default for the `.flow.compute(...)` call. ```python @@ -200,7 +200,7 @@ assert model.flow.compute(FlowContext(b=6)).value == 16 For `@Flow.model` generated models, the kwargs form is intentionally a DAG entrypoint: it can include extra fields needed only by upstream transformed dependencies. Regular parameters are still never read from runtime context. -`compute()` enforces two guardrails on keyword arguments: +`compute()` enforces two guardrails on keyword inputs: - If a key matches an **unbound** regular parameter, it raises early instead of silently treating that value as configuration. @@ -233,13 +233,13 @@ model = add( right=base.flow.with_context(value=add_offset(amount=10)), ) -assert model.flow.context_inputs == {"bonus": int} +assert model.flow.inspect().context_inputs == {"bonus": int} assert model.flow.compute(value=5, bonus=100).value == 121 ``` If a regular parameter is already bound on the root model and you need to pass an ambient context field with the same name for upstream graph nodes, use a -context object instead of keyword arguments. The kwargs form rejects keys that +context object instead of keyword inputs. The kwargs form rejects keys that match already-bound regular parameters to prevent accidental rebinding: ```python @@ -291,9 +291,9 @@ determines how it can be used in `with_context()`: - **Patch transforms** return a `Mapping` (e.g. `dict[str, object]`) of contextual field names to replacement values. They are passed as **positional - arguments** to `with_context()`. + inputs** to `with_context()`. - **Field transforms** return a single scalar value. They are passed as - **keyword arguments** to `with_context()`, keyed by the contextual field they + **keyword inputs** to `with_context()`, keyed by the contextual field they replace. ```python @@ -353,7 +353,7 @@ always apply last. Key rules: - `with_context()` only targets contextual fields, -- positional arguments must be patch transforms, +- positional inputs must be patch transforms, - keyword overrides may be literals or field transforms, - raw positional callables are rejected; use named `@Flow.context_transform` helpers for positional patch transforms, @@ -408,38 +408,90 @@ For class-based `CallableModel` methods that want to declare context fields as keyword-only parameters, see `Flow.call(auto_context=...)` in [Workflows](Workflows#flow-decorator). -## Introspection APIs +## Introspection -Flow models expose a few useful introspection helpers: +The public `.flow` surface is intentionally small: -- `model.flow.context_inputs`: the declared contextual contract for the model - or wrapped model, -- `model.flow.runtime_inputs`: direct runtime context inputs this model or - wrapper may read after applying its own bindings, -- `model.flow.required_inputs`: required direct runtime context inputs that are - not already satisfied by defaults or bindings, -- `model.flow.bound_inputs`: concrete values already fixed on the model, such - as regular construction-time inputs, construction-time contextual defaults, - and literal keyword `with_context(field=value)` bindings. +- `model.flow.compute(...)`: evaluate the model, +- `model.flow.with_context(...)`: return a branch-local contextual wrapper, +- `model.flow.inspect(...)`: return a structured debugging summary. + +Use `inspect(...)` when you want to understand what is bound, what is still +contextual, and which direct dependencies are attached: + +```python +inspection = model.flow.inspect() + +inspection.inputs # direct function inputs and their sources +inspection.context_inputs # declared contextual contract +inspection.runtime_inputs # direct runtime inputs after wrapper bindings +inspection.required_inputs # required runtime inputs still needed +inspection.bound_inputs # construction/static values already fixed +inspection.dependencies # dependency edges, controlled by dependencies=... +``` + +The top-level input fields are intentionally current-level only. They describe +the model or wrapper you inspected, not a flattened view of the whole dependency +graph. `inspection.inputs` is a dict from that model's function input name to an +`InputSpec`. Each `InputSpec` reports the expected type, whether the input is +required, the declared default, the effective direct value if known, and whether +that value came from construction, a function default, runtime context, or +`with_context(...)`. + +Dependency information lives under `inspection.dependencies`. Each dependency +edge reports the parameter path, target model, projected context values when +known, and whether the edge is lazy. To inspect a child, inspect that child +model directly: + +```python +inspection = model.flow.inspect(value=3) +dependency = inspection.dependencies[0] + +if dependency.context is None: + child = dependency.model.flow.inspect() +else: + child = dependency.model.flow.inspect(dependency.context) + +child.inputs +child.runtime_inputs +child.required_inputs +``` + +This explicit nesting avoids merging unrelated models into one ambiguous input +namespace. It also avoids name collisions when multiple dependencies use the +same context field name for different branches. `context_inputs` intentionally stays faithful to the model's declared contract. For bound models, `with_context(...)` bindings are reflected in -`runtime_inputs`, `required_inputs`, and `bound_inputs`. Literal bindings -satisfy their target fields. Transform bindings with runtime inputs add those -source context inputs to the effective runtime view. Static transforms, meaning -transforms with no contextual parameters, may be evaluated during introspection -so their output fields can be reported precisely. A transform parameter like -`seed: FromContext[int] = 0` is still a runtime input; its default only means the -caller is not required to provide it. -`required_inputs` is always the required subset of `runtime_inputs`; if multiple -bindings expose the same runtime context field with conflicting annotations, -introspection raises an error instead of silently choosing one. - -These helpers report the direct API for the current model or wrapper. They do -not recursively expand every contextual input used by upstream dependencies in a -larger graph. - -Because these helpers may evaluate static `@Flow.context_transform` functions, +`runtime_inputs`, `required_inputs`, `bound_inputs`, and `inputs`. Literal +bindings satisfy their target fields. Transform bindings with runtime inputs add +those source context inputs to the effective runtime view. Static transforms, +meaning transforms with no contextual parameters, may be evaluated during +introspection so their output fields can be reported precisely. A transform +parameter like `seed: FromContext[int] = 0` is still a runtime input; its default +only means the caller is not required to provide it. + +`inspect(...)` reports the direct API for the current model or wrapper. By +default, it also inspects immediate dependencies: + +```python +model.flow.inspect(dependencies="direct") # default +model.flow.inspect(dependencies="recursive") # follow inspect-visible dependencies +model.flow.inspect(dependencies="none") # skip dependency inspection +``` + +Recursive inspection follows dependencies visible from constructed +`@Flow.model` inputs and `with_context(...)` wrappers. It is intended for +debugging generated-model trees, not as a complete evaluator graph. A +handwritten `CallableModel` can still appear as a dependency target when it is +bound to an `@Flow.model` regular input, but `inspect(...)` does not expand that +handwritten model's custom `CallableModel.__deps__` implementation. Recursive +mode changes which dependency edges are listed; it does not change the meaning +of the top-level input fields. Lazy dependency edges are listed as `lazy`; +inspect the lazy target directly if you want to see what it could require when +called. + +Because introspection may evaluate static `@Flow.context_transform` functions, context transforms should be deterministic, side-effect-free, and cheap. This is the same practical contract expected by cache identity and dependency analysis. @@ -455,10 +507,12 @@ def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: model = add(a=10) -assert model.flow.context_inputs == {"b": int, "c": int} -assert model.flow.runtime_inputs == {"b": int, "c": int} -assert model.flow.required_inputs == {"b": int} -assert model.flow.bound_inputs == {"a": 10} +inspection = model.flow.inspect() +assert inspection.context_inputs == {"b": int, "c": int} +assert inspection.runtime_inputs == {"b": int, "c": int} +assert inspection.required_inputs == {"b": int} +assert inspection.bound_inputs == {"a": 10} +assert set(inspection.inputs) == {"a", "b", "c"} @Flow.context_transform @@ -467,10 +521,11 @@ def from_seed(seed: FromContext[int]) -> int: bound = add(a=10).flow.with_context(b=from_seed()) -assert bound.flow.context_inputs == {"b": int, "c": int} -assert bound.flow.runtime_inputs == {"c": int, "seed": int} -assert bound.flow.required_inputs == {"seed": int} -assert bound.flow.bound_inputs == {"a": 10} +bound_inspection = bound.flow.inspect() +assert bound_inspection.context_inputs == {"b": int, "c": int} +assert bound_inspection.runtime_inputs == {"c": int, "seed": int} +assert bound_inspection.required_inputs == {"seed": int} +assert bound_inspection.bound_inputs == {"a": 10} ``` In the bound example, `b` remains in `context_inputs` because `add` still @@ -479,6 +534,20 @@ declares `b` as part of its contextual contract. It is absent from appears in `runtime_inputs` because the transform reads it from the caller's runtime context. +When you pass a proposed context object or runtime kwargs to `inspect(...)`, +inspection uses those values to fill known direct `inputs` and project context +onto dependency edges. It does not validate the proposed context or report +unused fields: + +```python +inspection = add(a=10).flow.inspect(b=2, unused=1) +assert inspection.inputs["b"].value == 2 +assert "unused" not in inspection.inputs +``` + +This keeps `inspect(...)` structural. A stricter debug-time input checker can be +added later with explicit current-model versus graph-wide semantics. + ## Lazy Dependencies `Lazy[T]` defers evaluation of an upstream dependency until the function body @@ -519,7 +588,7 @@ controls exactly when (and whether) the dependency executes. Use `@Flow.model` when: - the stage logic is naturally a plain function, -- you want ordinary arguments to look like ordinary Python function parameters, +- you want ordinary inputs to look like ordinary Python function parameters, - the contextual contract is small and explicit, - the main goal is easy graph authoring on top of existing ccflow machinery. @@ -549,11 +618,12 @@ That is expected. `context_inputs` reports the declared contextual contract of the model or wrapped model. It does not mean the current wrapper still requires the caller to provide that field. -Use `runtime_inputs` to see the effective direct runtime context inputs after -`with_context(...)` bindings. Use `required_inputs` to see what still must be -provided by the caller. Static transforms with no contextual parameters may be -evaluated during introspection, so their output fields can be removed from -`runtime_inputs` and `required_inputs` or added to `bound_inputs`. +Use `model.flow.inspect().runtime_inputs` to see the effective direct runtime +context inputs after `with_context(...)` bindings. Use +`model.flow.inspect().required_inputs` to see what still must be provided by the +caller. Static transforms with no contextual parameters may be evaluated during +introspection, so their output fields can be removed from `runtime_inputs` and +`required_inputs` or added to `bound_inputs`. **A shared dependency runs more than once**