diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 71263b5..30ffd9c 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -11,6 +11,7 @@ from .callable import * from .context import * from .enums import Enum +from .flow_model import * from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/_flow_model_binding.py b/ccflow/_flow_model_binding.py new file mode 100644 index 0000000..a375617 --- /dev/null +++ b/ccflow/_flow_model_binding.py @@ -0,0 +1,756 @@ +"""Shared signature and context-contract analysis for Flow authoring APIs. + +This module also owns the portable serialization format for an already-analyzed +``@Flow.model`` contract. That payload is used by generated-model pickle/Ray +restore and serialized context transforms so workers do not re-resolve function +annotations from the original defining scope. +""" + +import inspect +from dataclasses import dataclass, field +from functools import wraps +from types import FunctionType, UnionType +from typing import Annotated, Any, Callable, Dict, Literal, NamedTuple, Optional, Tuple, Type, Union, get_args, get_origin, get_type_hints + +from .base import ContextBase, ResultBase +from .context import FlowContext +from .local_persistence import create_ccflow_model +from .result import GenericResult + +_AnyCallable = Callable[..., Any] +_UNION_ORIGINS = (Union, UnionType) + + +class _InternalSentinel: + def __init__(self, name: str): + self._name = name + + def __repr__(self) -> str: + return self._name + + def __reduce__(self): + return (_get_internal_sentinel, (self._name,)) + + +def _get_internal_sentinel(name: str) -> _InternalSentinel: + return _INTERNAL_SENTINELS[name] + + +_INTERNAL_SENTINELS = { + "_UNSET": _InternalSentinel("_UNSET"), +} +_UNSET = _INTERNAL_SENTINELS["_UNSET"] +_RESERVED_FLOW_MODEL_PARAM_NAMES = frozenset({"flow", "meta", "context_type", "result_type", "type_"}) + + +class _LazyMarker: + def __repr__(self) -> str: + return "Lazy" + + +class _FromContextMarker: + def __repr__(self) -> str: + return "FromContext" + + +class FromContext: + """Marker used in ``@Flow.model`` signatures for runtime/contextual inputs.""" + + def __class_getitem__(cls, item): + return Annotated[item, _FromContextMarker()] + + +class Lazy: + """Lazy dependency marker used only as ``Lazy[T]`` in type annotations.""" + + def __new__(cls, *args, **kwargs): + raise TypeError("Lazy is an annotation marker; use Lazy[T] in @Flow.model signatures.") + + def __class_getitem__(cls, item): + return Annotated[item, _LazyMarker()] + + +@dataclass(frozen=True) +class _ParsedAnnotation: + base: Any + is_lazy: bool + is_from_context: bool + + +@dataclass(frozen=True) +class _FlowModelParam: + name: str + annotation: Any + is_contextual: bool + is_lazy: bool + has_function_default: bool + function_default: Any = _UNSET + context_validation_annotation: Any = _UNSET + + @property + def validation_annotation(self) -> Any: + if self.context_validation_annotation is not _UNSET: + return self.context_validation_annotation + return self.annotation + + +@dataclass(frozen=True) +class _FlowModelConfig: + func: _AnyCallable + return_annotation: Any + context_type: Type[ContextBase] + result_type: Type[ResultBase] + auto_wrap_result: bool + auto_unwrap: bool + parameters: Tuple[_FlowModelParam, ...] + declared_context_type: Optional[Type[ContextBase]] = None + _regular_params: Tuple[_FlowModelParam, ...] = field(init=False, repr=False) + _contextual_params: Tuple[_FlowModelParam, ...] = field(init=False, repr=False) + _regular_param_names: Tuple[str, ...] = field(init=False, repr=False) + _contextual_param_names: Tuple[str, ...] = field(init=False, repr=False) + _params_by_name: Dict[str, _FlowModelParam] = field(init=False, repr=False) + + def __post_init__(self) -> None: + regular = tuple(param for param in self.parameters if not param.is_contextual) + contextual = tuple(param for param in self.parameters if param.is_contextual) + object.__setattr__(self, "_regular_params", regular) + object.__setattr__(self, "_contextual_params", contextual) + object.__setattr__(self, "_regular_param_names", tuple(param.name for param in regular)) + object.__setattr__(self, "_contextual_param_names", tuple(param.name for param in contextual)) + object.__setattr__(self, "_params_by_name", {param.name: param for param in self.parameters}) + + @property + def regular_params(self) -> Tuple[_FlowModelParam, ...]: + return self._regular_params + + @property + def contextual_params(self) -> Tuple[_FlowModelParam, ...]: + return self._contextual_params + + @property + def regular_param_names(self) -> Tuple[str, ...]: + return self._regular_param_names + + @property + def contextual_param_names(self) -> Tuple[str, ...]: + return self._contextual_param_names + + @property + def context_input_types(self) -> Dict[str, Any]: + return {param.name: param.validation_annotation for param in self.contextual_params} + + @property + def context_required_names(self) -> Tuple[str, ...]: + return tuple(param.name for param in self.contextual_params if not param.has_function_default) + + def param(self, name: str) -> _FlowModelParam: + return self._params_by_name[name] + + +@dataclass(frozen=True) +class _AutoContextSpec: + signature: inspect.Signature + base_class: Type[ContextBase] + class_name: str + fields: Dict[str, Tuple[Any, Any]] + + +class _SerializedAnnotation(NamedTuple): + kind: str + value: Any + args: Tuple[Any, ...] = () + + +class _SerializedFlowModelParam(NamedTuple): + name: str + annotation: _SerializedAnnotation + is_contextual: bool + is_lazy: bool + has_function_default: bool + function_default: Any + context_validation_annotation: _SerializedAnnotation + + +class _SerializedFlowModelConfig(NamedTuple): + func: _AnyCallable + return_annotation: _SerializedAnnotation + context_type: _SerializedAnnotation + result_type: _SerializedAnnotation + auto_wrap_result: bool + auto_unwrap: bool + parameters: Tuple[_SerializedFlowModelParam, ...] + declared_context_type: _SerializedAnnotation + + +def _callable_name(func: _AnyCallable) -> str: + return getattr(func, "__name__", type(func).__name__) + + +def _callable_qualname(func: _AnyCallable) -> str: + return getattr(func, "__qualname__", type(func).__qualname__) + + +def _clone_function_without_annotations(fn: _AnyCallable) -> _AnyCallable: + """Return a behavior-equivalent function whose annotations are not a pickle dependency. + + ``@Flow.model`` analyzes annotations eagerly when the decorator runs and + stores the resolved contract in ``_FlowModelConfig``. During local/generated + model pickling we want workers to execute the original function body, but we + do not want them to re-evaluate or unpickle the original annotations. Some + runtime annotations, notably Pydantic generic specializations such as + ``GenericResult[int]``, are valid Python objects in-process but do not have a + durable import path for fresh-process cloudpickle restore. + """ + + if not isinstance(fn, FunctionType): + return fn + + clone = FunctionType(fn.__code__, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__) + clone.__kwdefaults__ = fn.__kwdefaults__ + clone.__module__ = fn.__module__ + clone.__qualname__ = fn.__qualname__ + clone.__doc__ = fn.__doc__ + clone.__dict__.update(getattr(fn, "__dict__", {})) + clone.__annotations__ = {} + return clone + + +def _is_pydantic_generic_type(annotation: Any) -> bool: + metadata = getattr(annotation, "__pydantic_generic_metadata__", None) + return isinstance(annotation, type) and bool(metadata) and metadata.get("origin") is not None + + +def _serialize_annotation(annotation: Any) -> Any: + """Serialize the annotation shapes that are part of a Flow.model contract. + + The goal is not to serialize arbitrary Python typing objects perfectly. It + is narrower and deliberate: keep the already-analyzed ``@Flow.model`` + contract portable across subprocess/Ray restore without asking workers to + resolve annotations from the original defining scope again. + + Raw annotations are still allowed as a fallback for ordinary importable + objects. The explicit cases below cover annotation containers that commonly + wrap ccflow/Pydantic model types and would otherwise embed fragile runtime + objects directly in the cloudpickle payload. + """ + + if annotation is _UNSET or annotation is inspect.Signature.empty: + return _SerializedAnnotation(kind="raw", value=annotation) + + if _is_pydantic_generic_type(annotation): + metadata = annotation.__pydantic_generic_metadata__ + return _SerializedAnnotation( + kind="pydantic_generic", + value=_serialize_annotation(metadata["origin"]), + args=tuple(_serialize_annotation(arg) for arg in metadata["args"]), + ) + + origin = get_origin(annotation) + if origin is Annotated: + args = get_args(annotation) + return _SerializedAnnotation(kind="annotated", value=_serialize_annotation(args[0]), args=args[1:]) + if origin in _UNION_ORIGINS: + return _SerializedAnnotation(kind="union", value=tuple(_serialize_annotation(arg) for arg in get_args(annotation))) + if origin is Literal: + return _SerializedAnnotation(kind="literal", value=get_args(annotation)) + if origin is not None: + return _SerializedAnnotation( + kind="generic_alias", + value=_serialize_annotation(origin), + args=tuple(_serialize_annotation(arg) for arg in get_args(annotation)), + ) + + return _SerializedAnnotation(kind="raw", value=annotation) + + +def _restore_annotation(payload: Any) -> Any: + """Restore an annotation payload produced by ``_serialize_annotation``.""" + + if not isinstance(payload, _SerializedAnnotation): + raise TypeError(f"Unknown serialized annotation payload: {payload!r}") + value = payload.value + if payload.kind == "raw": + return value + if payload.kind == "pydantic_generic": + origin = _restore_annotation(value) + args = tuple(_restore_annotation(arg) for arg in payload.args) + return origin[args[0] if len(args) == 1 else args] + if payload.kind == "annotated": + return Annotated.__class_getitem__((_restore_annotation(value), *payload.args)) + if payload.kind == "union": + return Union.__getitem__(tuple(_restore_annotation(arg) for arg in value)) + if payload.kind == "literal": + return Literal.__getitem__(value) + if payload.kind == "generic_alias": + origin = _restore_annotation(value) + args = tuple(_restore_annotation(arg) for arg in payload.args) + return origin[args[0] if len(args) == 1 else args] + raise TypeError(f"Unknown serialized annotation payload kind: {payload.kind!r}") + + +def _serialize_flow_model_param(param: _FlowModelParam) -> _SerializedFlowModelParam: + return _SerializedFlowModelParam( + name=param.name, + annotation=_serialize_annotation(param.annotation), + is_contextual=param.is_contextual, + is_lazy=param.is_lazy, + has_function_default=param.has_function_default, + function_default=param.function_default, + context_validation_annotation=_serialize_annotation(param.context_validation_annotation), + ) + + +def _restore_flow_model_param(payload: _SerializedFlowModelParam) -> _FlowModelParam: + if not isinstance(payload, _SerializedFlowModelParam): + raise TypeError(f"Unknown Flow.model parameter payload: {payload!r}") + return _FlowModelParam( + name=payload.name, + annotation=_restore_annotation(payload.annotation), + is_contextual=payload.is_contextual, + is_lazy=payload.is_lazy, + has_function_default=payload.has_function_default, + function_default=payload.function_default, + context_validation_annotation=_restore_annotation(payload.context_validation_annotation), + ) + + +def _serialize_flow_model_config(config: _FlowModelConfig) -> _SerializedFlowModelConfig: + """Return a tagged, portable description of an analyzed Flow.model config. + + This is intentionally explicit instead of relying on ``_FlowModelConfig`` or + ``_FlowModelParam`` pickle hooks. The payload is the persistence boundary + for local generated models and serialized context transforms: function + behavior plus the resolved contract needed to rebuild the generated + ``CallableModel`` class. It is not a second signature-analysis path. + """ + + return _SerializedFlowModelConfig( + func=_clone_function_without_annotations(config.func), + return_annotation=_serialize_annotation(config.return_annotation), + context_type=_serialize_annotation(config.context_type), + result_type=_serialize_annotation(config.result_type), + auto_wrap_result=config.auto_wrap_result, + auto_unwrap=config.auto_unwrap, + parameters=tuple(_serialize_flow_model_param(param) for param in config.parameters), + declared_context_type=_serialize_annotation(config.declared_context_type), + ) + + +def _restore_flow_model_config(payload: _SerializedFlowModelConfig) -> _FlowModelConfig: + if not isinstance(payload, _SerializedFlowModelConfig): + raise TypeError(f"Unknown Flow.model config payload: {payload!r}") + return _FlowModelConfig( + func=payload.func, + return_annotation=_restore_annotation(payload.return_annotation), + context_type=_restore_annotation(payload.context_type), + result_type=_restore_annotation(payload.result_type), + auto_wrap_result=payload.auto_wrap_result, + auto_unwrap=payload.auto_unwrap, + parameters=tuple(_restore_flow_model_param(param) for param in payload.parameters), + declared_context_type=_restore_annotation(payload.declared_context_type), + ) + + +def _resolved_flow_signature( + fn: _AnyCallable, + *, + resolved_hints: Optional[Dict[str, Any]] = None, + skip_self: bool = False, + require_return_annotation: bool = False, + annotation_error_suffix: str = "", + return_error_suffix: str = "", + function_name: Optional[str] = None, +) -> inspect.Signature: + sig = inspect.signature(fn) + resolved_hints = resolved_hints or {} + function_name = function_name or _callable_name(fn) + parameters = [] + + for name, param in sig.parameters.items(): + if skip_self and name == "self": + continue + if param.kind is inspect.Parameter.POSITIONAL_ONLY: + raise TypeError(f"Function {function_name} does not support positional-only parameter '{name}'.") + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {function_name} does not support {param.kind.description} parameter '{name}'.") + + annotation = resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation{annotation_error_suffix}") + + parameters.append(param.replace(annotation=annotation)) + + return_annotation = resolved_hints.get("return", sig.return_annotation) + if require_return_annotation and return_annotation is inspect.Signature.empty: + raise TypeError(f"Function {function_name} must have a return type annotation{return_error_suffix}") + + return sig.replace(parameters=parameters, return_annotation=return_annotation) + + +def _parse_annotation(annotation: Any) -> _ParsedAnnotation: + is_lazy = False + is_from_context = False + + while get_origin(annotation) is Annotated: + args = get_args(annotation) + annotation = args[0] + for metadata in args[1:]: + if isinstance(metadata, _LazyMarker): + is_lazy = True + elif isinstance(metadata, _FromContextMarker): + is_from_context = True + + if annotation is FromContext: + raise TypeError("FromContext is an annotation marker; use FromContext[T] in @Flow.model signatures.") + if annotation is Lazy: + raise TypeError("Lazy is an annotation marker; use Lazy[T] in @Flow.model signatures.") + + return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context) + + +def _strip_annotated(annotation: Any) -> Any: + while get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + return annotation + + +def _is_result_annotation(annotation: Any) -> bool: + annotation = _strip_annotated(annotation) + origin = get_origin(annotation) or annotation + return isinstance(origin, type) and issubclass(origin, ResultBase) + + +def _result_union_members(annotation: Any) -> Tuple[Any, ...]: + annotation = _strip_annotated(annotation) + if get_origin(annotation) not in _UNION_ORIGINS: + return () + return tuple(arg for arg in get_args(annotation) if arg is not type(None)) + + +def _context_type_annotations_compatible(func_annotation: Any, context_annotation: Any) -> bool: + func_annotation = _strip_annotated(func_annotation) + context_annotation = _strip_annotated(context_annotation) + + if func_annotation is Any: + return True + if context_annotation is Any: + return True + if func_annotation is context_annotation or func_annotation == context_annotation: + return True + + func_origin = get_origin(func_annotation) + context_origin = get_origin(context_annotation) + + if func_origin in _UNION_ORIGINS: + raw_func_args = get_args(func_annotation) + func_accepts_none = type(None) in raw_func_args + func_args = tuple(arg for arg in raw_func_args if arg is not type(None)) + if context_origin in _UNION_ORIGINS: + raw_context_args = get_args(context_annotation) + if type(None) in raw_context_args and not func_accepts_none: + return False + context_args = tuple(arg for arg in raw_context_args if arg is not type(None)) + if not context_args: + return func_accepts_none + return bool(context_args) and all( + any(_context_type_annotations_compatible(func_arg, context_arg) for func_arg in func_args) for context_arg in context_args + ) + if context_annotation is type(None): + return func_accepts_none + return any(_context_type_annotations_compatible(func_arg, context_annotation) for func_arg in func_args) + + if context_origin in _UNION_ORIGINS: + raw_context_args = get_args(context_annotation) + if type(None) in raw_context_args: + return False + context_args = tuple(arg for arg in raw_context_args if arg is not type(None)) + return bool(context_args) and all(_context_type_annotations_compatible(func_annotation, context_arg) for context_arg in context_args) + + if func_origin is Literal and context_origin is Literal: + return set(get_args(context_annotation)).issubset(set(get_args(func_annotation))) + if func_origin is Literal: + return False + if context_origin is Literal: + literal_values = get_args(context_annotation) + func_base = func_origin or func_annotation + if isinstance(func_base, type): + return all(isinstance(value, func_base) for value in literal_values) + return False + + func_base = func_origin or func_annotation + context_base = context_origin or context_annotation + if isinstance(func_base, type) and isinstance(context_base, type): + if not issubclass(context_base, func_base): + return False + elif func_base != context_base: + return False + + func_args = get_args(func_annotation) + context_args = get_args(context_annotation) + if bool(func_args) != bool(context_args): + return False + if len(func_args) != len(context_args): + return False + + for func_arg, context_arg in zip(func_args, context_args): + if get_origin(func_arg) is not None or get_origin(context_arg) is not None or isinstance(func_arg, type) or isinstance(context_arg, type): + if not _context_type_annotations_compatible(func_arg, context_arg): + return False + elif func_arg != context_arg: + return False + + return True + + +def _analyze_flow_function( + fn: _AnyCallable, + sig: inspect.Signature, + *, + is_model_dependency: Callable[[Any], bool], +) -> Tuple[_FlowModelParam, ...]: + analyzed_params = [] + + for param in sig.parameters.values(): + parsed = _parse_annotation(param.annotation) + if parsed.is_lazy and parsed.is_from_context: + raise TypeError(f"Parameter '{param.name}' cannot combine Lazy[...] and FromContext[...].") + has_default = param.default is not inspect.Parameter.empty + if parsed.is_lazy and has_default and not is_model_dependency(param.default): + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must default to a CallableModel dependency.") + if parsed.is_from_context and has_default and is_model_dependency(param.default): + raise TypeError(f"Parameter '{param.name}' is marked FromContext[...] and cannot default to a CallableModel.") + + analyzed_params.append( + _FlowModelParam( + name=param.name, + annotation=parsed.base, + is_contextual=parsed.is_from_context, + is_lazy=parsed.is_lazy, + has_function_default=has_default, + function_default=param.default if has_default else _UNSET, + ) + ) + + return tuple(analyzed_params) + + +def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[_FlowModelParam, ...]) -> Type[ContextBase]: + if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): + raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") + + context_fields = getattr(context_type, "model_fields", {}) + contextual_names = {param.name for param in contextual_params} + + missing = sorted(name for name in contextual_names if name not in context_fields) + if missing: + raise TypeError(f"context_type {context_type.__name__} must define fields for all FromContext parameters: {', '.join(missing)}") + + required_extra_fields = sorted( + name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in contextual_names and info.is_required() + ) + if required_extra_fields: + raise TypeError( + f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: " + f"{', '.join(required_extra_fields)}" + ) + + for param in contextual_params: + ctx_field = context_fields[param.name] + if not _context_type_annotations_compatible(param.annotation, ctx_field.annotation): + raise TypeError( + f"FromContext parameter '{param.name}' annotates {param.annotation!r}, but " + f"context_type {context_type.__name__} declares {ctx_field.annotation!r}." + ) + + return context_type + + +def _analyze_flow_model( + fn: _AnyCallable, + sig: inspect.Signature, + *, + context_type: Optional[Type[ContextBase]], + auto_unwrap: bool, + is_model_dependency: Callable[[Any], bool], +) -> _FlowModelConfig: + parameters = _analyze_flow_function(fn, sig, is_model_dependency=is_model_dependency) + reserved = sorted(param.name for param in parameters if param.name in _RESERVED_FLOW_MODEL_PARAM_NAMES) + if reserved: + names = ", ".join(repr(name) for name in reserved) + raise TypeError(f"Parameter name(s) {names} are reserved for generated model framework attributes.") + + contextual_params = tuple(param for param in parameters if param.is_contextual) + declared_context_type = None + if context_type is not None and not contextual_params: + raise TypeError("context_type=... requires FromContext[...] parameters.") + if context_type is not None: + declared_context_type = _validate_declared_context_type(context_type, contextual_params) + + if declared_context_type is not None: + updated_params = [] + context_fields = declared_context_type.model_fields + for param in parameters: + if not param.is_contextual: + updated_params.append(param) + continue + updated_params.append( + _FlowModelParam( + name=param.name, + annotation=param.annotation, + is_contextual=param.is_contextual, + is_lazy=param.is_lazy, + has_function_default=param.has_function_default, + function_default=param.function_default, + context_validation_annotation=context_fields[param.name].annotation, + ) + ) + parameters = tuple(updated_params) + + return_annotation = _strip_annotated(sig.return_annotation) + union_result_members = _result_union_members(return_annotation) + if union_result_members and any(_is_result_annotation(arg) for arg in union_result_members): + raise TypeError( + "@Flow.model does not support Union or Optional ResultBase return annotations. " + "Return one concrete ResultBase subclass, or return an ordinary value and let @Flow.model wrap it." + ) + + auto_wrap_result = not _is_result_annotation(return_annotation) + result_type = GenericResult[return_annotation] if auto_wrap_result else return_annotation + + return _FlowModelConfig( + func=fn, + return_annotation=sig.return_annotation, + context_type=FlowContext, + result_type=result_type, + auto_wrap_result=auto_wrap_result, + auto_unwrap=auto_unwrap, + parameters=parameters, + declared_context_type=declared_context_type, + ) + + +def _analyze_flow_context_transform( + fn: _AnyCallable, + sig: inspect.Signature, + *, + is_model_dependency: Callable[[Any], bool], +) -> _FlowModelConfig: + parameters = _analyze_flow_function(fn, sig, is_model_dependency=is_model_dependency) + lazy_params = [param.name for param in parameters if param.is_lazy] + if lazy_params: + raise TypeError(f"Flow.context_transform does not support Lazy[...] parameter(s): {', '.join(lazy_params)}") + return _FlowModelConfig( + func=fn, + return_annotation=sig.return_annotation, + context_type=FlowContext, + result_type=GenericResult, + auto_wrap_result=False, + auto_unwrap=False, + parameters=parameters, + ) + + +def _analyze_auto_context_function( + func: _AnyCallable, + *, + parent: Optional[Type[ContextBase]], + resolved_hints: Dict[str, Any], + is_model_dependency: Callable[[Any], bool], +) -> _AutoContextSpec: + sig = _resolved_flow_signature( + func, + resolved_hints=resolved_hints, + skip_self=True, + require_return_annotation=True, + annotation_error_suffix=" when auto_context=True", + return_error_suffix=" when auto_context=True", + function_name=_callable_qualname(func), + ) + base_class = parent or ContextBase + + if parent is not None: + parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) + sig_params = set(sig.parameters) + missing = parent_fields - sig_params + if missing: + raise TypeError(f"Parent context fields {missing} must be included in function signature") + + for fname in parent_fields: + parent_annotation = parent.model_fields[fname].annotation + func_annotation = sig.parameters[fname].annotation + if func_annotation is inspect.Parameter.empty: + continue + if not _context_type_annotations_compatible(func_annotation, parent_annotation): + raise TypeError( + f"auto_context field '{fname}' has annotation {func_annotation!r} which is incompatible " + f"with parent field annotation {parent_annotation!r}" + ) + + lazy_params = [name for name, param in sig.parameters.items() if _parse_annotation(param.annotation).is_lazy] + if lazy_params: + raise TypeError(f"Flow.call(auto_context=...) does not support Lazy[...] parameter(s): {', '.join(lazy_params)}") + + model_defaults = [ + name for name, param in sig.parameters.items() if param.default is not inspect.Parameter.empty and is_model_dependency(param.default) + ] + if model_defaults: + raise TypeError(f"Flow.call(auto_context=...) parameters cannot default to CallableModel dependencies: {', '.join(model_defaults)}") + + fields = {} + parent_model_fields = {} if parent is None else parent.model_fields + for name, param in sig.parameters.items(): + if name in parent_model_fields: + field_info = parent_model_fields[name] + fields[name] = (field_info.annotation, field_info) + continue + fields[name] = (param.annotation, ... if param.default is inspect.Parameter.empty else param.default) + return _AutoContextSpec( + signature=sig, + base_class=base_class, + class_name=f"{_callable_qualname(func)}_AutoContext", + fields=fields, + ) + + +def _normalize_auto_context_parent(auto_context: Any) -> Type[ContextBase]: + if auto_context is True: + return ContextBase + if inspect.isclass(auto_context) and issubclass(auto_context, ContextBase): + return auto_context + raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}") + + +def _wrap_auto_context_call( + func: _AnyCallable, + *, + parent: Type[ContextBase], + is_model_dependency: Callable[[Any], bool], +) -> _AnyCallable: + resolved_hints = get_type_hints(func, include_extras=True) + spec = _analyze_auto_context_function( + func, + parent=parent, + resolved_hints=resolved_hints, + is_model_dependency=is_model_dependency, + ) + + auto_context_class = create_ccflow_model(spec.class_name, __base__=spec.base_class, **spec.fields) + + @wraps(func) + def wrapper(self, context): + fn_kwargs = {name: getattr(context, name) for name in spec.fields} + return func(self, **fn_kwargs) + + context_default = inspect.Signature.empty + if all(not field.is_required() for field in auto_context_class.model_fields.values()): + context_default = auto_context_class() + + wrapper.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class, default=context_default), + ], + return_annotation=spec.signature.return_annotation, + ) + wrapper.__auto_context__ = auto_context_class + return wrapper diff --git a/ccflow/callable.py b/ccflow/callable.py index a5c545f..cff2a32 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -13,13 +13,30 @@ import abc import logging +from dataclasses import dataclass from functools import lru_cache, wraps from inspect import Signature, isclass, signature -from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, +) from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override +from ._flow_model_binding import _normalize_auto_context_parent, _wrap_auto_context_call from .base import ( BaseModel, ContextBase, @@ -29,6 +46,9 @@ ) from .validators import str_to_log_level +if TYPE_CHECKING: + from .flow_model import FlowAPI + __all__ = ( "GraphDepType", "GraphDepList", @@ -50,6 +70,14 @@ log = logging.getLogger(__name__) +@dataclass(frozen=True) +class EvaluationDependency: + """Internal marker for a dependency invocation in an effective identity payload.""" + + model: Any + context: Any + + # ***************************************************************************** # Base CallableModel definitions, before introducing the Flow decorator or # any evaluators @@ -175,6 +203,18 @@ def __deps__( Implementations should be decorated with Flow.call. """ + def _evaluation_identity_payload( + self, + context: Any, + ) -> Optional[Any]: + """Return an effective evaluation identity payload when available. + + Returning ``None`` keeps the model on the existing structural key path. + This is intentionally narrow and internal: only models whose effective + invocation can be described declaratively should override it. + """ + return None + CallableModelType = TypeVar("CallableModelType", bound=_CallableModel) @@ -392,7 +432,28 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): - """Decorator for methods on callable models""" + """Decorator for methods on callable models. + + Pass ``auto_context=True`` or a ``ContextBase`` subclass to derive the + method context from annotated parameters. + """ + auto_context = kwargs.pop("auto_context", False) + + if auto_context is not False: + context_parent = _normalize_auto_context_parent(auto_context) + + def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + wrapped = _wrap_auto_context_call( + fn, + parent=context_parent, + is_model_dependency=lambda value: isinstance(value, CallableModel), + ) + return FlowOptions(**kwargs)(wrapped) + + if len(args) == 1 and callable(args[0]): + return auto_context_decorator(args[0]) + return auto_context_decorator + if len(args) == 1 and callable(args[0]): # No arguments to decorator, this is the decorator fn = args[0] @@ -418,6 +479,66 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def model(*args, **kwargs): + """Generate a ``CallableModel`` factory from a typed Python function. + + Unmarked function parameters become construction-time model inputs. + They can be bound to literal values or to upstream ``CallableModel`` + dependencies. Parameters annotated as ``FromContext[T]`` are runtime + inputs supplied by ``model.flow.compute(...)``, a context object, + construction-time contextual defaults, or ``with_context(...)``. + + Plain return annotations are wrapped in ``GenericResult`` so generated + models still follow the normal ccflow result/evaluator/cache path. + Functions that already return a ``ResultBase`` subclass are left as-is. + + Common options mirror ``Flow.call``: ``cacheable``, ``volatile``, + ``log_level``, ``validate_result``, ``verbose``, and ``evaluator``. + Generated-model options also include ``context_type`` for validating + contextual fields, ``auto_unwrap`` for ergonomic compute results, and + ``model_base`` for custom ``CallableModel`` bases. + + Args: + func: The function being decorated. This is passed automatically in + bare decorator form, for example ``@Flow.model``. When using + options, for example ``@Flow.model(auto_unwrap=True)``, Python + first calls ``Flow.model(...)`` without a function and then + applies the returned decorator. + context_type: Optional ``ContextBase`` subclass used to validate + ``FromContext[...]`` fields together. + auto_unwrap: When ``True`` and the function's plain return value is + auto-wrapped in ``GenericResult[T]``, external + ``model.flow.compute(...)`` calls return the raw ``T`` value + instead of ``GenericResult[T]``. Internal model dependencies + still use normal ccflow result objects. + model_base: Custom ``CallableModel`` subclass to use as the base + class for the generated model. + """ + from .flow_model import flow_model + + return flow_model(*args, **kwargs) + + @staticmethod + def context_transform(*args, **kwargs): + """Create a serializable transform factory for ``model.flow.with_context``. + + Mark transform inputs with ``FromContext[T]`` when they should be read + from the caller's runtime context. Unmarked parameters are regular + arguments bound when the transform factory is called. + + Mapping-returning transforms are patch transforms and are passed + positionally to ``with_context(...)``. Scalar-returning transforms are + field transforms and are passed as keyword overrides, for example + ``model.flow.with_context(score=shift_score(amount=1))``. + + Transform bindings serialize enough function metadata to survive model + serialization, including local or nested functions through cloudpickle. + """ + from .flow_model import _flow_context_transform + + return _flow_context_transform(*args, **kwargs) + # ***************************************************************************** # Define "Evaluators" and associated types @@ -672,6 +793,13 @@ def __deps__( """ return [] + @property + def flow(self) -> "FlowAPI": + """Access flow helpers for execution, context transforms, and introspection.""" + from .flow_model import FlowAPI + + return FlowAPI(self) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. diff --git a/ccflow/context.py b/ccflow/context.py index 9a04fad..04967d3 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,16 +1,18 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" +from collections.abc import Mapping from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency from .validators import normalize_date, normalize_datetime __all__ = ( + "FlowContext", "NullContext", "GenericContext", "DateContext", @@ -89,6 +91,53 @@ # Starting 0.8.0 Nullcontext is an alias to ContextBase NullContext = ContextBase + +class FlowContext(ContextBase): + """Universal context for @Flow.model functions. + + Instead of generating a new ContextBase subclass for each @Flow.model, + this single class with extra="allow" serves as the universal carrier. + Validation happens against the generated model's declared contextual input + types at compute() time. + + This design avoids: + - Proliferation of dynamic _funcname_Context classes + - Class registration overhead for serialization + - Pickling issues with Ray/distributed computing + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + def _hash_key(self) -> Hashable: + return _freeze_for_hash(self.model_dump(mode="python")) + + def __eq__(self, other: Any) -> bool: + if self is other: + return True + if not isinstance(other, FlowContext): + return False + return self._hash_key() == other._hash_key() + + def __hash__(self) -> int: + return hash(self._hash_key()) + + +def _freeze_for_hash(value: Any) -> Hashable: + if isinstance(value, Mapping): + return tuple(sorted(((key, _freeze_for_hash(item)) for key, item in value.items()), key=lambda item: repr(item[0]))) + if isinstance(value, (list, tuple)): + return tuple(_freeze_for_hash(item) for item in value) + if isinstance(value, (set, frozenset)): + return frozenset(_freeze_for_hash(item) for item in value) + if hasattr(value, "model_dump"): + return (type(value), _freeze_for_hash(value.model_dump(mode="python"))) + try: + hash(value) + except TypeError as exc: + raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}: {value!r}") from exc + return value + + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/evaluators/common.py b/ccflow/evaluators/common.py index 2cab984..bae133f 100644 --- a/ccflow/evaluators/common.py +++ b/ccflow/evaluators/common.py @@ -14,11 +14,14 @@ from ..callable import ( CallableModel, ContextBase, + EvaluationDependency, EvaluatorBase, ModelEvaluationContext, ResultType, TransparentModelEvaluationContext, + WrapperModel, ) +from ..utils.tokenize import compute_cache_token __all__ = [ "cache_key", @@ -36,6 +39,37 @@ log = logging.getLogger(__name__) +class _EffectiveEvaluationKeyUnavailable(Exception): + """Internal signal to use the existing structural evaluation key.""" + + +_EFFECTIVE_EVALUATION_KEY_VERSION = "ccflow_effective_evaluation_key_v1" +_RECURSIVE_EFFECTIVE_IDENTITY_SENTINEL = "recursive_effective_identity" + + +class _IdentityMemoKey: + """Identity-based key that keeps objects alive while effective keys recurse. + + Raw ``id(...)`` tuples are unsafe here because rewritten context objects can + be short-lived and Python may reuse their ids during a single graph build. + Holding references keeps the ids stable for the lifetime of the memo key; + equality still checks object identity, so hash collisions are harmless. + """ + + __slots__ = ("model", "context", "_hash") + + def __init__(self, model: CallableModel, context: Any): + self.model = model + self.context = context + self._hash = hash((id(model), id(context))) + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Any) -> bool: + return isinstance(other, _IdentityMemoKey) and self.model is other.model and self.context is other.context + + def combine_evaluators(first: Optional[EvaluatorBase], second: Optional[EvaluatorBase]) -> EvaluatorBase: """Helper function to combine evaluators into a new evaluator. @@ -226,21 +260,147 @@ def _format_result(self, result: ResultType) -> str: return f"{msg_str}{pformat(result_dict, **self.format_config.pformat_config)}" -def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel]) -> bytes: - """Returns a key suitable for use in caching and dependency graph deduplication. +def _effective_model_key( + model: CallableModel, + context: Any, + memo: Dict[_IdentityMemoKey, bytes], + active: Set[int], +) -> Optional[bytes]: + """Return a model's opt-in effective key, or ``None`` for normal opt-out. + + Plain ``CallableModel`` instances opt out by returning ``None`` from + ``_evaluation_identity_payload()``. Dependency invocations inside the + payload are resolved by ``_resolve_effective_identity_payload()`` so models + declare what matters without constructing recursive keys themselves. + """ + token = _IdentityMemoKey(model, context) + if token in memo: + return memo[token] + model_id = id(model) + if model_id in active: + raise _EffectiveEvaluationKeyUnavailable("recursive effective identity") + + active.add(model_id) + try: + payload = model._evaluation_identity_payload(context) + # For normal CallableModels, `_evaluation_identity_payload` defaults to + # None, so we should hit this path + if payload is None: + return None + payload = _resolve_effective_identity_payload(payload, memo, active) + key = compute_cache_token( + data_values=[(_EFFECTIVE_EVALUATION_KEY_VERSION, payload)], + behavior_classes=[type(model)], + ).encode("utf-8") + memo[token] = key + return key + finally: + active.discard(model_id) + + +def _resolve_effective_identity_payload( + value: Any, + memo: Dict[_IdentityMemoKey, bytes], + active: Set[int], +) -> Any: + """Replace dependency invocation markers with recursive effective keys.""" + if isinstance(value, EvaluationDependency): + evaluation = value.model.__call__.get_evaluation_context(value.model, value.context) + try: + return _effective_evaluation_key(evaluation, memo=memo, active=active, fallback=False) + except _EffectiveEvaluationKeyUnavailable: + return (_RECURSIVE_EFFECTIVE_IDENTITY_SENTINEL, type(value.model).__module__, type(value.model).__qualname__) + if isinstance(value, dict): + return {key: _resolve_effective_identity_payload(item, memo, active) for key, item in value.items()} + if isinstance(value, list): + return [_resolve_effective_identity_payload(item, memo, active) for item in value] + if isinstance(value, tuple): + return tuple(_resolve_effective_identity_payload(item, memo, active) for item in value) + return value + + +def _effective_evaluation_key( + evaluation_context: ModelEvaluationContext, + memo: Optional[Dict[_IdentityMemoKey, bytes]] = None, + active: Optional[Set[int]] = None, + fallback: bool = True, +) -> bytes: + """Use opt-in effective identity for ``__call__``; otherwise preserve ``cache_key()``.""" + memo = {} if memo is None else memo + active = set() if active is None else active + inner, fn, outer_to_inner_evaluators = _flatten_cache_key_context(evaluation_context) + if fn != "__call__": + # Keep non-call evaluations, especially ``__deps__``, on the exact + # public structural key path. Effective identity is only meant to + # narrow normal model execution where generated ``@Flow.model`` code + # knows which ambient ``FlowContext`` fields affect the result. + return cache_key(evaluation_context) + if outer_to_inner_evaluators: + # Non-transparent evaluators can inspect the full ModelEvaluationContext + # and change the returned value based on ambient context fields that an + # opt-in model would otherwise ignore. Use the structural key whenever + # such an evaluator is part of the call chain; missing an optimization is + # preferable to returning a value cached under a narrower model identity. + return cache_key(evaluation_context) + + try: + key = _effective_model_key(inner.model, inner.context, memo, active) + except _EffectiveEvaluationKeyUnavailable as exc: + if not fallback: + raise + # Effective identity is an optimization/semantic narrowing for opt-in + # generated models. If deriving it is unclear, do not make cache/graph + # key construction a failure mode; fall back to the structural key. + log.debug("Falling back to structural evaluation key for %s.__call__: %s", type(inner.model).__name__, exc) + return cache_key(evaluation_context) + if key is None: + # This is the ordinary path for existing CallableModel classes. The + # base implementation returns None, so their cache and graph identities + # remain byte-for-byte equivalent to ``cache_key(evaluation_context)``. + return cache_key(evaluation_context) + + # Preserve the existing evaluation-context semantics around the narrowed + # model key: options still distinguish evaluations and transparent + # evaluators are ignored. + return compute_cache_token( + data_values=[ + ( + _EFFECTIVE_EVALUATION_KEY_VERSION, + "evaluation_context", + { + "fn": fn, + "options": inner.options, + }, + key, + ) + ], + ).encode("utf-8") + + +def cache_key(flow_obj: Union[ModelEvaluationContext, ContextBase, CallableModel], *, effective: bool = False) -> bytes: + """Returns a key suitable for caching and dependency graph deduplication. For ``ModelEvaluationContext`` inputs, strips ``TransparentModelEvaluationContext`` layers (evaluators that don't modify the return value) so that the key depends only on the underlying model, context, fn, options, and any non-transparent evaluators in the chain. + By default, this key is structural. Passing ``effective=True`` for a + ``ModelEvaluationContext`` enables generated-model effective identity, which + lets opt-in models ignore unused ambient context fields. Non-opt-in models + and non-evaluation inputs still use the structural key. + When the underlying model has callable methods, a behavior token (SHA-256 of method bytecode) is included so that code changes invalidate the cache. Args: flow_obj: The object to be tokenized to form the cache key. + effective: Whether to use generated-model effective identity for model + evaluations that opt into it. Defaults to ``False`` to preserve the + public structural semantics. """ - from ..utils.tokenize import compute_cache_token + if effective and isinstance(flow_obj, ModelEvaluationContext): + return _effective_evaluation_key(flow_obj) if isinstance(flow_obj, ModelEvaluationContext): flow_obj, fn, non_transparent = _flatten_cache_key_context(flow_obj) @@ -273,9 +433,11 @@ def is_transparent(self, context: ModelEvaluationContext) -> bool: def key(self, context: ModelEvaluationContext): """Function to convert a ModelEvaluationContext to a cache key. - Delegates to ``cache_key()`` which strips transparent evaluator layers. + Generated ``@Flow.model`` instances can use a narrower effective key + that ignores unused ambient context fields; ordinary ``CallableModel`` + paths fall back to the structural key. """ - return cache_key(context) + return cache_key(context, effective=True) @property def cache(self): @@ -311,22 +473,56 @@ class CallableModelGraph(BaseModel): root_id: bytes -def _build_dependency_graph(evaluation_context: ModelEvaluationContext, graph: CallableModelGraph, parent_key: Optional[bytes] = None): - key = cache_key(evaluation_context) - if parent_key: +def _build_dependency_graph( + evaluation_context: ModelEvaluationContext, + graph: CallableModelGraph, + parent_key: Optional[bytes] = None, + parent_model: Optional[CallableModel] = None, +) -> bytes: + # Generated/bound ``@Flow.model`` nodes can use effective identity so unused + # ambient FlowContext fields do not split the graph. Normal CallableModel + # nodes opt out and therefore still receive ``cache_key(evaluation_context)``. + key = _effective_evaluation_key(evaluation_context) + unwrapped_evaluation_context, _, _ = _flatten_cache_key_context(evaluation_context) + current_model = unwrapped_evaluation_context.model + is_same_evaluation_key = parent_key == key + is_collapsed_wrapper_child = is_same_evaluation_key and isinstance(parent_model, WrapperModel) and parent_model.model is current_model + + # Bound/wrapper models can share an effective graph key with their wrapped + # model after context rewriting. Adding the wrapper -> wrapped edge in that + # case would create a fake self-loop in the public graph because both ends + # have the same key. Suppress only that edge; real cycles between ordinary + # models are still recorded. + if parent_key and not is_collapsed_wrapper_child: graph.graph[parent_key].add(key) if key not in graph.ids: graph.ids[key] = evaluation_context - if key not in graph.graph: + is_new_graph_key = key not in graph.graph + if is_new_graph_key: graph.graph[key] = set() - # Note that __deps__ will be evaluated using whatever evaluator is configured for the model, - # which could include logging, caching, etc. - deps = evaluation_context.model.__deps__(evaluation_context.context) - # Sequential evaluation of dependencies of dependencies (could have other implementations) - for model, contexts in deps: - for context in contexts: - sub_evaluation_context = model.__call__.get_evaluation_context(model, context) - _build_dependency_graph(sub_evaluation_context, graph, parent_key=key) + + # Effective identity can merge multiple model objects into one graph key. + # A bound wrapper and its wrapped model may share the graph node, but the + # wrapped model still has dependencies that must be traversed. Preserve + # normal graph deduplication by key, and make the only exception the exact + # same-key wrapper -> wrapped edge. + if not is_new_graph_key and not is_collapsed_wrapper_child: + return key + + # Note that __deps__ will be evaluated using whatever evaluator is configured for the model, + # which could include logging, caching, etc. + deps = evaluation_context.model.__deps__(evaluation_context.context) + # Recursively walk dependency contexts depth-first to build the complete graph. + for model, contexts in deps: + for context in contexts: + sub_evaluation_context = model.__call__.get_evaluation_context(model, context) + _build_dependency_graph( + sub_evaluation_context, + graph, + parent_key=key, + parent_model=current_model, + ) + return key def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> CallableModelGraph: @@ -335,9 +531,8 @@ def get_dependency_graph(evaluation_context: ModelEvaluationContext) -> Callable Args: evaluation_context: The model and context to build the graph for. """ - root_key = cache_key(evaluation_context) - graph = CallableModelGraph(ids={}, graph={}, root_id=root_key) - _build_dependency_graph(evaluation_context, graph) + graph = CallableModelGraph(ids={}, graph={}, root_id=b"") + graph.root_id = _build_dependency_graph(evaluation_context, graph) return graph diff --git a/ccflow/examples/flow_model/__init__.py b/ccflow/examples/flow_model/__init__.py new file mode 100644 index 0000000..9b8bb74 --- /dev/null +++ b/ccflow/examples/flow_model/__init__.py @@ -0,0 +1 @@ +"""Flow.model examples.""" diff --git a/ccflow/examples/flow_model/config/flow_model_hydra_builder_demo.yaml b/ccflow/examples/flow_model/config/flow_model_hydra_builder_demo.yaml new file mode 100644 index 0000000..af049fe --- /dev/null +++ b/ccflow/examples/flow_model/config/flow_model_hydra_builder_demo.yaml @@ -0,0 +1,26 @@ +# Hydra config for ccflow/examples/flow_model/flow_model_hydra_builder_demo.py +# +# Pattern: +# - configure static pipeline specs in YAML +# - use model_alias to pass already-registered models into a plain Python builder +# - keep runtime context as runtime inputs, supplied later at execution time + +library_visitors: + _target_: ccflow.examples.flow_model.flow_model_hydra_builder_demo.count_visitors + location: library + +previous_week: + _target_: ccflow.examples.flow_model.flow_model_hydra_builder_demo.build_visitor_delta + current: + _target_: ccflow.compose.model_alias + model_name: library_visitors + label: previous_week + days_back: 7 + +previous_two_weeks: + _target_: ccflow.examples.flow_model.flow_model_hydra_builder_demo.build_visitor_delta + current: + _target_: ccflow.compose.model_alias + model_name: library_visitors + label: previous_two_weeks + days_back: 14 diff --git a/ccflow/examples/flow_model/flow_model_example.py b/ccflow/examples/flow_model/flow_model_example.py new file mode 100644 index 0000000..3762853 --- /dev/null +++ b/ccflow/examples/flow_model/flow_model_example.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +"""Small `@Flow.model` example. + +Shows how to: + +1. define stages as plain Python functions, +2. compose stages by passing upstream models as ordinary arguments, +3. rewrite contextual inputs on one dependency edge with `.flow.with_context(...)`, +4. execute the configured graph with `model.flow.compute(...)`. + +Run with: + python ccflow/examples/flow_model/flow_model_example.py +""" + +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow, FromContext + + +def _format_input_names(inputs: dict[str, object]) -> str: + """Return a compact comma-separated list for example output.""" + return ", ".join(inputs) or "(none)" + + +def _format_bound_inputs(inputs: dict[str, object]) -> str: + parts = [] + for name, value in inputs.items(): + display = "model" if hasattr(value, "flow") else repr(value) + parts.append(f"{name}={display}") + return ", ".join(parts) or "(none)" + + +@Flow.model(context_type=DateRangeContext) +def count_visitors( + location: str, + start_date: FromContext[date], + end_date: FromContext[date], +) -> int: + """Return a deterministic visitor count for one date window.""" + days = (end_date - start_date).days + 1 + location_offset = sum(ord(ch) for ch in location) % 17 + week_index = (end_date.toordinal() - date(2024, 1, 1).toordinal()) // 7 + return days * 12 + location_offset + week_index + + +@Flow.model(context_type=DateRangeContext) +def visitor_delta( + current: int, + previous: int, + label: str, + start_date: FromContext[date], + end_date: FromContext[date], +) -> dict[str, object]: + """Return both visitor counts plus their difference.""" + return { + "label": label, + "window": f"{start_date} -> {end_date}", + "current": current, + "previous": previous, + "change": current - previous, + } + + +@Flow.context_transform +def shift_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + """Shift both date fields together.""" + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +def build_visitor_pipeline(location: str): + """Build one reusable visitor-count pipeline.""" + current = count_visitors(location=location) + previous = current.flow.with_context(shift_window(days=7)) + return visitor_delta( + current=current, + previous=previous, + label="previous_week", + ) + + +def main() -> None: + print("=" * 64) + print("Flow.model Example") + print("=" * 64) + + pipeline = build_visitor_pipeline(location="library") + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 7), + ) + + computed_from_context = pipeline.flow.compute(ctx) + computed_from_kwargs = pipeline.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ) + + print("\nPipeline:") + print(" model: visitor_delta") + pipeline_inspection = pipeline.flow.inspect() + current_inspection = pipeline.current.flow.inspect() + previous_inspection = pipeline.previous.flow.inspect() + print(f" bound inputs: {_format_bound_inputs(pipeline_inspection.bound_inputs)}") + print(f" declared context inputs: {_format_input_names(pipeline_inspection.context_inputs)}") + print(f" runtime inputs: {_format_input_names(pipeline_inspection.runtime_inputs)}") + print(f" current runtime inputs: {_format_input_names(current_inspection.runtime_inputs)}") + print(f" previous runtime inputs: {_format_input_names(previous_inspection.runtime_inputs)}") + + print("\nExecution:") + print(f" context object == kwargs: {computed_from_context == computed_from_kwargs}") + + print("\nResult:") + for key, value in computed_from_kwargs.value.items(): + print(f" {key}: {value}") + + +if __name__ == "__main__": + main() diff --git a/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py b/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py new file mode 100644 index 0000000..fbecd7b --- /dev/null +++ b/ccflow/examples/flow_model/flow_model_hydra_builder_demo.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +"""Hydra + Flow.model builder demo. + +This example shows a clean way to mix: + +1. ergonomic `@Flow.model` pipeline wiring in Python, and +2. Hydra / ModelRegistry configuration for static pipeline specs. + +The pattern is: + +- keep runtime context (`start_date`, `end_date`) as runtime inputs, +- use a plain Python builder function for graph construction, +- let Hydra instantiate that builder and register the returned model. + +Run with: + python ccflow/examples/flow_model/flow_model_hydra_builder_demo.py +""" + +from datetime import date, timedelta +from pathlib import Path + +from ccflow import CallableModel, DateRangeContext, Flow, FromContext, ModelRegistry + +CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" + + +def _format_input_names(inputs: dict[str, object]) -> str: + """Return a compact comma-separated list for example output.""" + return ", ".join(inputs) or "(none)" + + +def _format_bound_inputs(inputs: dict[str, object]) -> str: + parts = [] + for name, value in inputs.items(): + display = "model" if hasattr(value, "flow") else repr(value) + parts.append(f"{name}={display}") + return ", ".join(parts) or "(none)" + + +def _print_model_summary(label: str, model: CallableModel) -> None: + inspection = model.flow.inspect() + print(f" {label}:") + print(f" bound inputs: {_format_bound_inputs(inspection.bound_inputs)}") + print(f" declared context inputs: {_format_input_names(inspection.context_inputs)}") + print(f" runtime inputs: {_format_input_names(inspection.runtime_inputs)}") + + +@Flow.model(context_type=DateRangeContext) +def count_visitors(location: str, start_date: FromContext[date], end_date: FromContext[date]) -> int: + """Return a deterministic visitor count for one date window.""" + days = (end_date - start_date).days + 1 + location_offset = sum(ord(ch) for ch in location) % 17 + week_index = (end_date.toordinal() - date(2024, 1, 1).toordinal()) // 7 + return days * 12 + location_offset + week_index + + +@Flow.model(context_type=DateRangeContext) +def visitor_delta( + current: int, + previous: int, + label: str, + start_date: FromContext[date], + end_date: FromContext[date], +) -> dict[str, object]: + """Return both visitor counts plus their difference.""" + return { + "label": label, + "window": f"{start_date} -> {end_date}", + "current": current, + "previous": previous, + "change": current - previous, + } + + +@Flow.context_transform +def shift_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + """Shift both date fields together.""" + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +def build_visitor_delta(current: CallableModel, *, label: str, days_back: int): + """Hydra-friendly builder that returns a configured visitor-count model.""" + previous = current.flow.with_context(shift_window(days=days_back)) + return visitor_delta( + current=current, + previous=previous, + label=label, + ) + + +def main() -> None: + registry = ModelRegistry.root() + registry.clear() + try: + registry.load_config_from_path(str(CONFIG_PATH), overwrite=True) + + previous_week = registry["previous_week"] + previous_two_weeks = registry["previous_two_weeks"] + + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 7), + ) + + print("=" * 68) + print("Hydra + Flow.model Builder Demo") + print("=" * 68) + print("\nLoaded from config:") + _print_model_summary("library_visitors", registry["library_visitors"]) + _print_model_summary("previous_week", previous_week) + _print_model_summary("previous_two_weeks", previous_two_weeks) + + previous_week_result = previous_week.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ).value + previous_two_weeks_result = previous_two_weeks.flow.compute(ctx).value + + print("\nPrevious week:") + for key, value in previous_week_result.items(): + print(f" {key}: {value}") + + print("\nPrevious two weeks:") + for key, value in previous_two_weeks_result.items(): + print(f" {key}: {value}") + finally: + registry.clear() + + +if __name__ == "__main__": + main() diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py new file mode 100644 index 0000000..101998f --- /dev/null +++ b/ccflow/flow_model.py @@ -0,0 +1,3057 @@ +"""Generated ``@Flow.model`` implementation. + +This module is intentionally the owner of the generated-model API. The public +surface is small: + +* ``@Flow.model`` turns a typed Python function into a ``CallableModel`` factory. +* ``FromContext[T]`` marks function parameters that should come from runtime + context instead of model construction. +* ``Lazy[T]`` marks a dependency that should be passed as a thunk and evaluated + only if user code calls it. +* ``model.flow.compute(...)`` and ``model.flow.with_context(...)`` provide the + ergonomic execution and contextual binding API. + +The implementation has four moving parts that should stay conceptually +separate: + +* Signature analysis lives in ``_flow_model_binding.py`` so ``callable.py`` does + not become a generated-model implementation module. +* Runtime context construction turns arbitrary context objects/kwargs into the + narrow context shape required by a target model. +* Context bindings are represented as serializable specs, then applied at + execution time before the wrapped model validates its context. +* Effective identity describes the parts of generated and bound model + invocations that are known before evaluation so cache/graph keys can ignore + unused ambient ``FlowContext`` fields. + +Two invariants matter more than cleverness here: + +* Existing ``CallableModel`` behavior must remain structural unless a generated + or bound model explicitly opts into effective identity. +* Public ``cache_key(...)`` stays structural by default; evaluators and graph + construction can request effective identity explicitly for opt-in models. + +File layout: + +1. Internal data structures and small value helpers. +2. Type coercion, result unwrapping, and registry-reference helpers. +3. Context-transform and generated-model serialization helpers. +4. Runtime context contracts and dependency context projection. +5. Contextual value resolution and effective identity. +6. ``with_context`` binding validation/application. +7. ``model.flow`` APIs and ``BoundModel``. +8. Generated model class construction and decorators. +""" + +import importlib +import inspect +import sys +from abc import update_abstractmethods +from base64 import b64decode, b64encode +from collections import OrderedDict +from functools import lru_cache, wraps +from typing import ( + Annotated, + Any, + Callable, + ClassVar, + Dict, + List, + Literal, + Mapping, + NamedTuple, + Optional, + Set, + Tuple, + Type, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) + +import cloudpickle +from pydantic import BaseModel as PydanticModel, Field, PrivateAttr, SkipValidation, TypeAdapter, ValidationError, create_model, model_validator +from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation + +from ._flow_model_binding import ( + _UNION_ORIGINS, + _UNSET, + FromContext, + Lazy, + _analyze_flow_context_transform, + _analyze_flow_model, + _callable_name, + _FlowModelConfig, + _FlowModelParam, + _resolved_flow_signature, + _restore_flow_model_config, + _serialize_flow_model_config, + _strip_annotated, +) +from .base import BaseModel, ContextBase, ContextType, ResultBase +from .callable import CallableModel, EvaluationDependency, Flow, FlowOptions, GraphDepList, WrapperModel +from .context import FlowContext +from .exttypes import PyObjectPath +from .local_persistence import register_ccflow_import_path +from .result import GenericResult +from .utils.tokenize import compute_behavior_token, compute_data_token + +__all__ = ( + "FlowAPI", + "BoundModel", + "FlowInspection", + "InputSpec", + "FromContext", + "Lazy", +) + +_AnyCallable = Callable[..., Any] + + +# --------------------------------------------------------------------------- +# Internal data structures +# --------------------------------------------------------------------------- + + +class _UnsetFlowInput: + def __repr__(self) -> str: + return "" + + def __reduce__(self): + return (_unset_flow_input_factory, ()) + + +_UNSET_FLOW_INPUT = _UnsetFlowInput() +_TYPE_ADAPTER_CACHE_MAXSIZE = 256 +_HASHABLE_TYPE_ADAPTER_CACHE: "OrderedDict[Any, TypeAdapter]" = OrderedDict() +_UNHASHABLE_TYPE_ADAPTER_CACHE: "OrderedDict[int, Tuple[Any, TypeAdapter]]" = OrderedDict() + + +class InputSpec(NamedTuple): + """Richer description of one direct function input in ``FlowInspection``.""" + + name: str + type: Any + required: bool + default: Any + value: Any + source: str + + @property + def default_repr(self) -> str: + """Compact representation of the declared function default.""" + + return _flow_debug_value_repr(self.default) + + @property + def value_repr(self) -> str: + """Compact representation of the effective direct value.""" + + return _flow_debug_value_repr(self.value) + + def __repr__(self) -> str: + return ( + f"InputSpec(name={self.name!r}, type={_expected_type_repr(self.type)}, " + f"required={self.required!r}, default={self.default_repr}, " + f"value={self.value_repr}, source={self.source!r})" + ) + + +class DependencySpec(NamedTuple): + """User-facing provenance for one direct dependency edge.""" + + path: str + model: CallableModel + context: Optional[ContextBase] = None + lazy: bool = False + + def __repr__(self) -> str: + return ( + f"DependencySpec(path={self.path!r}, model={_flow_debug_value_repr(self.model)}, " + f"context={_flow_debug_value_repr(self.context)}, lazy={self.lazy!r})" + ) + + +class FlowInspection(NamedTuple): + """Structured current-level debugging summary returned by ``model.flow.inspect()``.""" + + model: CallableModel + context_inputs: Dict[str, Any] + runtime_inputs: Dict[str, Any] + required_inputs: Dict[str, Any] + bound_inputs: Dict[str, Any] + inputs: Dict[str, InputSpec] + dependencies: Tuple[DependencySpec, ...] + + def __str__(self) -> str: + lines = [f"FlowInspection(model={_flow_debug_model_name(self.model)})"] + if self.inputs: + lines.append(" inputs:") + for spec in self.inputs.values(): + required = " required" if spec.required else "" + lines.append(f" {spec.name}: {_expected_type_repr(spec.type)} = {spec.value_repr} [{spec.source}{required}]") + else: + lines.append(" inputs: none") + if self.context_inputs: + context_inputs = ", ".join(f"{name}: {_expected_type_repr(annotation)}" for name, annotation in self.context_inputs.items()) + else: + context_inputs = "none" + lines.append(f" contextual inputs: {context_inputs}") + if self.runtime_inputs: + runtime_inputs = ", ".join(f"{name}: {_expected_type_repr(annotation)}" for name, annotation in self.runtime_inputs.items()) + else: + runtime_inputs = "none" + lines.append(f" runtime inputs: {runtime_inputs}") + if self.required_inputs: + required = ", ".join(f"{name}: {_expected_type_repr(annotation)}" for name, annotation in self.required_inputs.items()) + else: + required = "none" + lines.append(f" required runtime inputs: {required}") + lines.append(f" bound inputs: {', '.join(self.bound_inputs) if self.bound_inputs else 'none'}") + if self.dependencies: + lines.append(" dependencies:") + for dependency in self.dependencies: + target = _flow_debug_model_name(dependency.model) + suffix = " lazy" if dependency.lazy else "" + context = f" context={_flow_debug_value_repr(dependency.context)}" if dependency.context is not None else "" + lines.append(f" {dependency.path} -> {target}{suffix}{context}") + else: + lines.append(" dependencies: none") + return "\n".join(lines) + + def __repr__(self) -> str: + return str(self) + + def _repr_pretty_(self, printer, cycle: bool) -> None: + printer.text("..." if cycle else str(self)) + + +def _unset_flow_input_factory() -> _UnsetFlowInput: + return _UNSET_FLOW_INPUT + + +def _is_unset_flow_input(value: Any) -> bool: + return value is _UNSET_FLOW_INPUT + + +def _flow_debug_value_repr(value: Any) -> str: + if _is_unset_flow_input(value): + return repr(value) + + bound_context_type = globals().get("_BoundModelContext") + if bound_context_type is not None and isinstance(value, bound_context_type): + return repr(FlowContext(**_context_values(value))) + + callable_model_type = globals().get("CallableModel") + if callable_model_type is not None and isinstance(value, callable_model_type): + return f"" + + return repr(value) + + +def _flow_debug_model_name(model: CallableModel) -> str: + if isinstance(model, BoundModel): + return f"{_flow_debug_model_name(model.model)}.flow.with_context(...)" + name = model.meta.name + cls_name = type(model).__name__ + return f"{name} ({cls_name})" if name else cls_name + + +_ModelContextContract = NamedTuple( + "_ModelContextContract", + [ + ("runtime_context_type", Type[ContextBase]), + ("input_types", Optional[Dict[str, Any]]), + ("required_names", Tuple[str, ...]), + ("generated_model", Optional["_GeneratedFlowModelBase"]), + ], +) + + +class ContextTransform(PydanticModel): + """Serializable binding produced by ``@Flow.context_transform``. + + Transform bindings store the analyzed transform contract directly. We avoid + import-path restore here because decorators usually run before the module + global points at the returned transform factory. + """ + + kind: Literal["context_transform"] = "context_transform" + serialized_config: str + bound_args: Dict[str, Any] = Field(default_factory=dict) + + +class StaticValueSpec(PydanticModel): + """A ``with_context(field=value)`` static contextual override.""" + + kind: Literal["static_value"] = "static_value" + value: Any + + +_FieldOverrideSpec = Annotated[StaticValueSpec | ContextTransform, Field(discriminator="kind")] + + +class PatchContextOperation(PydanticModel): + """One ordered positional context patch in a ``with_context`` chain.""" + + kind: Literal["patch"] = "patch" + binding: ContextTransform + + +class FieldContextOperation(PydanticModel): + """One ordered field override in a ``with_context`` chain.""" + + kind: Literal["field"] = "field" + name: str + spec: _FieldOverrideSpec + + +_ContextOperation = Annotated[PatchContextOperation | FieldContextOperation, Field(discriminator="kind")] + + +class _BoundContextSpec(PydanticModel): + """Normalized, serializable representation of all context bindings.""" + + operations: List[_ContextOperation] = Field(default_factory=list) + + +class _BoundModelContext(FlowContext): + """Flow.call carrier for BoundModel that preserves existing context objects.""" + + _base_context: Optional[ContextBase] = PrivateAttr(default=None) + + @model_validator(mode="wrap") + @classmethod + def _preserve_context_base(cls, value, handler, info): + if isinstance(value, ContextBase): + return value + return handler(value) + + @classmethod + def from_values(cls, values: Dict[str, Any], base_context: Optional[ContextBase] = None) -> "_BoundModelContext": + context = cls(**values) + context._base_context = base_context + return context + + +class _DependencyIdentity(NamedTuple): + kind: Literal["dependency"] + evaluation: EvaluationDependency + + +class _LiteralIdentity(NamedTuple): + kind: Literal["literal"] + value: Any + + +class _UnresolvedLazyDependencyIdentity(NamedTuple): + kind: Literal["unresolved_lazy_dependency"] + model_type: str + model: Dict[str, Any] + context_type: str + context: Dict[str, Any] + missing_context: Tuple[str, ...] + missing_transform_context: Tuple[Tuple[str, Tuple[str, ...]], ...] + + +class _RegularInputIdentity(NamedTuple): + kind: Literal["regular_input"] + name: str + lazy: bool + payload: Any + + +class _GeneratedModelIdentity(NamedTuple): + kind: Literal["generated_flow_model_v1"] + model_type: str + contextual_inputs: Dict[str, Any] + regular_inputs: Tuple[_RegularInputIdentity, ...] + model_base_fields: Dict[str, Any] + + +class _LocalFlowModelPicklePayload(NamedTuple): + serialized_config: Any + factory_kwargs: Dict[str, Any] + + +# --------------------------------------------------------------------------- +# Small value helpers +# --------------------------------------------------------------------------- + + +def _context_values(context: ContextBase) -> Dict[str, Any]: + return dict(context) + + +def _context_transform_repr(transform: Any) -> str: + if isinstance(transform, ContextTransform): + name = _callable_name(_load_context_transform_config_from_binding(transform).func) + if not transform.bound_args: + return f"{name}()" + args = ", ".join(f"{key}={value!r}" for key, value in sorted(transform.bound_args.items())) + return f"{name}({args})" + return repr(transform) + + +def _context_transform_identifier(binding: ContextTransform) -> str: + return _callable_name(_load_context_transform_config_from_binding(binding).func) + + +def _is_model_dependency(value: Any) -> bool: + # Keep this predicate in the module that can import CallableModel. The + # binding analyzers also need it, but _flow_model_binding.py is imported by + # callable.py and cannot import CallableModel directly without a cycle. + return isinstance(value, CallableModel) + + +def _bound_field_names(model: Any) -> set[str]: + fields_set = getattr(model, "model_fields_set", None) + if fields_set is not None: + return set(fields_set) + return set() + + +def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: + if isinstance(context_type, type) and issubclass(context_type, ContextBase): + return context_type + + if get_origin(context_type) in _UNION_ORIGINS: + for arg in get_args(context_type): + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, ContextBase): + return arg + + return None + + +# --------------------------------------------------------------------------- +# Type coercion, lazy thunks, and registry references +# --------------------------------------------------------------------------- + + +def _remember_type_adapter(cache: "OrderedDict[Any, Any]", key: Any, value: Any) -> Any: + cache[key] = value + cache.move_to_end(key) + if len(cache) > _TYPE_ADAPTER_CACHE_MAXSIZE: + cache.popitem(last=False) + return value + + +def _type_adapter(annotation: Any) -> TypeAdapter: + try: + adapter = _HASHABLE_TYPE_ADAPTER_CACHE.pop(annotation) + except TypeError: + key = id(annotation) + cached = _UNHASHABLE_TYPE_ADAPTER_CACHE.pop(key, None) + if cached is not None and cached[0] is annotation: + _UNHASHABLE_TYPE_ADAPTER_CACHE[key] = cached + return cached[1] + adapter = TypeAdapter(annotation) + return _remember_type_adapter(_UNHASHABLE_TYPE_ADAPTER_CACHE, key, (annotation, adapter))[1] + except KeyError: + adapter = TypeAdapter(annotation) + return _remember_type_adapter(_HASHABLE_TYPE_ADAPTER_CACHE, annotation, adapter) + _HASHABLE_TYPE_ADAPTER_CACHE[annotation] = adapter + return adapter + + +def _can_validate_type(annotation: Any) -> bool: + try: + _type_adapter(annotation) + except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation): + return False + return True + + +def _expected_type_repr(annotation: Any) -> str: + try: + return annotation.__name__ + except AttributeError: + return repr(annotation) + + +def _coerce_value(name: str, value: Any, annotation: Any, source: str) -> Any: + if not _can_validate_type(annotation): + return value + try: + return _type_adapter(annotation).validate_python(value) + except ValidationError as exc: + expected = _expected_type_repr(annotation) + raise TypeError(f"{source} '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc + + +def _unwrap_model_result(value: Any) -> Any: + if isinstance(value, GenericResult): + return value.value + return value + + +def _make_lazy_thunk(value: CallableModel, context: ContextBase) -> Callable[[], Any]: + cache: Dict[str, Any] = {} + + def thunk(): + if "result" not in cache: + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + cache["result"] = _unwrap_model_result(dependency_model(dependency_context)) + return cache["result"] + + return thunk + + +def _make_coercing_lazy_thunk(inner_thunk: Callable[[], Any], name: str, annotation: Any) -> Callable[[], Any]: + cache: Dict[str, Any] = {} + + def thunk(): + if "result" not in cache: + cache["result"] = _coerce_value(name, inner_thunk(), annotation, "Regular parameter") + return cache["result"] + + return thunk + + +def _maybe_auto_unwrap_external_result(target: CallableModel, result: Any) -> Any: + generated = _generated_model_instance(target) + if generated is None: + return result + + config = type(generated).__flow_model_config__ + if config.auto_wrap_result and config.auto_unwrap: + return _unwrap_model_result(result) + return result + + +def _resolve_registry_candidate(value: str) -> Any: + try: + candidate = BaseModel.model_validate(value) + except ValidationError: + return None + return candidate if isinstance(candidate, BaseModel) else None + + +def _registry_candidate_allowed(expected_type: Any, candidate: Any) -> bool: + if _is_model_dependency(candidate): + return True + if not _can_validate_type(expected_type): + return True + try: + _type_adapter(expected_type).validate_python(candidate) + except ValidationError: + return False + return True + + +def _resolve_bound_param_registry_ref(param: _FlowModelParam, value: Any) -> Any: + """Resolve registry references for bound regular parameters. + + Generated fields use SkipValidation, so registry lookup no longer needs to + run as a Pydantic before-validator just to beat field validation. Keeping it + here lets the generated model's after-validator own the construction + contract. Literal string validation gets first refusal for non-lazy + parameters; registry aliases are fallback dependency syntax. + """ + + if not isinstance(value, str): + return value + if not param.is_lazy and _can_validate_type(param.annotation): + try: + _type_adapter(param.annotation).validate_python(value) + except ValidationError: + pass + else: + return value + + candidate = _resolve_registry_candidate(value) + if candidate is None: + return value + if _registry_candidate_allowed(param.annotation, candidate): + return candidate + return value + + +def _resolve_serialized_dependency_ref( + value: Any, + *, + include_target_alias: bool = False, +) -> Any: + """Restore a direct serialized CallableModel reference.""" + + def serialized_model_type(item: Dict[str, Any]) -> Optional[type]: + marker = item.get("type_", _UNSET) + if marker is _UNSET and include_target_alias: + marker = item.get("_target_", _UNSET) + if marker is _UNSET: + return None + try: + candidate = marker.object if isinstance(marker, PyObjectPath) else PyObjectPath(marker).object + except (ImportError, AttributeError, TypeError, ValueError): + return None + return candidate if inspect.isclass(candidate) and issubclass(candidate, CallableModel) else None + + if type(value) is not dict or serialized_model_type(value) is None: + return value + # ``type_`` is ccflow's default serialized-model marker. ``_target_`` is + # also a ccflow alias, but it is Hydra's config language too, so callers + # only enable that spelling after normal literal validation has failed. + try: + restored = BaseModel.model_validate(value) + except (ValidationError, ImportError, AttributeError, TypeError, ValueError): + return value + return restored if _is_model_dependency(restored) else value + + +def _ensure_named_python_function(fn: _AnyCallable, *, decorator_name: str) -> None: + if not inspect.isfunction(fn): + raise TypeError(f"{decorator_name} only supports Python functions.") + + name = getattr(fn, "__name__", "") + if name == "": + raise TypeError(f"{decorator_name} only supports named Python functions.") + + +# --------------------------------------------------------------------------- +# Context-transform serialization and generated-model persistence +# --------------------------------------------------------------------------- + + +def _serialize_context_transform_config(config: _FlowModelConfig) -> str: + payload = cloudpickle.dumps(_serialize_flow_model_config(config), protocol=5) + return b64encode(payload).decode("ascii") + + +@lru_cache(maxsize=None) +def _load_serialized_context_transform_config(serialized_config: str) -> _FlowModelConfig: + try: + payload = cloudpickle.loads(b64decode(serialized_config.encode("ascii"))) + config = _restore_flow_model_config(payload) + except Exception as exc: + raise TypeError("Stored context transform payload does not contain a Flow.context_transform binding.") from exc + return config + + +def _load_context_transform_config_from_binding(binding: ContextTransform) -> _FlowModelConfig: + return _load_serialized_context_transform_config(binding.serialized_config) + + +def clear_flow_model_caches() -> None: + """Clear module-level caches used by Flow.model internals.""" + + _HASHABLE_TYPE_ADAPTER_CACHE.clear() + _UNHASHABLE_TYPE_ADAPTER_CACHE.clear() + _load_serialized_context_transform_config.cache_clear() + + +def _is_mapping_annotation(annotation: Any) -> bool: + if annotation is inspect.Signature.empty: + return False + annotation = _strip_annotated(annotation) + origin = get_origin(annotation) + if origin in _UNION_ORIGINS: + variants = [arg for arg in get_args(annotation) if arg is not type(None)] + return bool(variants) and all(_is_mapping_annotation(arg) for arg in variants) + origin = origin or annotation + try: + return issubclass(origin, Mapping) + except TypeError: + return False + + +def _new_local_flow_model_for_pickle(serialized_factory_payload: bytes) -> BaseModel: + payload = cloudpickle.loads(serialized_factory_payload) + config = _restore_flow_model_config(payload.serialized_config) + # Do not call ``flow_model(config.func, **factory_kwargs)`` here. That would + # re-run worker-side type-hint resolution for local/postponed annotations. + # The serialized config is the resolved contract from the defining process; + # rebuild the generated class from it, then let pickle apply the third + # reducer element through ``__setstate__``. + factory = _build_flow_model_factory_from_config(config, payload.factory_kwargs) + cls = cast(type[BaseModel], getattr(factory, "_generated_model")) + return cls.__new__(cls) + + +def _new_generated_flow_model_for_pickle(factory_path: str) -> BaseModel: + """Allocate a generated flow model by importing its factory function. + + This is the cross-process-safe restore path: importing the factory's module + triggers the ``@Flow.model`` decorator, which re-creates the GeneratedModel + class. The reducer returns Pydantic state separately so pickle applies + ``__setstate__`` in the outer pickle stream, preserving normal memo + semantics for shared references, cycles, and protocol-5 buffers. + """ + factory = _load_module_attribute_uncached(factory_path) + generated_cls = getattr(factory, "_generated_model", None) + if generated_cls is None: + raise ImportError(f"Cannot restore generated flow model: '{factory_path}' does not have a _generated_model attribute.") + return generated_cls.__new__(generated_cls) + + +def _is_importable_function(func: _AnyCallable) -> bool: + """Return True if *func* is a top-level, importable named function.""" + module = getattr(func, "__module__", None) + name = getattr(func, "__name__", None) + qualname = getattr(func, "__qualname__", None) + return bool(module and module != "__main__" and name and qualname and qualname == name and "" not in qualname) + + +def _importable_function_path(func: _AnyCallable) -> Optional[str]: + if not _is_importable_function(func): + return None + return f"{func.__module__}.{func.__name__}" + + +def _load_module_attribute_uncached(path: str) -> Any: + module_name, attribute_name = path.rsplit(".", 1) + return getattr(importlib.import_module(module_name), attribute_name) + + +def _generated_model_factory_path_for_pickle(config: _FlowModelConfig, generated_cls: type) -> Optional[str]: + path = _importable_function_path(config.func) + if path is None: + return None + try: + factory = _load_module_attribute_uncached(path) + except (ImportError, AttributeError, ValueError): + return None + if getattr(factory, "_generated_model", None) is generated_cls: + return path + return None + + +def _module_has_factory_for_generated_class(module: Any, generated_cls: type, *, excluding: str) -> bool: + """Return whether a module-level factory still owns ``generated_cls``. + + During ``importlib.reload()``, generated class attributes from the previous + import can remain on the module while the new decorator is running. Those + stale classes should be replaced, not force a suffixed path that a clean + process will never recreate. If a live factory still points at the class, + the slot is occupied by a real duplicate and should not be reclaimed. + """ + + return any(name != excluding and getattr(value, "_generated_model", None) is generated_cls for name, value in vars(module).items()) + + +def _same_generated_function_source(existing_cls: type, config: _FlowModelConfig) -> bool: + existing_config = getattr(existing_cls, "__flow_model_config__", None) + if existing_config is None: + return False + existing_func = existing_config.func + current_func = config.func + existing_code = getattr(existing_func, "__code__", None) + current_code = getattr(current_func, "__code__", None) + filename = getattr(current_code, "co_filename", "") + return ( + getattr(existing_func, "__module__", None) == getattr(current_func, "__module__", None) + and getattr(existing_func, "__qualname__", None) == getattr(current_func, "__qualname__", None) + and existing_code is not None + and current_code is not None + and not filename.startswith("<") + and existing_code.co_filename == current_code.co_filename + and existing_code.co_firstlineno == current_code.co_firstlineno + ) + + +def _register_generated_model_class(config: _FlowModelConfig, generated_cls: type) -> None: + """Make generated classes importable when their factory function is importable. + + Importable module-level ``@Flow.model`` functions should serialize by a + stable module path. Local, nested, and ``__main__`` definitions still use + local-persistence registration because there is no durable import path for + their generated class. Duplicate importable generated names are rejected + instead of suffixed because suffixed paths are not reliably reproducible + across ``importlib.reload()`` and clean-process config/Hydra round trips. + """ + + if _importable_function_path(config.func) is None: + register_ccflow_import_path(generated_cls) + return + + module_name = getattr(config.func, "__module__", None) + module = sys.modules.get(module_name or "") + qualname = getattr(generated_cls, "__qualname__", "") + if module is None or not qualname or "" in qualname: + register_ccflow_import_path(generated_cls) + return + + obj = module + parts = qualname.split(".") + for part in parts[:-1]: + obj = getattr(obj, part, None) + if obj is None: + register_ccflow_import_path(generated_cls) + return + + name = parts[-1] + existing = getattr(obj, name, None) + if existing is None or existing is generated_cls: + setattr(obj, name, generated_cls) + return + if getattr(existing, "__flow_model_config__", None) is not None and _same_generated_function_source(existing, config): + # Reloaded modules can keep aliases to the previous factory until the + # assignment currently being evaluated completes. Matching the previous + # generated class by source location lets those stale aliases be replaced + # while still rejecting true duplicate function definitions. + setattr(obj, name, generated_cls) + return + if getattr(existing, "__flow_model_config__", None) is not None and not _module_has_factory_for_generated_class( + module, existing, excluding=_callable_name(config.func) + ): + # ``importlib.reload()`` can leave the previous generated class on the + # module while the factory function is being rebound. Replacing that + # stale class keeps the advertised path reproducible in a clean process. + setattr(obj, name, generated_cls) + return + raise ValueError( + f"Cannot register generated Flow model class at {module_name}.{'.'.join(parts)} because that path is already occupied. " + "Use a unique function name for each importable @Flow.model factory." + ) + + +# --------------------------------------------------------------------------- +# Runtime context contracts and dependency projection +# --------------------------------------------------------------------------- + + +def _runtime_context_for_model(model: CallableModel, values: Dict[str, Any]) -> ContextBase: + """Build the runtime context object expected by ``model`` from raw values.""" + + contract = _model_context_contract(model) + if contract.runtime_context_type is FlowContext: + return FlowContext(**values) + return contract.runtime_context_type.model_validate(values) + + +def _project_context_values_for_model(model: CallableModel, values: Dict[str, Any]) -> Dict[str, Any]: + """Keep only the context fields a target model knows how to consume. + + Generated models and ordinary ``ContextBase`` subclasses have declared input + fields. ``FlowContext`` and opaque context types do not, so their context is + forwarded unchanged. + """ + + contract = _model_context_contract(model) + if contract.runtime_context_type is FlowContext or contract.input_types is None: + return values + return {name: values[name] for name in contract.input_types if name in values} + + +def _dependency_context_for_model(model: CallableModel, context: ContextBase) -> ContextBase: + return _runtime_context_for_model(model, _project_context_values_for_model(model, _context_values(context))) + + +def _bound_model_default_base_context(bound_model: "BoundModel") -> Optional[ContextBase]: + """Return the wrapped plain model's default context object when one exists.""" + + contract = _model_context_contract(bound_model.model) + if contract.generated_model is not None: + return None + default_context = _plain_model_default_context(bound_model.model) + if default_context is _UNSET or default_context is None: + return None + if isinstance(default_context, ContextBase) and _context_matches_type(default_context, bound_model.model.context_type): + return default_context + return contract.runtime_context_type.model_validate(default_context) + + +def _bound_model_ambient_context(bound_model: "BoundModel", values: Dict[str, Any]) -> _BoundModelContext: + """Return the ambient carrier a bound wrapper should rewrite.""" + + base_context = _bound_model_default_base_context(bound_model) + if base_context is not None: + ambient = _context_values(base_context) + ambient.update(values) + return _BoundModelContext.from_values(ambient, base_context=base_context) + return _BoundModelContext.from_values(values) + + +def _resolved_dependency_invocation(value: CallableModel, context: ContextBase) -> Tuple[CallableModel, ContextBase]: + """Return the concrete ``(model, context)`` pair for a dependency call. + + Bound models must receive the full ambient ``FlowContext`` so their binding + transforms can read source fields before narrowing to the wrapped model's. + If the wrapper targets a handwritten model with a decorated default context, + dependency execution uses the same default baseline as ``bound.flow.compute``. + Unbound dependencies can be projected immediately. + """ + + if isinstance(value, BoundModel): + values = {} if context is None else _context_values(context) + return value, _bound_model_ambient_context(value, values) + return value, _dependency_context_for_model(value, context) + + +def _effective_context_operations(context_spec: _BoundContextSpec) -> Tuple[_ContextOperation, ...]: + """Drop field operations overwritten by later field operations. + + Patches are kept conservative because their write keys may be dynamic. Field + operations have explicit targets, so an earlier field transform for ``a`` + should not run or require inputs if a later field binding overwrites ``a``. + """ + + seen_fields: Set[str] = set() + operations: List[_ContextOperation] = [] + for operation in reversed(context_spec.operations): + if isinstance(operation, FieldContextOperation): + if operation.name in seen_fields: + continue + seen_fields.add(operation.name) + operations.append(operation) + operations.reverse() + return tuple(operations) + + +def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: + model = stage.model if isinstance(stage, BoundModel) else stage + if isinstance(model, _GeneratedFlowModelBase): + return model + return None + + +def _model_context_contract( + model: CallableModel, +) -> _ModelContextContract: + """Describe how ``model`` consumes runtime context. + + This is the central adapter between generated models, plain CallableModels, + optional/opaque context annotations, and ``FlowContext``. Callers use it to + decide which fields are contextual inputs, which are required, and which + concrete ``ContextBase`` subclass should validate runtime values. + """ + + generated = _generated_model_instance(model) + if generated is not None: + config = type(generated).__flow_model_config__ + return _ModelContextContract(FlowContext, dict(config.context_input_types), config.context_required_names, generated) + + context_cls = _concrete_context_type(model.context_type) + if context_cls is None: + return _ModelContextContract(FlowContext, None, (), None) + if context_cls is FlowContext or not hasattr(context_cls, "model_fields"): + return _ModelContextContract(context_cls, None, (), None) + return _ModelContextContract( + context_cls, + {name: info.annotation for name, info in context_cls.model_fields.items()}, + tuple(name for name, info in context_cls.model_fields.items() if info.is_required()), + None, + ) + + +def _model_base_field_names(generated: "_GeneratedFlowModelBase") -> set[str]: + """Return field names from model_base that aren't function parameters or internal fields.""" + config = type(generated).__flow_model_config__ + param_names = {param.name for param in config.parameters} + return {name for name in type(generated).model_fields if name not in param_names and name != "meta"} + + +def _missing_regular_param_names(model: "_GeneratedFlowModelBase", config: _FlowModelConfig) -> List[str]: + missing = [] + for param in config.regular_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + missing.append(param.name) + return missing + + +# --------------------------------------------------------------------------- +# Generated model input resolution +# --------------------------------------------------------------------------- + + +def _resolve_regular_param_value(model: "_GeneratedFlowModelBase", param: _FlowModelParam, context: ContextBase) -> Any: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + raise TypeError( + f"Regular parameter '{param.name}' for {_callable_name(type(model).__flow_model_config__.func)} is still unbound. " + "Bind it at construction time." + ) + if param.is_lazy: + if _is_model_dependency(value): + return _make_lazy_thunk(value, context) + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") + + if _is_model_dependency(value): + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + try: + return _unwrap_model_result(dependency_model(dependency_context)) + except Exception as exc: + parent = _callable_name(type(model).__flow_model_config__.func) + child = dependency_model.meta.name or type(dependency_model).__name__ + add_note = getattr(exc, "add_note", None) + if callable(add_note): + add_note(f"Error while evaluating dependency {parent}.{param.name} -> {child}.") + raise + return value + + +def _regular_dependency_identity(value: CallableModel, context: ContextBase) -> _DependencyIdentity: + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + return _DependencyIdentity( + kind="dependency", + evaluation=EvaluationDependency(dependency_model, dependency_context), + ) + + +def _lazy_regular_dependency_identity(value: CallableModel, context: ContextBase) -> Any: + unresolved, dependency_model, dependency_context = _lazy_dependency_identity(value, context) + if unresolved is not None: + return unresolved + assert dependency_model is not None + assert dependency_context is not None + return _DependencyIdentity( + kind="dependency", + evaluation=EvaluationDependency(dependency_model, dependency_context), + ) + + +def _regular_input_identity(param: _FlowModelParam, value: Any, context: ContextBase) -> _RegularInputIdentity: + if _is_model_dependency(value): + payload = _lazy_regular_dependency_identity(value, context) if param.is_lazy else _regular_dependency_identity(value, context) + else: + payload = _LiteralIdentity(kind="literal", value=value) + return _RegularInputIdentity(kind="regular_input", name=param.name, lazy=param.is_lazy, payload=payload) + + +def _collect_contextual_values( + model: "_GeneratedFlowModelBase", + config: _FlowModelConfig, + explicit_values: Dict[str, Any], +) -> Tuple[Dict[str, Any], List[str]]: + """Resolve ``FromContext`` values from runtime values, model defaults, and function defaults.""" + + resolved: Dict[str, Any] = {} + missing: List[str] = [] + + for param in config.contextual_params: + if param.name in explicit_values: + resolved[param.name] = explicit_values[param.name] + continue + + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + resolved[param.name] = value + continue + + if param.has_function_default: + resolved[param.name] = param.function_default + continue + + missing.append(param.name) + + return resolved, missing + + +def _resolved_contextual_inputs(model: "_GeneratedFlowModelBase", config: _FlowModelConfig, context: ContextBase) -> Dict[str, Any]: + """Validate and return the contextual kwargs passed to the user function.""" + + resolved, missing_contextual = _collect_contextual_values(model, config, _context_values(context)) + + if missing_contextual: + missing = ", ".join(sorted(missing_contextual)) + raise TypeError( + f"Missing contextual input(s) for {_callable_name(config.func)}: {missing}. " + "Supply them via the runtime context, compute(), with_context(), or construction-time contextual defaults." + ) + + if config.declared_context_type is not None: + return _validate_declared_context_values(config, resolved) + + return { + param.name: _coerce_value(param.name, resolved[param.name], param.validation_annotation, "Context field") + for param in config.contextual_params + } + + +def _validate_declared_context_values(config: _FlowModelConfig, values: Dict[str, Any]) -> Dict[str, Any]: + if config.declared_context_type is None: + return values + + validated = config.declared_context_type.model_validate(values) + return {param.name: getattr(validated, param.name) for param in config.contextual_params} + + +def _declared_context_field_annotation(config: _FlowModelConfig, name: str) -> Any: + """Return a field-level annotation preserving declared context constraints.""" + + assert config.declared_context_type is not None + field_info = config.declared_context_type.model_fields[name] + if field_info.metadata: + return Annotated.__class_getitem__((field_info.annotation, *field_info.metadata)) + return field_info.annotation + + +def _coerce_declared_context_field(config: _FlowModelConfig, name: str, value: Any) -> Any: + if config.declared_context_type is None: + return _UNSET + annotation = _declared_context_field_annotation(config, name) + if not _can_validate_type(annotation): + return value + return _type_adapter(annotation).validate_python(value) + + +def _coerce_contextual_value(config: _FlowModelConfig, param: _FlowModelParam, value: Any, source: str) -> Any: + declared_value = _coerce_declared_context_field(config, param.name, value) + if declared_value is not _UNSET: + return declared_value + return _coerce_value(param.name, value, param.validation_annotation, source) + + +def _coerce_model_context_value(model: CallableModel, field_name: str, value: Any, source: str) -> Any: + """Coerce a value for a contextual field when the target field type is known.""" + + generated = _generated_model_instance(model) + if generated is not None: + config = type(generated).__flow_model_config__ + if field_name in config.contextual_param_names: + return _coerce_contextual_value(config, config.param(field_name), value, source) + + contract = _model_context_contract(model) + if contract.input_types is None or field_name not in contract.input_types: + return value + return _coerce_value(field_name, value, contract.input_types[field_name], source) + + +# --------------------------------------------------------------------------- +# Effective identity helpers +# --------------------------------------------------------------------------- + +# Identity terms used below: +# - config identity: stable hash of the analyzed Flow.model contract, fixed at +# generated-class construction time and carried through local restore. +# - behavior token: tokenizer hook for invalidating cache keys when generated +# model behavior changes. +# - model type identity: importable factory path when available, otherwise the +# fixed local config identity; used inside effective cache-key payloads. +# - effective invocation identity: the full context/input payload for one model +# evaluation, with unused ambient FlowContext fields removed. + + +def _identity_context_values_for_model_values(model: CallableModel, values: Dict[str, Any]) -> Dict[str, Any]: + """Project context values for identity, not execution. + + This intentionally mirrors context projection but is kept separate because + identity cares about "what affects the result", while execution cares about + "what context object can validate for the target model". + """ + + contract = _model_context_contract(model) + if contract.input_types is None: + return values + return {name: values[name] for name in contract.input_types if name in values} + + +def _identity_context_values_and_missing_for_model(model: CallableModel, values: Dict[str, Any]) -> Tuple[Dict[str, Any], Tuple[str, ...]]: + """Return identity-relevant context values and required fields still missing.""" + + generated = _generated_model_instance(model) + if generated is not None: + config = type(generated).__flow_model_config__ + resolved, missing = _collect_contextual_values(generated, config, values) + return ( + {param.name: resolved[param.name] for param in config.contextual_params if param.name in resolved}, + tuple(missing), + ) + + context_values = _identity_context_values_for_model_values(model, values) + missing = tuple(name for name in _model_context_contract(model).required_names if name not in context_values) + return context_values, missing + + +def _context_transform_missing_context_names(binding: ContextTransform, values: Dict[str, Any]) -> Tuple[str, ...]: + config = _load_context_transform_config_from_binding(binding) + return tuple(param.name for param in config.contextual_params if param.name not in values and not param.has_function_default) + + +def _evaluate_context_transform_from_values(binding: ContextTransform, values: Dict[str, Any]) -> Any: + """Run a context transform against a raw value mapping. + + ``with_context`` transforms read from the original ambient context. Chained + bindings preserve write order, but they are not an implicit transform + pipeline; dependent rewrites should be expressed inside one transform. + """ + + config = _load_context_transform_config_from_binding(binding) + kwargs = _bound_context_transform_regular_kwargs(config, binding) + + for param in config.contextual_params: + if param.name in values: + kwargs[param.name] = _coerce_value(param.name, values[param.name], param.annotation, "Context transform field") + elif param.has_function_default: + kwargs[param.name] = param.function_default + else: + raise TypeError( + f"Missing contextual input(s) for context transform {_callable_name(config.func)}: {param.name}. " + "Supply them via the runtime context, a transform default, or combine dependent rewrites into one patch transform." + ) + + return config.func(**kwargs) + + +def _apply_context_spec_values_for_identity( + model: CallableModel, context_spec: "_BoundContextSpec", context: ContextBase +) -> Tuple[Dict[str, Any], Tuple[Tuple[str, Tuple[str, ...]], ...]]: + """Apply a binding spec for identity derivation. + + Unlike execution, identity must be able to describe a binding even when a + transform cannot yet run because some source context fields are missing. In + that case the missing transform inputs are recorded so two unresolved lazy + dependencies do not collapse accidentally. + """ + + original_values = _context_values(context) + current_values = dict(original_values) + missing_transforms: List[Tuple[str, Tuple[str, ...]]] = [] + + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + missing = _context_transform_missing_context_names(operation.binding, original_values) + if missing: + missing_transforms.append((_context_transform_identifier(operation.binding), missing)) + continue + result = _evaluate_context_transform_from_values(operation.binding, original_values) + current_values.update(_validate_patch_result(model, result)) + continue + + if isinstance(operation.spec, StaticValueSpec): + current_values[operation.name] = operation.spec.value + continue + + missing = _context_transform_missing_context_names(operation.spec, original_values) + if missing: + missing_transforms.append((operation.name, missing)) + current_values.pop(operation.name, None) + continue + result = _evaluate_context_transform_from_values(operation.spec, original_values) + current_values[operation.name] = _coerce_model_context_value(model, operation.name, result, "with_context()") + + return current_values, tuple(missing_transforms) + + +def _unresolved_lazy_model_identity(value: CallableModel) -> Dict[str, Any]: + """Return a stable structural model payload for unresolved lazy identity.""" + + if isinstance(value, BoundModel): + return { + "kind": "bound_model", + "model": _unresolved_lazy_model_identity(value.model), + "context_spec": value.context_spec.model_dump(mode="python"), + } + + dump = _stable_model_identity_dump(value) + return dump + + +def _contains_model_dependency(value: Any) -> bool: + if _is_model_dependency(value): + return True + if isinstance(value, (tuple, list, frozenset, set)): + return any(_contains_model_dependency(item) for item in value) + if isinstance(value, dict): + return any(_contains_model_dependency(item) for item in value.values()) + return False + + +def _stable_model_identity_dump(value: Any) -> Any: + if _is_model_dependency(value): + dump = value.model_dump(mode="python") + if isinstance(dump, dict): + dump = _stable_model_identity_dump(dump) + for name in type(value).model_fields: + field_value = getattr(value, name, _UNSET_FLOW_INPUT) + if name in dump and _contains_model_dependency(field_value): + dump[name] = _stable_model_identity_dump(field_value) + dump["type_"] = _model_type_identity(value) + return dump + if isinstance(value, tuple): + return tuple(_stable_model_identity_dump(item) for item in value) + if isinstance(value, list): + return [_stable_model_identity_dump(item) for item in value] + if isinstance(value, OrderedDict): + return OrderedDict((key, _stable_model_identity_dump(item)) for key, item in value.items()) + if isinstance(value, dict): + return {key: _stable_model_identity_dump(item) for key, item in value.items()} + if isinstance(value, frozenset): + return frozenset(_stable_model_identity_dump(item) for item in value) + if isinstance(value, set): + return {_stable_model_identity_dump(item) for item in value} + return value + + +def _unresolved_lazy_dependency_descriptor( + value: CallableModel, + context_values: Dict[str, Any], + missing_context: Tuple[str, ...], + missing_transform_context: Tuple[Tuple[str, Tuple[str, ...]], ...] = (), +) -> _UnresolvedLazyDependencyIdentity: + """Describe a lazy dependency whose runtime context cannot be resolved yet.""" + + return _UnresolvedLazyDependencyIdentity( + kind="unresolved_lazy_dependency", + model_type=_model_type_identity(value), + model=_unresolved_lazy_model_identity(value), + context_type=str(PyObjectPath.validate(FlowContext)), + context=context_values, + missing_context=missing_context, + missing_transform_context=missing_transform_context, + ) + + +def _flow_model_config_identity(config: _FlowModelConfig) -> str: + """Return a stable identity for a generated model's analyzed behavior. + + This is for local generated classes whose ``PyObjectPath`` points at a + random ``ccflow.local_persistence._Local_*`` name. The identity is computed + once when the generated class is created and then stored in factory kwargs + so cache keys survive pickle/cloudpickle restore. Do not recompute it on + every cache-key request: function closures can contain mutable state such as + call counters, and those values are runtime state, not model identity. + """ + + return compute_data_token( + ( + "local_generated_flow_model", + config.func, + config.return_annotation, + config.context_type, + config.result_type, + config.auto_wrap_result, + config.auto_unwrap, + tuple( + ( + param.name, + param.annotation, + param.is_contextual, + param.is_lazy, + param.has_function_default, + param.function_default, + param.context_validation_annotation, + ) + for param in config.parameters + ), + config.declared_context_type, + ) + ) + + +def _generated_model_behavior_token(config_identity: str, model_base: Type[CallableModel]) -> str: + return compute_data_token( + ( + "generated_flow_model_behavior", + config_identity, + f"{model_base.__module__}.{model_base.__qualname__}", + compute_behavior_token(model_base), + ) + ) + + +def _model_type_identity(model: CallableModel) -> str: + """Return a stable model-type identity for effective generated-model keys.""" + + generated = _generated_model_instance(model) + if generated is None: + return str(PyObjectPath.validate(type(model))) + + config = type(generated).__flow_model_config__ + factory_path = _generated_model_factory_path_for_pickle(config, type(generated)) + if factory_path is not None: + return factory_path + identity = getattr(type(generated), "__flow_model_identity__", None) + if identity is None: + identity = _flow_model_config_identity(config) + return f"local:{identity}" + + +def _lazy_dependency_identity( + value: CallableModel, + context: ContextBase, +) -> Tuple[Optional[Any], Optional[CallableModel], Optional[ContextBase]]: + """Resolve or describe a lazy dependency for effective identity. + + If all required context is available, return the concrete dependency + invocation so evaluator/common can recursively derive its effective key. If + not, return a stable unresolved descriptor instead of evaluating or raising. + + This deliberately does not try to prove whether the lazy thunk will be used + by the downstream function. That would require executing user logic before + key construction. The current policy is conservative: resolvable lazy + dependencies participate in identity; lazy dependencies with unresolved + runtime context are represented by descriptors. + """ + + if isinstance(value, BoundModel): + dependency_model = value.model + values, missing_transform_context = _apply_context_spec_values_for_identity(dependency_model, value.context_spec, context) + if missing_transform_context: + context_values: Dict[str, Any] = {} + missing_context = _model_context_contract(dependency_model).required_names + else: + context_values, missing_context = _identity_context_values_and_missing_for_model(dependency_model, values) + if missing_context or missing_transform_context: + return _unresolved_lazy_dependency_descriptor(value, context_values, missing_context, missing_transform_context), None, None + return None, *_resolved_dependency_invocation(value, context) + + dependency_model = value + context_values, missing_context = _identity_context_values_and_missing_for_model(dependency_model, _context_values(context)) + if missing_context: + return _unresolved_lazy_dependency_descriptor(value, context_values, missing_context), None, None + return None, *_resolved_dependency_invocation(value, context) + + +def _validate_bound_param_value( + config: _FlowModelConfig, + param: _FlowModelParam, + value: Any, + source: str, +) -> Any: + """Validate a construction-time bound parameter for a generated model.""" + + if param.is_contextual: + if _is_model_dependency(value): + raise TypeError( + f"Parameter '{param.name}' is marked FromContext[...] and cannot be bound to a CallableModel. " + "Bind a literal contextual default or supply it via compute()/with_context()." + ) + return _coerce_contextual_value(config, param, value, source) + + if param.is_lazy: + if not _is_model_dependency(value): + value = _resolve_serialized_dependency_ref(value, include_target_alias=True) + if not _is_model_dependency(value): + raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must be bound to a CallableModel dependency.") + return value + if _is_model_dependency(value): + return value + try: + return _coerce_value(param.name, value, param.annotation, source) + except TypeError: + value_with_alias_deps = _resolve_serialized_dependency_ref(value, include_target_alias=True) + if value_with_alias_deps is not value and _is_model_dependency(value_with_alias_deps): + return value_with_alias_deps + raise + + +def _generated_model_identity_payload( + model: "_GeneratedFlowModelBase", + context: ContextBase, +) -> Optional[_GeneratedModelIdentity]: + """Describe the generated model's effective invocation for cache keys. + + Contract: + - contextual identity is projected to the ``FromContext[...]`` fields the + generated model consumes; + - unused ambient ``FlowContext`` fields are ignored; + - regular literal inputs are included directly; + - regular ``CallableModel`` inputs are recorded as dependency invocations; + - lazy ``CallableModel`` inputs are recorded conservatively when their + context can be resolved, even if the lazy thunk is not called later; and + - unresolved lazy dependency runtime context is recorded explicitly instead of + forcing eager dependency resolution. + + Returning ``None`` asks the evaluator to use the structural key. + """ + + config = type(model).__flow_model_config__ + regular_inputs = [] + for param in config.regular_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + return None + regular_inputs.append(_regular_input_identity(param, value, context)) + + model_base_fields = {name: getattr(model, name) for name in sorted(_model_base_field_names(model))} + + return _GeneratedModelIdentity( + kind="generated_flow_model_v1", + model_type=_model_type_identity(model), + contextual_inputs=_resolved_contextual_inputs(model, config, context), + regular_inputs=tuple(regular_inputs), + model_base_fields=model_base_fields, + ) + + +# --------------------------------------------------------------------------- +# Static binding resolution and with_context normalization +# --------------------------------------------------------------------------- + + +def _resolved_static_contextual_values( + model: "_GeneratedFlowModelBase", + config: _FlowModelConfig, + static_overrides: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: + resolved, missing = _collect_contextual_values(model, config, static_overrides or {}) + return None if missing else resolved + + +def _validate_bound_declared_context_defaults(model: "_GeneratedFlowModelBase", config: _FlowModelConfig) -> None: + resolved = _resolved_static_contextual_values(model, config) + if resolved is None: + return + + validated = _validate_declared_context_values(config, resolved) + for param in config.contextual_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + object.__setattr__(model, param.name, validated[param.name]) + + +def _bound_context_transform_regular_kwargs(config: _FlowModelConfig, binding: ContextTransform) -> Dict[str, Any]: + """Return already-bound regular kwargs for a context transform invocation.""" + + kwargs: Dict[str, Any] = {} + for param in config.regular_params: + if param.name in binding.bound_args: + kwargs[param.name] = _coerce_value(param.name, binding.bound_args[param.name], param.annotation, "Context transform argument") + elif param.has_function_default: + kwargs[param.name] = param.function_default + else: + raise TypeError(f"Context transform '{_callable_name(config.func)}' is missing required regular parameter '{param.name}'.") + return kwargs + + +def _evaluate_static_context_transform(binding: ContextTransform) -> Any: + """Evaluate a transform only when it has no contextual inputs at all.""" + + config = _load_context_transform_config_from_binding(binding) + if config.contextual_params: + return _UNSET + + kwargs = _bound_context_transform_regular_kwargs(config, binding) + return config.func(**kwargs) + + +def _statically_resolved_context_values(model: CallableModel, context_spec: _BoundContextSpec) -> Optional[Dict[str, Any]]: + """Return static binding values when the whole spec can be resolved without runtime context.""" + + values: Dict[str, Any] = {} + + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + result = _evaluate_static_context_transform(operation.binding) + if result is _UNSET: + return None + values.update(_validate_patch_result(model, result)) + continue + + if isinstance(operation.spec, StaticValueSpec): + value = operation.spec.value + else: + value = _evaluate_static_context_transform(operation.spec) + if value is not _UNSET: + value = _coerce_model_context_value(model, operation.name, value, "with_context()") + if value is _UNSET: + return None + values[operation.name] = value + + return values + + +def _statically_resolved_context_field_values(model: CallableModel, context_spec: _BoundContextSpec) -> Dict[str, Any]: + values: Dict[str, Any] = {} + + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + result = _evaluate_static_context_transform(operation.binding) + if result is _UNSET: + values.clear() + continue + values.update(_validate_patch_result(model, result)) + continue + + if isinstance(operation.spec, StaticValueSpec): + values[operation.name] = operation.spec.value + continue + + value = _evaluate_static_context_transform(operation.spec) + if value is _UNSET: + values.pop(operation.name, None) + continue + values[operation.name] = _coerce_model_context_value(model, operation.name, value, "with_context()") + + return values + + +def _statically_resolved_context_field_names(model: CallableModel, context_spec: _BoundContextSpec) -> Set[str]: + return set(_statically_resolved_context_field_values(model, context_spec)) + + +def _context_transform_input_types(binding: ContextTransform, *, required_only: bool) -> Dict[str, Any]: + config = _load_context_transform_config_from_binding(binding) + names = config.context_required_names if required_only else config.contextual_param_names + return {name: config.context_input_types[name] for name in names} + + +def _merge_context_input_types(target: Dict[str, Any], updates: Dict[str, Any]) -> None: + """Merge context input annotations without silently hiding conflicts.""" + + for name, annotation in updates.items(): + if name in target and target[name] != annotation: + raise TypeError(f"Conflicting runtime context annotations for {name!r}: {target[name]!r} and {annotation!r}.") + target[name] = annotation + + +def _dynamic_context_operation_effects(context_spec: _BoundContextSpec, *, required_only: bool) -> Tuple[Set[str], Dict[str, Any]]: + supplied_fields: Set[str] = set() + input_types: Dict[str, Any] = {} + + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + patch_result = _evaluate_static_context_transform(operation.binding) + if patch_result is _UNSET: + _merge_context_input_types(input_types, _context_transform_input_types(operation.binding, required_only=required_only)) + continue + continue + + if isinstance(operation.spec, StaticValueSpec): + continue + + value = _evaluate_static_context_transform(operation.spec) + if value is _UNSET: + supplied_fields.add(operation.name) + _merge_context_input_types(input_types, _context_transform_input_types(operation.spec, required_only=required_only)) + continue + + return supplied_fields, input_types + + +def _validate_static_context_spec_declared_context(model: CallableModel, context_spec: _BoundContextSpec) -> _BoundContextSpec: + generated = _generated_model_instance(model) + if generated is None: + return context_spec + + config = type(generated).__flow_model_config__ + if config.declared_context_type is None: + return context_spec + + static_context_values = _statically_resolved_context_values(model, context_spec) + if static_context_values is None: + return context_spec + + resolved = _resolved_static_contextual_values(generated, config, static_context_values) + if resolved is None: + return context_spec + + _validate_declared_context_values(config, resolved) + return context_spec + + +def _validate_with_context_field_names(model: CallableModel, names: List[str]) -> None: + contract = _model_context_contract(model) + if contract.input_types is not None: + invalid = sorted(set(names) - set(contract.input_types)) + if invalid: + names = ", ".join(invalid) + raise TypeError(f"with_context() only accepts contextual fields. Invalid field(s): {names}.") + + +def _binding_uses_patch_shape(binding: ContextTransform) -> bool: + return _is_mapping_annotation(_load_context_transform_config_from_binding(binding).return_annotation) + + +def _validate_context_transform_factory_kwargs(config: _FlowModelConfig, kwargs: Dict[str, Any]) -> Dict[str, Any]: + unknown = sorted(set(kwargs) - {param.name for param in config.parameters}) + if unknown: + raise TypeError(f"{_callable_name(config.func)}() got unexpected keyword argument(s): {', '.join(unknown)}") + + contextual = sorted(name for name in kwargs if config.param(name).is_contextual) + if contextual: + raise TypeError( + f"{_callable_name(config.func)}() only binds regular parameters. Do not pass contextual parameter(s): {', '.join(contextual)}." + ) + + missing = [param.name for param in config.regular_params if param.name not in kwargs and not param.has_function_default] + if missing: + raise TypeError(f"{_callable_name(config.func)}() is missing required regular parameter(s): {', '.join(missing)}") + + validated: Dict[str, Any] = {} + for param in config.regular_params: + if param.name not in kwargs: + continue + value = kwargs[param.name] + validated[param.name] = _coerce_value(param.name, value, param.annotation, "Context transform argument") + return validated + + +def _validate_patch_result(model: CallableModel, result: Any) -> Dict[str, Any]: + if not isinstance(result, Mapping): + raise TypeError( + f"Patch context transform for {model!r} must return a mapping of contextual field names to values, got {type(result).__name__}." + ) + + patch = dict(result) + if not all(isinstance(name, str) for name in patch): + raise TypeError("Patch context transforms must return a mapping with string field names.") + + _validate_with_context_field_names(model, list(patch)) + contract = _model_context_contract(model) + if contract.input_types is None: + return patch + + return {name: _coerce_model_context_value(model, name, value, "with_context() patch") for name, value in patch.items()} + + +def _normalize_with_context(model: CallableModel, patches: Tuple[Any, ...], field_overrides: Dict[str, Any]) -> _BoundContextSpec: + """Validate and normalize user-facing ``with_context(...)`` arguments.""" + + operations: List[_ContextOperation] = [] + for patch in patches: + if callable(patch): + raise TypeError("Positional with_context() arguments must be bound @Flow.context_transform results that return a mapping.") + if not isinstance(patch, ContextTransform): + raise TypeError("Positional with_context() arguments must be @Flow.context_transform bindings that return a mapping.") + if not _binding_uses_patch_shape(patch): + raise TypeError( + "Field context transforms must be passed by keyword to with_context(...). Patch transforms belong in positional arguments." + ) + operations.append(PatchContextOperation(binding=patch)) + + _validate_with_context_field_names(model, list(field_overrides)) + contract = _model_context_contract(model) + for name, value in field_overrides.items(): + if isinstance(value, ContextTransform): + if _binding_uses_patch_shape(value): + raise TypeError("Patch transforms must be passed positionally to with_context(...), not as keyword field overrides.") + operations.append(FieldContextOperation(name=name, spec=value)) + continue + if callable(value) and (contract.input_types is None or name not in contract.input_types): + raise TypeError( + "Callable keyword values in with_context() must either be bound @Flow.context_transform results " + "or validate against a declared contextual field type." + ) + spec = StaticValueSpec( + value=value + if contract.input_types is None or name not in contract.input_types + else _coerce_model_context_value(model, name, value, "with_context()") + ) + operations.append(FieldContextOperation(name=name, spec=spec)) + + context_spec = _BoundContextSpec(operations=operations) + return _validate_static_context_spec_declared_context(model, context_spec) + + +# --------------------------------------------------------------------------- +# Bound context application and compute context construction +# --------------------------------------------------------------------------- + + +def _context_from_values_preserving_private_state(context: ContextBase, values: Dict[str, Any]) -> ContextBase: + """Validate updated public values while preserving private context state.""" + + if values == _context_values(context): + return context + validated = type(context).model_validate(values) + private = getattr(context, "__pydantic_private__", None) + if private is not None: + object.__setattr__(validated, "__pydantic_private__", dict(private)) + return validated + + +def _apply_context_spec_values(model: CallableModel, context_spec: _BoundContextSpec, context: ContextBase) -> Dict[str, Any]: + """Apply a binding spec at execution time and return rewritten context values.""" + + original_values = _context_values(context) + current_values = dict(original_values) + + for operation in _effective_context_operations(context_spec): + if isinstance(operation, PatchContextOperation): + result = _evaluate_context_transform_from_values(operation.binding, original_values) + current_values.update(_validate_patch_result(model, result)) + continue + + if isinstance(operation.spec, StaticValueSpec): + current_values[operation.name] = operation.spec.value + continue + result = _evaluate_context_transform_from_values(operation.spec, original_values) + current_values[operation.name] = _coerce_model_context_value(model, operation.name, result, "with_context()") + + return current_values + + +def _apply_context_spec(model: CallableModel, context_spec: _BoundContextSpec, context: ContextBase) -> ContextBase: + """Apply bindings, project to the wrapped model, and build its runtime context.""" + + if not context_spec.operations: + if isinstance(context, _BoundModelContext): + if context._base_context is not None: + values = _project_context_values_for_model(model, _context_values(context)) + return _context_from_values_preserving_private_state(context._base_context, values) + return _dependency_context_for_model(model, context) + if _context_matches_type(context, model.context_type): + return context + return _dependency_context_for_model(model, context) + + values = _apply_context_spec_values(model, context_spec, context) + if isinstance(context, _BoundModelContext): + values = _project_context_values_for_model(model, values) + if context._base_context is not None: + return _context_from_values_preserving_private_state(context._base_context, values) + return _runtime_context_for_model(model, values) + if _context_matches_type(context, model.context_type): + return _context_from_values_preserving_private_state(context, values) + return _runtime_context_for_model(model, _project_context_values_for_model(model, values)) + + +def _plain_model_default_context(model: CallableModel) -> Any: + call = getattr(type(model), "__call__", None) + wrapped = getattr(call, "__wrapped__", None) + if wrapped is None: + return _UNSET + try: + parameter = inspect.signature(wrapped).parameters.get("context") + except (TypeError, ValueError): + return _UNSET + if parameter is None or parameter.default is inspect.Signature.empty: + return _UNSET + return parameter.default + + +def _plain_model_default_context_values( + model: CallableModel, + runtime_context_type: Type[ContextBase], +) -> Optional[Dict[str, Any]]: + default_context = _plain_model_default_context(model) + if default_context is _UNSET: + return None + if default_context is None: + if _is_optional_context_type(model.context_type): + return {} + return _context_values(runtime_context_type.model_validate(default_context)) + if isinstance(default_context, ContextBase): + return _context_values(default_context) + return _context_values(runtime_context_type.model_validate(default_context)) + + +def _plain_model_compute_context_from_default( + model: CallableModel, + default_context: Any, + default_values: Dict[str, Any], + kwargs: Dict[str, Any], + runtime_context_type: Type[ContextBase], +) -> Optional[ContextBase]: + if default_context is None and not kwargs and _is_optional_context_type(model.context_type): + return None + + if isinstance(default_context, ContextBase) and _context_matches_type(default_context, model.context_type): + values = dict(default_values) + values.update(kwargs) + return _context_from_values_preserving_private_state(default_context, values) + + values = dict(default_values) + values.update(kwargs) + return runtime_context_type.model_validate(values) + + +def _build_compute_context(model: CallableModel, context: Any, kwargs: Dict[str, Any]) -> Optional[ContextBase]: + """Construct the context used by ``FlowAPI.compute`` for a target model. + + ``compute`` is intentionally not a second constructor. For generated models + it only supplies contextual inputs; regular parameters and model_base fields + must already be bound on the model instance. + """ + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword inputs, but not both.") + + ctx_type = model.context_type + _ctx_is_optional = _is_optional_context_type(ctx_type) + + contract = _model_context_contract(model) + + if context is not _UNSET: + if context is None and _ctx_is_optional: + return None + if isinstance(context, FlowContext): + return context + if isinstance(context, ContextBase): + if _context_matches_type(context, model.context_type): + return context + return _runtime_context_for_model(model, _context_values(context)) + return contract.runtime_context_type.model_validate(context) + + if contract.generated_model is None: + default_context = _plain_model_default_context(model) + if default_context is not _UNSET: + default_values = _plain_model_default_context_values(model, contract.runtime_context_type) + return _plain_model_compute_context_from_default(model, default_context, default_values, kwargs, contract.runtime_context_type) + if not kwargs and _ctx_is_optional: + return None + return contract.runtime_context_type.model_validate(kwargs) + + generated = contract.generated_model + config = type(generated).__flow_model_config__ + regular_kwargs = sorted(name for name in config.regular_param_names if name in kwargs) + unresolved_regular = sorted(name for name in regular_kwargs if _is_unset_flow_input(getattr(generated, name, _UNSET_FLOW_INPUT))) + if unresolved_regular: + names = ", ".join(unresolved_regular) + raise TypeError( + f"compute() cannot satisfy unbound regular parameter(s): {names}. " + "Bind them at construction time; compute() only supplies runtime context." + ) + + already_bound_regular = sorted(name for name in regular_kwargs if name not in unresolved_regular) + if already_bound_regular: + names = ", ".join(already_bound_regular) + raise TypeError( + f"compute() does not accept regular parameter override(s): {names}. " + "Those parameters are already bound on the model. Pass a context object if you need ambient fields with the same names." + ) + + base_kwargs = sorted(name for name in _model_base_field_names(generated) if name in kwargs) + if base_kwargs: + names = ", ".join(base_kwargs) + raise TypeError( + f"compute() does not accept model configuration override(s): {names}. Those fields are bound on the model at construction time." + ) + + ambient = dict(kwargs) + for param in config.contextual_params: + if param.name not in kwargs: + continue + ambient[param.name] = _coerce_contextual_value(config, param, kwargs[param.name], "compute() input") + return FlowContext(**ambient) + + +def _is_optional_context_type(context_type: Any) -> bool: + return get_origin(context_type) in _UNION_ORIGINS and type(None) in get_args(context_type) + + +def _context_matches_type(context: Any, context_type: Any) -> bool: + """Return whether an existing context object is accepted by a context annotation.""" + + if context is None: + return _is_optional_context_type(context_type) + if get_origin(context_type) in _UNION_ORIGINS: + return any(_context_matches_type(context, arg) for arg in get_args(context_type) if arg is not type(None)) + return isinstance(context_type, type) and isinstance(context, context_type) + + +def _bound_model_preserves_none_context(bound_model: "BoundModel") -> bool: + return not bound_model.context_spec.operations and _is_optional_context_type(bound_model.model.context_type) + + +def _build_bound_compute_context(bound_model: "BoundModel", context: Any, kwargs: Dict[str, Any]) -> Optional[ContextBase]: + """Construct the ambient context passed into a ``BoundModel`` by ``compute``.""" + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword inputs, but not both.") + if context is not _UNSET: + return context + if not kwargs and _bound_model_preserves_none_context(bound_model): + return None + return _bound_model_ambient_context(bound_model, kwargs) + + +def _raw_input_values_for_debug(model: CallableModel, context: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Return caller-supplied context values without executing the model.""" + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword inputs, but not both.") + if context is _UNSET: + return dict(kwargs) + if context is None: + return {} + if isinstance(context, ContextBase): + return _context_values(context) + if isinstance(context, Mapping): + return dict(context) + return _context_values(_model_context_contract(model).runtime_context_type.model_validate(context)) + + +def _partial_context_for_inspect(model: CallableModel, values: Dict[str, Any]) -> ContextBase: + contract = _model_context_contract(model) + if contract.input_types is None: + return FlowContext(**values) + return FlowContext(**{name: values[name] for name in contract.input_types if name in values}) + + +def _partial_dependency_context_for_inspect(model: CallableModel, context: ContextBase) -> ContextBase: + if isinstance(model, BoundModel): + return FlowContext(**_context_values(context)) + return _partial_context_for_inspect(model, _context_values(context)) + + +def _project_bound_dependency_context_for_inspect(model: "BoundModel", context: ContextBase) -> ContextBase: + values = _context_values(context) + projected = {name: values[name] for name in model.flow._runtime_inputs if name in values} + if isinstance(context, _BoundModelContext) and context._base_context is not None: + return _BoundModelContext.from_values(projected, base_context=context._base_context) + return FlowContext(**projected) + + +def _generated_context_argument_specs(generated: "_GeneratedFlowModelBase", input_types: Optional[Dict[str, Any]]) -> Dict[str, InputSpec]: + config = type(generated).__flow_model_config__ + explicit_fields = _bound_field_names(generated) + specs: Dict[str, InputSpec] = {} + for param in config.contextual_params: + annotation = param.annotation if input_types is None else input_types[param.name] + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if param.name in explicit_fields and not _is_unset_flow_input(value): + specs[param.name] = InputSpec(param.name, annotation, False, _UNSET_FLOW_INPUT, value, "construction") + elif param.has_function_default: + specs[param.name] = InputSpec(param.name, annotation, False, param.function_default, param.function_default, "function_default") + else: + specs[param.name] = InputSpec(param.name, annotation, True, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "runtime") + return specs + + +def _plain_context_argument_specs(model: CallableModel, contract: _ModelContextContract) -> Dict[str, InputSpec]: + if contract.input_types is None: + return {} + default_values = _plain_model_default_context_values(model, contract.runtime_context_type) + required_inputs = set(contract.required_names) + if default_values is not None: + required_inputs -= set(default_values) + specs = {} + for name, annotation in contract.input_types.items(): + if default_values is not None and name in default_values: + specs[name] = InputSpec(name, annotation, False, default_values[name], default_values[name], "context_default") + else: + specs[name] = InputSpec(name, annotation, name in required_inputs, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "runtime") + return specs + + +def _direct_dependency_specs( + model: CallableModel, + context: Optional[ContextBase] = None, + *, + trim_context: bool = True, +) -> Tuple[DependencySpec, ...]: + if isinstance(model, BoundModel): + rewritten_context = None if context is None else model._rewrite_context(context) + return _direct_dependency_specs(model.model, rewritten_context, trim_context=trim_context) + + generated = _generated_model_instance(model) + if generated is None: + return () + + specs = [] + config = type(generated).__flow_model_config__ + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value) or not _is_model_dependency(value): + continue + dependency_model = value + dependency_context = None + if context is not None: + try: + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + except (TypeError, ValueError, ValidationError): + dependency_model = value + dependency_context = _partial_dependency_context_for_inspect(value, context) + else: + contract = _model_context_contract(dependency_model) + if trim_context and dependency_context is not None: + if isinstance(dependency_model, BoundModel): + dependency_context = _project_bound_dependency_context_for_inspect(dependency_model, dependency_context) + elif contract.input_types is not None: + values = _context_values(dependency_context) + dependency_context = _runtime_context_for_model( + dependency_model, + {name: values[name] for name in contract.input_types if name in values}, + ) + specs.append(DependencySpec(param.name, dependency_model, dependency_context, lazy=param.is_lazy)) + return tuple(specs) + + +def _normalize_inspect_dependencies(dependencies: str) -> Literal["none", "direct", "recursive"]: + if dependencies == "none": + return "none" + if dependencies == "direct": + return "direct" + if dependencies == "recursive": + return "recursive" + raise ValueError("dependencies must be one of: direct, none, recursive") + + +def _is_missing_contextual_input_error(error: TypeError) -> bool: + return str(error).startswith("Missing contextual input(s)") + + +def _recursive_dependency_specs_for_flow( + flow: "FlowAPI", + context: Optional[ContextBase], + *, + prefix: str = "", + lazy_parent: bool = False, + active: Optional[Set[int]] = None, +) -> Tuple[DependencySpec, ...]: + """Return inspect-visible dependency specs below ``flow`` with prefixed paths.""" + + active = set() if active is None else active + model_id = id(flow._compute_target) + if model_id in active: + return () + + active.add(model_id) + try: + try: + argument_context = flow._argument_context(context) + direct_dependencies = flow._dependency_specs_for_inspect( + context, + argument_context, + preserve_ambient_context=True, + ) + except TypeError as exc: + if not _is_missing_contextual_input_error(exc): + raise + return () + result = [] + for dependency in direct_dependencies: + path = f"{prefix}.{dependency.path}" if prefix else dependency.path + lazy = lazy_parent or dependency.lazy + prefixed = DependencySpec(path, dependency.model, dependency.context, lazy) + result.append(prefixed) + result.extend( + _recursive_dependency_specs_for_flow( + dependency.model.flow, + dependency.context, + prefix=path, + lazy_parent=lazy, + active=active, + ) + ) + return tuple(result) + finally: + active.remove(model_id) + + +# --------------------------------------------------------------------------- +# model.flow API and BoundModel wrapper +# --------------------------------------------------------------------------- + + +class FlowAPI: + """API namespace exposed as ``model.flow``. + + ``FlowAPI`` works for both generated models and ordinary ``CallableModel`` + instances. Generated models get richer introspection because their function + signature declares regular and contextual inputs. Plain models only expose + what can be inferred from their ``context_type`` and pydantic fields. + """ + + _PUBLIC_HELP: ClassVar[Dict[str, str]] = { + "compute": "Evaluate the model from a context object or runtime keyword context.", + "with_context": "Return a new wrapper that binds or rewrites runtime context before evaluation; it does not mutate this model.", + "inspect": "Return a readable debugging summary. Options: dependencies='direct|recursive|none'.", + } + _PUBLIC_NAMES: ClassVar[Tuple[str, ...]] = tuple(_PUBLIC_HELP) + + def __init__(self, model: CallableModel): + self._model = model + + def __dir__(self) -> List[str]: + """Return a focused list of public helpers for interactive autocomplete.""" + + return sorted(self._PUBLIC_NAMES) + + def __repr__(self) -> str: + helpers = ", ".join(self._PUBLIC_NAMES) + return f"{type(self).__name__}(model={self._compute_target!r}, helpers=[{helpers}])" + + @property + def _compute_target(self) -> CallableModel: + return self._model + + def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = None, **kwargs) -> Any: + """Evaluate the model after building a runtime context from ``context`` or kwargs.""" + + target = self._compute_target + built_context = _build_compute_context(target, context, kwargs) + return _maybe_auto_unwrap_external_result(target, target(built_context, _options=_options)) + + @property + def _context_inputs(self) -> Dict[str, Any]: + """Declared contextual input names and expected types for this model.""" + + contract = _model_context_contract(self._model) + return dict(contract.input_types or {}) + + @property + def _runtime_inputs(self) -> Dict[str, Any]: + """Direct runtime context fields this model may read from the caller.""" + + return self._context_inputs + + @property + def _required_inputs(self) -> Dict[str, Any]: + """Required direct runtime context fields still needed from the caller.""" + + contract = _model_context_contract(self._model) + if contract.generated_model is None and _is_optional_context_type(self._model.context_type): + return {} + if contract.generated_model is None: + if contract.input_types is None: + return {} + result = {name: contract.input_types[name] for name in contract.required_names} + default_values = _plain_model_default_context_values(self._model, contract.runtime_context_type) + if default_values is not None: + for name in default_values: + result.pop(name, None) + return result + + generated = contract.generated_model + config = type(generated).__flow_model_config__ + result = {} + for param in config.contextual_params: + if not _is_unset_flow_input(getattr(generated, param.name, _UNSET_FLOW_INPUT)): + continue + if param.has_function_default: + continue + result[param.name] = param.annotation if contract.input_types is None else contract.input_types[param.name] + return result + + @property + def _context_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of declared direct contextual inputs.""" + + contract = _model_context_contract(self._model) + if contract.generated_model is not None: + return _generated_context_argument_specs(contract.generated_model, contract.input_types) + return _plain_context_argument_specs(self._model, contract) + + @property + def _runtime_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of direct runtime inputs read from the caller.""" + + specs = self._context_argument_specs + runtime_names = set(self._runtime_inputs) + required_names = set(self._required_inputs) + result = {} + for name in runtime_names: + spec = specs.get(name) + result[name] = InputSpec( + name, + self._runtime_inputs[name], + name in required_names, + spec.default if spec is not None else _UNSET_FLOW_INPUT, + spec.value if spec is not None else _UNSET_FLOW_INPUT, + spec.source if spec is not None else "runtime", + ) + return result + + @property + def _argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of direct construction and contextual inputs.""" + + generated = _model_context_contract(self._model).generated_model + if generated is None: + return self._context_argument_specs + + config = type(generated).__flow_model_config__ + specs: Dict[str, InputSpec] = {} + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + specs[param.name] = InputSpec(param.name, param.annotation, False, _UNSET_FLOW_INPUT, value, "construction") + elif param.has_function_default: + specs[param.name] = InputSpec(param.name, param.annotation, False, param.function_default, param.function_default, "function_default") + else: + specs[param.name] = InputSpec(param.name, param.annotation, True, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "construction") + specs.update(self._context_argument_specs) + return specs + + def inspect( + self, + context: Any = _UNSET, + /, + *, + dependencies: Literal["direct", "recursive", "none"] = "direct", + **kwargs, + ) -> FlowInspection: + """Return a readable direct debugging summary for this model. + + Args: + context: Optional runtime context object. + dependencies: Dependency inspection depth. ``"direct"`` includes + only immediate dependencies; ``"recursive"`` follows + inspect-visible dependencies recursively; ``"none"`` skips + dependency inspection. + **kwargs: Runtime context field values. + """ + + dependency_depth = _normalize_inspect_dependencies(dependencies) + dependency_context = None + argument_context = None + dependency_specs: Tuple[DependencySpec, ...] + if context is not _UNSET or kwargs: + raw_values = _raw_input_values_for_debug(self._compute_target, context, kwargs) + try: + dependency_context = _build_compute_context(self._compute_target, context, kwargs) + except (TypeError, ValidationError): + dependency_context = _partial_context_for_inspect(self._compute_target, raw_values) + try: + argument_context = self._argument_context(dependency_context) + except (TypeError, ValidationError) as exc: + if isinstance(exc, TypeError) and not _is_missing_contextual_input_error(exc): + raise + argument_context = None + dependency_specs = self._dependency_specs_for_inspect(dependency_context, argument_context) if dependency_depth == "direct" else () + if dependency_depth == "recursive": + dependency_specs = _recursive_dependency_specs_for_flow(self, dependency_context) + else: + try: + argument_context = self._argument_context(None) + except (TypeError, ValidationError) as exc: + if isinstance(exc, TypeError) and not _is_missing_contextual_input_error(exc): + raise + argument_context = None + dependency_specs = self._dependency_specs_for_inspect(None, argument_context) if dependency_depth == "direct" else () + if dependency_depth == "recursive": + dependency_specs = _recursive_dependency_specs_for_flow(self, None) + return FlowInspection( + model=self._compute_target, + context_inputs=self._context_inputs, + runtime_inputs=self._runtime_inputs, + required_inputs=self._required_inputs, + bound_inputs=self._bound_inputs, + inputs=self._argument_specs_for_context(argument_context), + dependencies=dependency_specs, + ) + + def _argument_context(self, context: Optional[ContextBase]) -> Optional[ContextBase]: + return context + + def _dependency_specs_for_inspect( + self, + dependency_context: Optional[ContextBase], + _argument_context: Optional[ContextBase], + *, + preserve_ambient_context: bool = False, + ) -> Tuple[DependencySpec, ...]: + return _direct_dependency_specs(self._compute_target, dependency_context, trim_context=not preserve_ambient_context) + + def _context_argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + result = dict(self._context_argument_specs) + if context is None: + return result + values = _context_values(context) + for name, spec in list(result.items()): + if name in values: + result[name] = InputSpec(name, spec.type, False, spec.default, values[name], "runtime") + return result + + def _runtime_argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + specs = self._context_argument_specs_for_context(context) + runtime_names = set(self._runtime_inputs) + required_names = set(self._required_inputs) + result = {} + for name in runtime_names: + spec = specs.get(name) + result[name] = InputSpec( + name, + self._runtime_inputs[name], + name in required_names, + spec.default if spec is not None else _UNSET_FLOW_INPUT, + spec.value if spec is not None else _UNSET_FLOW_INPUT, + spec.source if spec is not None else "runtime", + ) + return result + + def _argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + generated = _model_context_contract(self._model).generated_model + if generated is None: + return self._context_argument_specs_for_context(context) + + config = type(generated).__flow_model_config__ + specs: Dict[str, InputSpec] = {} + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + specs[param.name] = InputSpec(param.name, param.annotation, False, _UNSET_FLOW_INPUT, value, "construction") + elif param.has_function_default: + specs[param.name] = InputSpec(param.name, param.annotation, False, param.function_default, param.function_default, "function_default") + else: + specs[param.name] = InputSpec(param.name, param.annotation, True, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "construction") + specs.update(self._context_argument_specs_for_context(context)) + return specs + + @property + def _bound_inputs(self) -> Dict[str, Any]: + """Inputs already fixed by construction-time values or static context bindings.""" + + generated = _model_context_contract(self._model).generated_model + if generated is not None: + config = type(generated).__flow_model_config__ + result: Dict[str, Any] = {} + explicit_fields = _bound_field_names(generated) + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value + for param in config.contextual_params: + if param.name not in explicit_fields: + continue + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value + for name in _model_base_field_names(generated): + if name in explicit_fields: + result[name] = getattr(generated, name) + return result + + result: Dict[str, Any] = {} + model_fields = getattr(self._model.__class__, "model_fields", {}) + for name in model_fields: + if name == "meta": + continue + result[name] = getattr(self._model, name) + return result + + def with_context(self, *patches, **field_overrides) -> "BoundModel": + """Return a wrapper that rewrites runtime context before evaluating this model.""" + + context_spec = _normalize_with_context(self._model, patches, field_overrides) + return BoundModel(model=self._model, context_spec=context_spec) + + +class BoundModel(WrapperModel): + """A wrapper that rewrites context for exactly one wrapped model. + + ``BoundModel`` is deliberately a ``WrapperModel`` rather than mutating the + wrapped model. This keeps context bindings scoped to the edge where they + are used, lets dependency graphs show the wrapped model, and derives + effective identity from the rewritten wrapped invocation. + """ + + context_spec: _BoundContextSpec = Field(default_factory=_BoundContextSpec, repr=False) + + def _rewrite_context(self, context: ContextBase) -> ContextBase: + """Apply this wrapper's context bindings to an ambient runtime context.""" + + return _apply_context_spec(self.model, self.context_spec, context) + + @property + def context_type(self) -> Any: + if _bound_model_preserves_none_context(self): + return self.model.context_type + return _BoundModelContext + + @Flow.call + def __call__(self, context: ContextType) -> ResultBase: + """Evaluate the wrapped model after rewriting context.""" + + if context is None and _bound_model_preserves_none_context(self): + return self.model(None) + return self.model(self._rewrite_context(context)) + + @Flow.deps + def __deps__(self, context: ContextType) -> GraphDepList: + """Expose the wrapped model as the single dependency of this binding wrapper.""" + + if context is None and _bound_model_preserves_none_context(self): + return self.model.__deps__(None) + return [(self.model, [self._rewrite_context(context)])] + + def __repr__(self) -> str: + args = [] + for operation in _effective_context_operations(self.context_spec): + if isinstance(operation, PatchContextOperation): + args.append(_context_transform_repr(operation.binding)) + continue + value = operation.spec if isinstance(operation.spec, ContextTransform) else operation.spec.value + args.append(f"{operation.name}={_context_transform_repr(value)}") + return f"{self.model!r}.flow.with_context({', '.join(args)})" + + def _evaluation_identity_payload( + self, + context: ContextBase, + ) -> Optional[Any]: + """Describe this binding in terms of the rewritten wrapped call.""" + + return { + "kind": "bound_model_v1", + "model": EvaluationDependency(self.model, self._rewrite_context(context)), + } + + @property + def flow(self) -> "FlowAPI": + """Access bound flow helpers for execution, context transforms, and introspection.""" + + return _BoundFlowAPI(self) + + +class _BoundFlowAPI(FlowAPI): + """``model.flow`` implementation for ``BoundModel`` wrappers.""" + + def __init__(self, bound_model: BoundModel): + self._bound = bound_model + super().__init__(bound_model.model) + + @property + def _compute_target(self) -> CallableModel: + return self._bound + + def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = None, **kwargs) -> Any: + """Evaluate the bound wrapper after building its ambient context.""" + + built_context = _build_bound_compute_context(self._bound, context, kwargs) + return _maybe_auto_unwrap_external_result(self._bound, self._bound(built_context, _options=_options)) + + def _argument_context(self, context: Optional[ContextBase]) -> Optional[ContextBase]: + if context is None: + if _bound_model_preserves_none_context(self._bound): + return None + _supplied_fields, required_dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=True) + if required_dynamic_inputs: + return None + return self._bound._rewrite_context(_bound_model_ambient_context(self._bound, {})) + return self._bound._rewrite_context(context) + + def _dependency_specs_for_inspect( + self, + _dependency_context: Optional[ContextBase], + argument_context: Optional[ContextBase], + *, + preserve_ambient_context: bool = False, + ) -> Tuple[DependencySpec, ...]: + return _direct_dependency_specs(self._bound.model, argument_context, trim_context=not preserve_ambient_context) + + @property + def _bound_inputs(self) -> Dict[str, Any]: + """Concrete values already fixed, including statically resolved context bindings.""" + + result = super()._bound_inputs + for name in self._context_inputs: + result.pop(name, None) + result.update(_statically_resolved_context_field_values(self._bound.model, self._bound.context_spec)) + return result + + @property + def _context_inputs(self) -> Dict[str, Any]: + """Declared contextual inputs of the wrapped model.""" + + return super()._context_inputs + + @property + def _context_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of wrapped-model contextual inputs after bindings.""" + + result = dict(super()._context_argument_specs) + for name, value in _statically_resolved_context_field_values(self._bound.model, self._bound.context_spec).items(): + if name in result: + spec = result[name] + result[name] = InputSpec(name, spec.type, False, spec.default, value, "with_context") + + supplied_fields, _dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=False) + for name in supplied_fields: + if name in result: + spec = result[name] + result[name] = InputSpec(name, spec.type, False, spec.default, _UNSET_FLOW_INPUT, "context_transform") + return result + + def _context_argument_specs_for_context(self, context: Optional[ContextBase]) -> Dict[str, InputSpec]: + result = dict(self._context_argument_specs) + if context is None: + return result + values = _context_values(context) + for name, spec in list(result.items()): + if name not in values: + continue + source = spec.source if spec.source in {"context_transform", "with_context"} else "runtime" + result[name] = InputSpec(name, spec.type, False, spec.default, values[name], source) + return result + + @property + def _runtime_inputs(self) -> Dict[str, Any]: + """Direct runtime context inputs after applying this wrapper's bindings. + + Static context transforms may be evaluated to identify resolved fields. + Dynamic transforms contribute their own runtime context inputs. + """ + + result = super()._context_inputs + for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): + result.pop(name, None) + supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=False) + for name in supplied_fields: + result.pop(name, None) + _merge_context_input_types(result, dynamic_inputs) + return result + + @property + def _runtime_argument_specs(self) -> Dict[str, InputSpec]: + """Rich descriptions of runtime inputs after applying this wrapper's bindings.""" + + result = {} + base_specs = self._context_argument_specs + for name, annotation in self._runtime_inputs.items(): + base = base_specs.get(name) + result[name] = InputSpec( + name, + annotation, + name in self._required_inputs, + base.default if base is not None else _UNSET_FLOW_INPUT, + base.value if base is not None else _UNSET_FLOW_INPUT, + base.source if base is not None else "runtime", + ) + _supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=False) + required_dynamic = set(_dynamic_context_operation_effects(self._bound.context_spec, required_only=True)[1]) + for name, annotation in dynamic_inputs.items(): + result[name] = InputSpec(name, annotation, name in required_dynamic, _UNSET_FLOW_INPUT, _UNSET_FLOW_INPUT, "context_transform") + return result + + @property + def _required_inputs(self) -> Dict[str, Any]: + """Required direct runtime context inputs still missing after static bindings. + + Static context transforms may be evaluated to identify resolved fields. + Dynamic transforms contribute their own required runtime context inputs. + """ + + result = super()._required_inputs + for name in _statically_resolved_context_field_names(self._bound.model, self._bound.context_spec): + result.pop(name, None) + supplied_fields, dynamic_inputs = _dynamic_context_operation_effects(self._bound.context_spec, required_only=True) + for name in supplied_fields: + result.pop(name, None) + _merge_context_input_types(result, dynamic_inputs) + contract = _model_context_contract(self._bound.model) + if contract.generated_model is None: + default_values = _plain_model_default_context_values(self._bound.model, contract.runtime_context_type) + if default_values is not None: + for name in default_values: + result.pop(name, None) + return result + + def with_context(self, *patches, **field_overrides) -> BoundModel: + context_spec = _normalize_with_context(self._bound.model, patches, field_overrides) + merged = _BoundContextSpec(operations=[*self._bound.context_spec.operations, *context_spec.operations]) + return BoundModel( + model=self._bound.model, + context_spec=_validate_static_context_spec_declared_context(self._bound.model, merged), + ) + + +class _GeneratedFlowModelBase(CallableModel): + """Base class for all classes created by ``@Flow.model``.""" + + __flow_model_config__: ClassVar[_FlowModelConfig] + __flow_model_identity__: ClassVar[str] + + def __reduce__(self): + """Prefer import-path restoration, falling back to serialized local factories.""" + + config = type(self).__flow_model_config__ + + state = self.__getstate__() + factory_path = _generated_model_factory_path_for_pickle(config, type(self)) + if factory_path is not None: + return (_new_generated_flow_model_for_pickle, (factory_path,), state) + + # Local generated classes are not normal importable class definitions: + # plain pickle cannot reconstruct them, and Ray workers should not + # re-run fragile annotation resolution. Carry the analyzed contract + # separately, but leave instance state in the outer pickle stream so + # normal pickle memo semantics remain intact. + payload = _LocalFlowModelPicklePayload( + serialized_config=_serialize_flow_model_config(config), + factory_kwargs=type(self).__flow_model_factory_kwargs__, + ) + return (_new_local_flow_model_for_pickle, (cloudpickle.dumps(payload, protocol=5),), state) + + @model_validator(mode="after") + def _validate_flow_model_fields(self): + """Validate all bound regular and contextual defaults after pydantic construction.""" + + config = self.__class__.__flow_model_config__ + + for param in config.parameters: + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + if not param.is_contextual: + value = _resolve_bound_param_registry_ref(param, value) + object.__setattr__( + self, + param.name, + _validate_bound_param_value(config, param, value, "Contextual default" if param.is_contextual else "Field"), + ) + + _validate_bound_declared_context_defaults(self, config) + return self + + @property + def context_type(self) -> Type[ContextBase]: + return self.__class__.__flow_model_config__.context_type + + @property + def result_type(self) -> Type[ResultBase]: + return self.__class__.__flow_model_config__.result_type + + def _evaluation_identity_payload( + self, + context: ContextBase, + ) -> Optional[Any]: + return _generated_model_identity_payload(self, context) + + +# --------------------------------------------------------------------------- +# Generated model method builders and decorators +# --------------------------------------------------------------------------- + + +def _make_call_impl(config: _FlowModelConfig) -> _AnyCallable: + """Create the ``__call__`` implementation for one generated model class.""" + + def __call__(self, context): + """Resolve bound inputs, dependency inputs, and context inputs, then call the user function.""" + + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError( + f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. " + "Bind them at construction time; compute() only supplies contextual inputs." + ) + + fn_kwargs: Dict[str, Any] = {} + for param in config.regular_params: + value = _resolve_regular_param_value(self, param, context) + if param.is_lazy: + fn_kwargs[param.name] = _make_coercing_lazy_thunk(value, param.name, param.annotation) + else: + fn_kwargs[param.name] = _coerce_value(param.name, value, param.annotation, "Regular parameter") + + fn_kwargs.update(_resolved_contextual_inputs(self, config, context)) + + raw_result = config.func(**fn_kwargs) + if config.auto_wrap_result: + return GenericResult(value=raw_result) + return raw_result + + cast(Any, __call__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=config.result_type, + ) + return __call__ + + +def _make_deps_impl(config: _FlowModelConfig) -> _AnyCallable: + """Create the ``__deps__`` implementation for one generated model class.""" + + def __deps__(self, context): + """Declare non-lazy regular ``CallableModel`` inputs as graph dependencies.""" + + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError(f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. Bind them before dependency evaluation.") + + deps = [] + for param in config.regular_params: + if param.is_lazy: + continue + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if _is_model_dependency(value): + dependency_model, dependency_context = _resolved_dependency_invocation(value, context) + deps.append((dependency_model, [dependency_context])) + return deps + + cast(Any, __deps__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ + + +def _factory_param_annotation(param: _FlowModelParam) -> Any: + """Return the public factory signature annotation for a user function parameter. + + The factory signature keeps the user's surface syntax visible: contextual + params show as FromContext[T], lazy params show as Lazy[T], and regular + params keep T. The generated Pydantic field annotation below has a different + job: describing the stored construction value after binding. + """ + + if param.is_contextual: + return FromContext[param.annotation] + if param.is_lazy: + return Lazy[param.annotation] + return param.annotation + + +def _pydantic_schema_safe_annotation(annotation: Any) -> Any: + # Only the generated Pydantic field declaration falls back to Any for known + # Pydantic schema-build failures. Runtime coercion still builds the real + # TypeAdapter and propagates unexpected errors. + try: + _type_adapter(annotation) + except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation): + return Any + return annotation + + +def _generated_field_annotation(param: _FlowModelParam) -> Any: + """Return the generated model field annotation used for schema only. + + This differs from _factory_param_annotation: the factory signature describes + user-facing inputs (FromContext[T], Lazy[T]), while this describes the value + stored on the generated Pydantic model after binding. SkipValidation keeps + this schema visible without letting Pydantic enforce it before ccflow can + distinguish literals, dependencies, lazy deps, and contextual defaults. + """ + + if param.is_contextual: + annotation = param.validation_annotation + elif param.is_lazy: + annotation = CallableModel + elif param.annotation is Any or param.annotation is inspect.Parameter.empty: + annotation = Any + else: + annotation = Union[param.annotation, CallableModel] + return SkipValidation[_pydantic_schema_safe_annotation(annotation)] + + +def _factory_signature(config: _FlowModelConfig, generated_cls: Type[BaseModel]) -> inspect.Signature: + """Return the public construction signature for a generated factory.""" + + parameters = [] + for param in config.parameters: + parameters.append( + inspect.Parameter( + param.name, + inspect.Parameter.KEYWORD_ONLY, + annotation=_factory_param_annotation(param), + default=param.function_default if param.has_function_default else _UNSET_FLOW_INPUT, + ) + ) + + param_names = {param.name for param in config.parameters} + for name, field_info in generated_cls.model_fields.items(): + if name == "meta" or name in param_names: + continue + default = _UNSET_FLOW_INPUT if field_info.is_required() or getattr(field_info, "default_factory", None) is not None else field_info.default + parameters.append( + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.annotation, + default=default, + ) + ) + + return inspect.Signature(parameters=parameters, return_annotation=generated_cls) + + +def _context_transform_factory_signature(config: _FlowModelConfig) -> inspect.Signature: + """Return the public binding signature for a context transform factory.""" + + parameters = [] + for param in config.regular_params: + parameters.append( + inspect.Parameter( + param.name, + inspect.Parameter.KEYWORD_ONLY, + annotation=param.annotation, + default=param.function_default if param.has_function_default else inspect.Parameter.empty, + ) + ) + return inspect.Signature(parameters=parameters) + + +def _resolve_generated_model_bases(model_base: Type[CallableModel]) -> Tuple[type, ...]: + """Return the class bases for a generated model, preserving custom model bases.""" + + if not isinstance(model_base, type) or not issubclass(model_base, CallableModel): + raise TypeError(f"model_base must be a CallableModel subclass, got {model_base!r}") + + if issubclass(model_base, _GeneratedFlowModelBase): + return (model_base,) + if model_base is CallableModel: + return (_GeneratedFlowModelBase,) + return (_GeneratedFlowModelBase, model_base) + + +def _build_flow_model_factory_from_config(config: _FlowModelConfig, factory_kwargs: Dict[str, Any]) -> _AnyCallable: + """Build the generated model class/factory from an analyzed flow-model config.""" + + fn = config.func + factory_kwargs = dict(factory_kwargs) + model_base = factory_kwargs["model_base"] + # Preserve one generated-model identity across local pickle/cloudpickle + # restore. See ``_flow_model_config_identity`` for why this must be fixed at + # class-construction time instead of recalculated on every cache-key build. + config_identity = factory_kwargs.setdefault("_flow_model_identity", _flow_model_config_identity(config)) + field_definitions: Dict[str, Any] = {} + + for param in config.parameters: + annotation = _generated_field_annotation(param) + if param.is_contextual: + default = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + elif param.has_function_default: + default = param.function_default + else: + default = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + field_definitions[param.name] = (annotation, default) + + GeneratedModel = cast( + type[_GeneratedFlowModelBase], + create_model( + f"_{_callable_name(fn)}_Model", + __base__=_resolve_generated_model_bases(model_base), + __module__=getattr(fn, "__module__", __name__), + __qualname__=f"_{_callable_name(fn)}_Model", + **field_definitions, + ), + ) + GeneratedModel.__call__ = Flow.call( + **{ + name: value + for name in ("cacheable", "volatile", "log_level", "validate_result", "verbose", "evaluator") + if (value := factory_kwargs.get(name, _UNSET)) is not _UNSET + } + )(_make_call_impl(config)) + GeneratedModel.__deps__ = Flow.deps(_make_deps_impl(config)) + update_abstractmethods(GeneratedModel) + GeneratedModel.__flow_model_config__ = config + GeneratedModel.__flow_model_factory_kwargs__ = factory_kwargs + GeneratedModel.__flow_model_identity__ = config_identity + GeneratedModel.__ccflow_tokenizer_cache__ = _generated_model_behavior_token(config_identity, model_base) + _register_generated_model_class(config, GeneratedModel) + GeneratedModel.model_rebuild() + + @wraps(fn) + def factory(**kwargs) -> _GeneratedFlowModelBase: + """Create a generated model instance with regular/contextual defaults bound.""" + + return GeneratedModel(**kwargs) + + cast(Any, factory)._generated_model = GeneratedModel + cast(Any, factory).__signature__ = _factory_signature(config, GeneratedModel) + factory.__doc__ = fn.__doc__ + return factory + + +def _flow_context_transform(func: Optional[_AnyCallable] = None) -> _AnyCallable: + """Decorator that turns a function into a serializable ``with_context`` transform factory. + + Regular parameters are bound when the transform factory is called. + ``FromContext`` parameters are read from the runtime context when the bound + model executes. Transform functions returning mappings are positional patch + transforms; transforms returning scalar values are field transforms. + """ + + def decorator(fn: _AnyCallable) -> _AnyCallable: + """Analyze one transform function and return its binding factory.""" + + _ensure_named_python_function(fn, decorator_name="@Flow.context_transform") + resolved_hints = get_type_hints(fn, include_extras=True) + sig = _resolved_flow_signature( + fn, + resolved_hints=resolved_hints, + require_return_annotation=True, + function_name=_callable_name(fn), + ) + config = _analyze_flow_context_transform(fn, sig, is_model_dependency=_is_model_dependency) + # Store the analyzed transform contract directly. Import-path detection + # during decoration is brittle because module globals usually still point + # at the undecorated function until the decorator returns. + serialized_config = _serialize_context_transform_config(config) + + @wraps(fn) + def factory(**kwargs) -> ContextTransform: + """Bind regular transform arguments into a serializable spec.""" + + return ContextTransform( + serialized_config=serialized_config, + bound_args=_validate_context_transform_factory_kwargs(config, kwargs), + ) + + cast(Any, factory).__flow_context_transform_config__ = config + cast(Any, factory).__signature__ = _context_transform_factory_signature(config) + return factory + + if func is not None: + return decorator(func) + return decorator + + +def flow_model( + func: Optional[_AnyCallable] = None, + *, + context_type: Optional[Type[ContextBase]] = None, + auto_unwrap: bool = False, + model_base: Type[CallableModel] = CallableModel, + cacheable: Any = _UNSET, + volatile: Any = _UNSET, + log_level: Any = _UNSET, + validate_result: Any = _UNSET, + verbose: Any = _UNSET, + evaluator: Any = _UNSET, +) -> _AnyCallable: + """Decorator that generates a ``CallableModel`` class from a plain function. + + Unmarked parameters become construction-time model fields. Parameters + annotated as ``FromContext[T]`` are contextual inputs supplied by + ``FlowContext``, a declared ``context_type``, ``compute(...)`` kwargs, or + ``with_context(...)`` bindings. The returned object is a factory that + creates instances of the generated model class. + + Args: + func: The function being decorated. This is passed automatically in + bare decorator form, for example ``@Flow.model``. When using + options, for example ``@Flow.model(auto_unwrap=True)``, Python first + calls ``Flow.model(...)`` without a function and then applies the + returned decorator. + context_type: Optional ``ContextBase`` subclass used to validate all + contextual inputs together after individual ``FromContext[...]`` + fields are resolved. + auto_unwrap: When ``True`` and ccflow auto-wraps a plain return + annotation in ``GenericResult[T]``, external + ``model.flow.compute(...)`` calls return the raw ``T`` value instead + of ``GenericResult[T]``. Dependency evaluation and direct model calls + keep the normal ccflow result contract. + model_base: Custom ``CallableModel`` subclass to use as the base class + for the generated model. + cacheable: Optional generated-model default for ``FlowOptions.cacheable``. + volatile: Optional generated-model default for ``FlowOptions.volatile``. + log_level: Optional generated-model default for ``FlowOptions.log_level``. + validate_result: Optional generated-model default for + ``FlowOptions.validate_result``. + verbose: Optional generated-model default for ``FlowOptions.verbose``. + evaluator: Optional generated-model default evaluator. + """ + + def decorator(fn: _AnyCallable) -> _AnyCallable: + """Analyze one user function and synthesize its generated model class.""" + + resolved_hints = get_type_hints(fn, include_extras=True) + sig = _resolved_flow_signature( + fn, + resolved_hints=resolved_hints, + require_return_annotation=True, + function_name=_callable_name(fn), + ) + config = _analyze_flow_model( + fn, + sig, + context_type=context_type, + auto_unwrap=auto_unwrap, + is_model_dependency=_is_model_dependency, + ) + factory_kwargs = { + "context_type": context_type, + "auto_unwrap": auto_unwrap, + "model_base": model_base, + "cacheable": cacheable, + "volatile": volatile, + "log_level": log_level, + "validate_result": validate_result, + "verbose": verbose, + "evaluator": evaluator, + } + factory_kwargs["_flow_model_identity"] = _flow_model_config_identity(config) + return _build_flow_model_factory_from_config(config, factory_kwargs) + + if func is not None: + return decorator(func) + return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml new file mode 100644 index 0000000..7963d5f --- /dev/null +++ b/ccflow/tests/config/conf_flow.yaml @@ -0,0 +1,44 @@ +# Flow.model configurations for Hydra integration tests. + +flow_loader: + _target_: ccflow.tests.flow_model_hydra_fixtures.basic_loader + source: fixture_input + multiplier: 5 + +flow_source: + _target_: ccflow.tests.flow_model_hydra_fixtures.data_source + base_value: 100 + +flow_transformer: + _target_: ccflow.tests.flow_model_hydra_fixtures.data_transformer + source: flow_source + factor: 3 + +diamond_source: + _target_: ccflow.tests.flow_model_hydra_fixtures.data_source + base_value: 10 + +diamond_branch_a: + _target_: ccflow.tests.flow_model_hydra_fixtures.data_transformer + source: diamond_source + factor: 2 + +diamond_branch_b: + _target_: ccflow.tests.flow_model_hydra_fixtures.data_transformer + source: diamond_source + factor: 5 + +diamond_aggregator: + _target_: ccflow.tests.flow_model_hydra_fixtures.data_aggregator + input_a: diamond_branch_a + input_b: diamond_branch_b + operation: add + +contextual_loader_model: + _target_: ccflow.tests.flow_model_hydra_fixtures.contextual_loader + source: data_source + +contextual_processor_model: + _target_: ccflow.tests.flow_model_hydra_fixtures.contextual_processor + data: contextual_loader_model + prefix: output diff --git a/ccflow/tests/evaluators/test_common.py b/ccflow/tests/evaluators/test_common.py index 6db6d07..8577207 100644 --- a/ccflow/tests/evaluators/test_common.py +++ b/ccflow/tests/evaluators/test_common.py @@ -10,7 +10,10 @@ DateRangeContext, Evaluator, EvaluatorBase, + Flow, + FlowContext, FlowOptionsOverride, + FromContext, ModelEvaluationContext, NullContext, TransparentModelEvaluationContext, @@ -315,6 +318,30 @@ def __call__(self, context: ModelEvaluationContext): ) assert cache_key(opaque2) == cache_key(opaque2b) + def test_opaque_mec_order_preserved(self): + """Non-transparent evaluator wrapper order is identity-significant.""" + + class OpaqueEval(EvaluatorBase): + tag: str + + def __call__(self, context: ModelEvaluationContext): + return context() + + m = MyDateCallable(offset=1) + ctx = DateContext(date=date(2022, 1, 1)) + inner = ModelEvaluationContext(model=m, context=ctx) + + inner_then_outer = ModelEvaluationContext( + model=OpaqueEval(tag="outer"), + context=ModelEvaluationContext(model=OpaqueEval(tag="inner"), context=inner), + ) + outer_then_inner = ModelEvaluationContext( + model=OpaqueEval(tag="inner"), + context=ModelEvaluationContext(model=OpaqueEval(tag="outer"), context=inner), + ) + + assert cache_key(inner_then_outer) != cache_key(outer_then_inner) + def test_fn_deps_preserved_through_transparent(self): """fn='__deps__' is preserved when walking through transparent layers.""" m = MyDateCallable(offset=1) @@ -362,6 +389,51 @@ def test_basic(self): self.assertNotIn(key, evaluator.cache) self.assertNotIn(key, evaluator.ids) + def test_plain_callable_key_matches_public_cache_key(self): + """Existing CallableModels stay on the structural cache-key path.""" + m1 = MyDateCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, options=dict(cacheable=True)) + + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context, effective=True)) + + def test_plain_callable_key_fallback_does_not_log(self): + """Normal opt-out from effective identity should stay quiet.""" + m1 = MyDateCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, options=dict(cacheable=True)) + + with self.assertNoLogs("ccflow.evaluators.common", level="DEBUG"): + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) + + def test_effective_key_identity_errors_propagate(self): + """Unexpected effective-identity failures should not be hidden by structural fallback.""" + + class BadIdentityCallable(MyDateCallable): + def _evaluation_identity_payload(self, context): + raise ValueError("identity broke") + + m1 = BadIdentityCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, options=dict(cacheable=True)) + + with self.assertRaisesRegex(ValueError, "identity broke"): + evaluator.key(model_evaluation_context) + + def test_plain_callable_deps_key_matches_public_cache_key(self): + """Non-__call__ evaluations stay structural.""" + m1 = MyDateCallable(offset=1) + evaluator = MemoryCacheEvaluator() + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__deps__", options=dict(cacheable=True)) + + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context)) + self.assertEqual(evaluator.key(model_evaluation_context), cache_key(model_evaluation_context, effective=True)) + def test_caching(self): # Create some hard-to hash structure with all kinds of custom types # We will put this on the callable to make sure caching still works @@ -511,6 +583,34 @@ def test_graph_deps_diamond(self): elif v.model.meta.name == "n0": self.assertEqual(set(graph.ids[dep_key].context.model.meta.name for dep_key in graph.graph[k]), set()) + def test_plain_callable_graph_keys_match_public_cache_key(self): + """Dependency graphs for ordinary CallableModels keep structural keys.""" + n0 = NodeModel(meta=dict(name="n0")) + n1 = NodeModel(meta=dict(name="n1"), deps_model=[n0]) + n2 = NodeModel(meta=dict(name="n2"), deps_model=[n0]) + root = NodeModel(meta=dict(name="n3"), deps_model=[n1, n2]) + context = DateContext(date=date(2022, 1, 1)) + + graph = get_dependency_graph(ModelEvaluationContext(model=root, context=context)) + + for key, evaluation_context in graph.ids.items(): + self.assertEqual(key, cache_key(evaluation_context)) + self.assertEqual(graph.root_id, cache_key(ModelEvaluationContext(model=root, context=context))) + + def test_plain_callable_graph_deduplicates_equal_models_by_key(self): + """Ordinary graph traversal deduplicates structurally equal nodes by key.""" + leaf1 = NodeModel(meta=dict(name="leaf")) + leaf2 = NodeModel(meta=dict(name="leaf")) + root = NodeModel(meta=dict(name="root"), deps_model=[leaf1, leaf2]) + context = DateContext(date=date(2022, 1, 1)) + + NodeModel._deps_calls = [] + graph = get_dependency_graph(ModelEvaluationContext(model=root, context=context)) + + self.assertEqual(NodeModel._deps_calls, [("root", date(2022, 1, 1)), ("leaf", date(2022, 1, 1))]) + self.assertEqual(len(graph.ids), 2) + self.assertEqual(graph.ids.keys(), graph.graph.keys()) + def test_graph_deps_circular(self): root = CircularModel() context = DateContext(date=date(2022, 1, 1)) @@ -548,6 +648,27 @@ def test_graph_evaluator_basic(self): self.assertIn(("n0", date(2022, 1, 1)), graph_calls[:2]) self.assertIn(("n1", date(2022, 1, 1)), graph_calls[:2]) + def test_graph_evaluator_root_id_uses_built_graph_key(self): + calls = 0 + + @Flow.context_transform + def bump(seed: FromContext[int]) -> int: + nonlocal calls + calls += 1 + return seed + calls + + @Flow.model + def add(a: FromContext[int]) -> int: + return a + + model = add().flow.with_context(a=bump()) + graph = get_dependency_graph(model.__call__.get_evaluation_context(model, FlowContext(seed=10))) + result = model.flow.compute(FlowContext(seed=10), _options={"evaluator": GraphEvaluator()}) + + self.assertIn(graph.root_id, graph.ids) + self.assertIsNotNone(result) + self.assertGreater(result.value, 10) + def test_graph_evaluator_diamond(self): n0 = NodeModel(meta=dict(name="n0")) n1 = NodeModel(meta=dict(name="n1"), deps_model=[n0]) diff --git a/ccflow/tests/flow_model_hydra_fixtures.py b/ccflow/tests/flow_model_hydra_fixtures.py new file mode 100644 index 0000000..2ebb022 --- /dev/null +++ b/ccflow/tests/flow_model_hydra_fixtures.py @@ -0,0 +1,49 @@ +"""Flow.model fixtures used by Hydra integration tests.""" + +from datetime import date + +from ccflow import Flow, FromContext, GenericResult + + +@Flow.model +def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + +@Flow.model +def data_source(base_value: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + base_value) + + +@Flow.model +def data_transformer(source: int, factor: int) -> GenericResult[int]: + return GenericResult(value=source * factor) + + +@Flow.model +def data_aggregator(input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: + if operation == "add": + return GenericResult(value=input_a + input_b) + raise ValueError(f"unsupported operation: {operation}") + + +@Flow.model +def contextual_loader(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) + + +@Flow.model +def contextual_processor( + prefix: str, + data: dict, + start_date: FromContext[date], + end_date: FromContext[date], +) -> GenericResult[str]: + del start_date, end_date + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 43f86b5..d39f3e3 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -14,6 +14,7 @@ Flow, GenericResult, GraphDepList, + Lazy, MetaData, ModelRegistry, NullContext, @@ -783,3 +784,263 @@ class MyCallableParent_bad_decorator(MyCallableParent): @Flow.deps def foo(self, context): return [] + + +class TestAutoContext(TestCase): + """Tests for the opt-in @Flow.call(auto_context=...) path.""" + + def test_direct_function_call_form(self): + class AutoContextCallable(CallableModel): + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + __call__ = Flow.call(__call__, auto_context=True) + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertEqual(AutoContextCallable()(auto_ctx(x=3)).value, 3) + + def test_basic_usage_with_kwargs(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = AutoContextCallable() + + self.assertEqual(model(x=42, y="hello").value, "42-hello") + self.assertEqual(model(x=10).value, "10-default") + + def test_no_arg_call_uses_generated_context_defaults_only_for_auto_context(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int = 1, y: str = "a") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + class PlainCallable(CallableModel): + @Flow.call + def __call__(self, context: MyContext = MyContext(a="plain")) -> MyResult: + return MyResult(x=1, y=context.a) + + self.assertEqual(AutoContextCallable()().value, "1-a") + self.assertEqual(PlainCallable()().y, "plain") + + def test_no_arg_call_still_rejects_required_generated_context_fields(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "a") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + with self.assertRaisesRegex(TypeError, "missing 1 required positional argument: 'context'"): + AutoContextCallable()() + + def test_auto_context_attribute_and_registration(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + inner = AutoContextCallable.__call__.__wrapped__ + self.assertTrue(hasattr(inner, "__auto_context__")) + + auto_ctx = inner.__auto_context__ + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertIn("value", auto_ctx.model_fields) + self.assertTrue(hasattr(auto_ctx, "__ccflow_import_path__")) + self.assertTrue(auto_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + def test_call_with_context_object(self): + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + ctx = auto_ctx(x=99, y="context") + + self.assertEqual(AutoContextCallable()(ctx).value, "99-context") + + def test_with_parent_context(self): + class ParentContext(ContextBase): + base_value: str = "base" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + self.assertTrue(issubclass(auto_ctx, ParentContext)) + self.assertIn("base_value", auto_ctx.model_fields) + self.assertIn("x", auto_ctx.model_fields) + self.assertEqual(AutoContextCallable()(x=42, base_value="custom").value, "42-custom") + + def test_parent_fields_must_be_in_signature(self): + class ParentContext(ContextBase): + required_field: str + + with self.assertRaisesRegex(TypeError, "must be included in function signature"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + def test_parent_field_type_incompatibility_rejected(self): + class ParentContext(ContextBase): + base: int + + with self.assertRaisesRegex(TypeError, "incompatible"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, base: str) -> GenericResult: + return GenericResult(value=base) + + def test_parent_field_defaults_remain_authoritative_for_auto_context(self): + class ParentContext(ContextBase): + base: str = "parent" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, base: str = "function") -> GenericResult: + return GenericResult(value=base) + + self.assertEqual(AutoContextCallable()().value, "parent") + + def test_cloudpickle_roundtrip(self): + class AutoContextCallable(CallableModel): + multiplier: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + restored = rcploads(rcpdumps(AutoContextCallable(multiplier=3))) + + self.assertEqual(restored(x=10).value, 30) + + def test_ray_task_execution(self): + class AutoContextCallable(CallableModel): + factor: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(AutoContextCallable(factor=5), x=10, y=2)) + + self.assertEqual(result, 60) + + def test_postponed_annotations_are_resolved(self): + namespace = {} + exec( + """ +from __future__ import annotations + +from ccflow import CallableModel, Flow, GenericResult + + +class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + +result = AutoContextCallable().flow.compute(x=1) +""", + namespace, + namespace, + ) + + self.assertEqual(namespace["result"].value, 1) + + def test_postponed_annotations_unresolved_names_stay_loud(self): + namespace = {} + with self.assertRaises(NameError): + exec( + """ +from __future__ import annotations + +from ccflow import CallableModel, Flow, GenericResult + + +class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: MissingType) -> GenericResult[int]: + return GenericResult(value=x) +""", + namespace, + namespace, + ) + + def test_normal_keyword_only_flow_call_without_auto_context_still_fails(self): + class BadCallable(CallableModel): + @Flow.call + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + with self.assertRaisesRegex(ValueError, "__call__ method must take a single argument, named 'context'"): + BadCallable() + + def test_invalid_auto_context_value(self): + with self.assertRaisesRegex(TypeError, "auto_context must be False, True, or a ContextBase subclass"): + + @Flow.call(auto_context="invalid") + def bad_func(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + def test_auto_context_rejects_var_args(self): + with self.assertRaisesRegex(TypeError, "variadic positional"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *args: int) -> GenericResult: + return GenericResult(value=len(args)) + + def test_auto_context_rejects_var_kwargs(self): + with self.assertRaisesRegex(TypeError, "variadic keyword"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, **kwargs: int) -> GenericResult: + return GenericResult(value=len(kwargs)) + + def test_auto_context_requires_return_annotation(self): + with self.assertRaisesRegex(TypeError, "must have a return type annotation"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int): + return GenericResult(value=value) + + def test_auto_context_rejects_missing_annotation(self): + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value) -> GenericResult: + return GenericResult(value=value) + + def test_auto_context_rejects_lazy_annotation(self): + with self.assertRaisesRegex(TypeError, "Lazy"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: Lazy[int]) -> GenericResult: + return GenericResult(value=value) + + def test_auto_context_rejects_callable_model_default(self): + with self.assertRaisesRegex(TypeError, "CallableModel"): + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int = MyCallable(i=1)) -> GenericResult: + return GenericResult(value=value) diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index ad98bd9..611e447 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -13,6 +13,7 @@ DateContext, DateRangeContext, DatetimeContext, + FlowContext, FreqContext, FreqDateContext, FreqDateRangeContext, @@ -53,6 +54,13 @@ def test_null_context_validation(self): self.assertIsInstance(NullContext.model_validate(DateContext(date="0d")), NullContext) self.assertRaises(ValueError, NullContext.model_validate, [True]) + def test_flow_context_hash_freezes_nested_pydantic_values(self): + c1 = FlowContext(payload=DateContext(date=date(2024, 1, 1))) + c2 = FlowContext(payload=DateContext(date="2024-01-01")) + + self.assertEqual(c1, c2) + self.assertEqual(hash(c1), hash(c2)) + def test_context_with_defaults(self): # Contexts may define default values. Extending the assumptions above: # Any context inherits the behavior from NullContext, and can be @@ -275,8 +283,13 @@ def split_camel(name: str): def test_inheritance(self): """Test that if a context has a superset of fields of another context, it is a subclass of that context.""" - for parent_name, parent_class in self.classes.items(): - for child_name, child_class in self.classes.items(): + # Exclude FlowContext from this test - it's a special universal carrier with no + # declared fields (uses extra="allow"), so the "superset implies subclass" logic + # doesn't apply to it. + classes_to_check = {name: cls for name, cls in self.classes.items() if name != "FlowContext"} + + for parent_name, parent_class in classes_to_check.items(): + for child_name, child_class in classes_to_check.items(): if parent_class is child_class: continue diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index cc34155..dabf815 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -1,9 +1,21 @@ from datetime import date from unittest import TestCase -from ccflow import DateContext, Evaluator, ModelEvaluationContext +import pytest -from .evaluators.util import MyDateCallable +from ccflow import CallableModel, DateContext, Evaluator, Flow, ModelEvaluationContext + +from .evaluators.util import MyDateCallable, MyResult + + +class MyAutoContextDateCallable(CallableModel): + """Auto context version of MyDateCallable for testing evaluators.""" + + offset: int + + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date) -> MyResult: + return MyResult(x=date.day + self.offset) class TestEvaluator(TestCase): @@ -32,3 +44,57 @@ def test_evaluator_deps(self): evaluator = Evaluator() out2 = evaluator.__deps__(model_evaluation_context) self.assertEqual(out2, out) + + +@pytest.mark.parametrize( + "callable_class", + [MyDateCallable, MyAutoContextDateCallable], + ids=["standard", "auto_context"], +) +class TestEvaluatorParametrized: + """Test evaluators work with both standard and auto_context callables.""" + + def test_evaluator_with_context_object(self, callable_class): + """Test evaluator with a context object.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + + out = model_evaluation_context() + assert out == MyResult(x=2) # day 1 + offset 1 + + evaluator = Evaluator() + out2 = evaluator(model_evaluation_context) + assert out2 == out + + def test_evaluator_with_fn_specified(self, callable_class): + """Test evaluator with fn='__call__' explicitly specified.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__call__") + + out = model_evaluation_context() + assert out == MyResult(x=2) + + def test_evaluator_direct_call_matches(self, callable_class): + """Test that evaluator result matches direct call.""" + m1 = callable_class(offset=5) + context = DateContext(date=date(2022, 1, 15)) + + # Direct call + direct_result = m1(context) + + # Via evaluator + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + evaluator_result = model_evaluation_context() + + assert direct_result == evaluator_result + assert direct_result == MyResult(x=20) # day 15 + offset 5 + + def test_evaluator_with_kwargs(self, callable_class): + """Test that evaluator works when callable is called with kwargs.""" + m1 = callable_class(offset=1) + + # Call with kwargs + result = m1(date=date(2022, 1, 10)) + assert result == MyResult(x=11) # day 10 + offset 1 diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py new file mode 100644 index 0000000..506e8f6 --- /dev/null +++ b/ccflow/tests/test_flow_context.py @@ -0,0 +1,452 @@ +"""Tests for FlowContext, FlowAPI, and BoundModel under the FromContext design.""" + +import pickle +from concurrent.futures import ThreadPoolExecutor +from datetime import date, timedelta + +import cloudpickle +import pytest + +from ccflow import BoundModel, CallableModel, ContextBase, Flow, FlowContext, FromContext, GenericResult + + +class NumberContext(ContextBase): + x: int + + +class OffsetModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: NumberContext) -> GenericResult[int]: + return GenericResult(value=context.x + self.offset) + + +@Flow.context_transform +def shift_start_date(start_date: FromContext[date], days: int) -> date: + return start_date - timedelta(days=days) + + +@Flow.context_transform +def shift_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +@Flow.context_transform +def offset_value(value: FromContext[int], amount: int) -> int: + return value + amount + + +@Flow.context_transform +def offset_b(b: FromContext[int], amount: int) -> int: + return b + amount + + +@Flow.context_transform +def double_x(x: FromContext[int]) -> int: + return x * 2 + + +def test_flow_context_basic_properties(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), label="x") + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + assert ctx.label == "x" + assert dict(ctx) == {"start_date": date(2024, 1, 1), "end_date": date(2024, 1, 31), "label": "x"} + + +def test_flow_context_value_semantics_and_hash(): + first = FlowContext(x=1, values=[1, 2]) + second = FlowContext(x=1, values=[1, 2]) + third = FlowContext(x=2, values=[1, 2]) + + assert first == second + assert first != third + assert len({first, second, third}) == 2 + + +def test_flow_context_hash_reflects_nested_mutable_changes(): + ctx = FlowContext(values=[1]) + assert hash(ctx) == hash(FlowContext(values=[1])) + + ctx.values.append(2) + + assert ctx == FlowContext(values=[1, 2]) + assert ctx != FlowContext(values=[1]) + assert hash(ctx) == hash(FlowContext(values=[1, 2])) + + +def test_flow_context_hash_handles_nested_models_and_rejects_opaque_unhashable_values(): + class WithDict: + __hash__ = None + + def __init__(self): + self.values = {"a": [1, 2]} + + class UnhashableNoState: + __slots__ = () + __hash__ = None + + nested = FlowContext(model=NumberContext(x=1), values={"items": [{2, 1}]}) + same = FlowContext(model=NumberContext(x=1), values={"items": [{1, 2}]}) + + assert nested == nested + assert nested != {"model": NumberContext(x=1)} + assert hash(nested) == hash(same) + + with pytest.raises(TypeError, match="unhashable value"): + hash(FlowContext(value=WithDict())) + + with pytest.raises(TypeError, match="unhashable value"): + hash(FlowContext(value=UnhashableNoState())) + + +def test_flow_context_pickle_and_cloudpickle_roundtrip(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), tags=frozenset({"a", "b"})) + assert pickle.loads(pickle.dumps(ctx)) == ctx + assert cloudpickle.loads(cloudpickle.dumps(ctx)) == ctx + + +def test_flow_api_introspection_for_from_context_model(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + model = add(a=10) + assert model.flow.inspect().context_inputs == {"b": int, "c": int} + assert model.flow.inspect().runtime_inputs == {"b": int, "c": int} + assert model.flow.inspect().required_inputs == {"b": int} + assert model.flow.inspect().bound_inputs == {"a": 10} + assert model.flow.compute(b=2).value == 17 + + +def test_flow_api_compute_accepts_single_context_or_kwargs_but_not_both(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + assert model.flow.compute(b=5).value == 15 + assert model.flow.compute(FlowContext(b=6)).value == 16 + + with pytest.raises(TypeError, match="either one context object or contextual keyword inputs"): + model.flow.compute(FlowContext(b=5), b=6) + + +def test_bound_model_with_context_static_and_transform(): + @Flow.model + def load_window(start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_window() + shifted = model.flow.with_context( + shift_window(days=7), + end_date=date(2024, 1, 31), + ) + + result = shifted(FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 2, 7))) + assert result.value == {"start": date(2024, 1, 1), "end": date(2024, 1, 31)} + + +def test_bound_model_with_context_is_branch_local_and_chained(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value + + base = source() + left = base.flow.with_context(value=offset_value(amount=1)) + right = base.flow.with_context(value=offset_value(amount=2)).flow.with_context(value=offset_value(amount=10)) + model = combine(left=left, right=right) + + assert model.flow.compute(value=5).value == (6 + 15 + 5) + + +def test_chained_with_context_merges_patch_transforms(): + @Flow.model + def load(start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"start": start_date, "end": end_date} + + base = load() + # First with_context applies a patch, second chains another patch on top + chained = base.flow.with_context(shift_window(days=7)).flow.with_context(shift_window(days=3)) + + # Both patches should be present in the ordered context spec. + assert [operation.kind for operation in chained.context_spec.operations] == ["patch", "patch"] + + # Patches read the original context, then apply writes left-to-right. + # patch1: start - 7, end - 7 => Jan 1, Jan 24 + # patch2: start - 3, end - 3 => Jan 5, Jan 28 (overwrites patch1 keys) + result = chained(FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31))) + assert result.value == {"start": date(2024, 1, 5), "end": date(2024, 1, 28)} + + +def test_compute_kwargs_can_supply_ambient_context_for_upstream_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + base = source() + model = combine( + left=base.flow.with_context(value=offset_value(amount=1)), + right=base.flow.with_context(value=offset_value(amount=10)), + ) + + assert model.flow.inspect().context_inputs == {"bonus": int} + assert model.flow.inspect().runtime_inputs == {"bonus": int} + assert model.flow.compute(value=5, bonus=100).value == (6 + 15 + 100) + + +def test_bound_flow_api_separates_declared_and_runtime_context_inputs(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + @Flow.context_transform + def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + bound = add(a=10).flow.with_context(b=from_seed()) + + assert bound.flow.inspect().context_inputs == {"b": int, "c": int} + assert bound.flow.inspect().runtime_inputs == {"c": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"seed": int} + for name, annotation in bound.flow.inspect().required_inputs.items(): + assert bound.flow.inspect().runtime_inputs[name] == annotation + assert bound.flow.inspect().bound_inputs == {"a": 10} + assert bound.flow.compute(seed=1).value == 17 + + +def test_bound_flow_api_rejects_conflicting_runtime_input_annotations(): + @Flow.model + def add(b: FromContext[int], c: FromContext[int]) -> int: + return b + c + + @Flow.context_transform + def from_int(seed: FromContext[int]) -> int: + return seed + 1 + + @Flow.context_transform + def from_str(seed: FromContext[str]) -> int: + return int(seed) + 1 + + bound = add().flow.with_context(b=from_int(), c=from_str()) + + with pytest.raises(TypeError, match="Conflicting runtime context annotations for 'seed'"): + bound.flow.inspect().runtime_inputs + + +def test_bound_flow_api_keeps_dynamic_transform_source_inputs_after_later_field_bindings(): + @Flow.model + def add(x: FromContext[int], y: FromContext[int]) -> int: + return x + y + + @Flow.context_transform + def from_y(y: FromContext[int]) -> int: + return y + 1 + + @Flow.context_transform + def from_x(x: FromContext[int]) -> int: + return x + 10 + + bound = add().flow.with_context(x=from_y()).flow.with_context(y=from_x()) + + assert bound.flow.inspect().context_inputs == {"x": int, "y": int} + assert bound.flow.inspect().runtime_inputs == {"x": int, "y": int} + assert bound.flow.inspect().required_inputs == {"x": int, "y": int} + assert bound.flow.compute(x=1, y=2).value == 14 + + +def test_bound_flow_api_reports_optional_transform_context_inputs_as_runtime_only(): + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + @Flow.context_transform + def seed_plus_one(seed: FromContext[int] = 0) -> int: + return seed + 1 + + bound = add().flow.with_context(b=seed_plus_one()) + + assert bound.flow.inspect().runtime_inputs == {"a": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"a": int} + assert bound.flow.inspect().bound_inputs == {} + assert bound.flow.compute(a=10).value == 11 + assert bound.flow.compute(a=10, seed=5).value == 16 + + +def test_bound_model_rejects_regular_field_context_overrides(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="only accepts contextual fields"): + add(a=1).flow.with_context(a=3) + + +def test_bound_model_repr_matches_user_facing_api(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=1) + bound = model.flow.with_context(b=offset_b(amount=1)) + assert repr(bound) == f"{model!r}.flow.with_context(b=offset_b(amount=1))" + + +def test_bound_model_serialization_roundtrip_preserves_static_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=5) + + dumped = bound.model_dump(mode="python") + assert dumped["context_spec"] == { + "operations": [{"kind": "field", "name": "b", "spec": {"kind": "static_value", "value": 5}}], + } + + restored = type(bound).model_validate(dumped) + assert restored.flow.compute().value == 15 + assert restored.model.flow.inspect().bound_inputs == {"a": 10} + + +def test_bound_model_json_roundtrip_preserves_context_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + dumped = bound.model_dump(mode="json") + binding = dumped["context_spec"]["operations"][0]["spec"] + assert binding["serialized_config"] is not None + + restored = type(bound).model_validate(dumped) + assert restored.flow.compute(b=4).value == 15 + + +def test_bound_model_cloudpickle_roundtrip_preserves_context_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + restored = cloudpickle.loads(cloudpickle.dumps(bound)) + assert restored.flow.compute(b=4).value == 15 + + +def test_bound_model_plain_pickle_roundtrip_preserves_context_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + restored = pickle.loads(pickle.dumps(bound, protocol=5)) + assert restored.flow.compute(b=4).value == 15 + + +def test_transformed_dag_cloudpickle_roundtrip_preserves_context_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value + + base = source() + model = combine( + left=base.flow.with_context(value=offset_value(amount=1)), + right=base.flow.with_context(value=offset_value(amount=10)), + ) + restored = cloudpickle.loads(cloudpickle.dumps(model)) + + assert restored.flow.compute(value=5).value == (6 + 15 + 5) + + +def test_bound_model_pydantic_roundtrip_preserves_context_transforms(): + """model_dump(mode='python') + model_validate must preserve serialized context bindings.""" + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=offset_b(amount=1)) + assert bound.flow.compute(b=4).value == 15 + + dumped = bound.model_dump(mode="python") + assert dumped["context_spec"]["operations"][0]["spec"]["kind"] == "context_transform" + + restored = type(bound).model_validate(dumped) + assert restored.flow.compute(b=4).value == 15 + + +def test_bound_model_context_spec_dump_contains_patch_and_field_specs(): + """model_dump(mode='json') should emit explicit tagged context-spec objects.""" + + @Flow.model + def load_window(start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + dumped = load_window().flow.with_context(shift_window(days=7), start_date=shift_start_date(days=1)).model_dump(mode="json") + assert dumped["context_spec"]["operations"][0]["binding"]["kind"] == "context_transform" + assert dumped["context_spec"]["operations"][1]["name"] == "start_date" + assert dumped["context_spec"]["operations"][1]["spec"]["kind"] == "context_transform" + + +def test_regular_callable_models_still_support_with_context(): + model = OffsetModel(offset=10) + shifted = model.flow.with_context(x=double_x()) + assert shifted(NumberContext(x=5)).value == 20 + + +def test_flow_api_for_regular_callable_model(): + model = OffsetModel(offset=10) + assert model.flow.compute(x=5).value == 15 + assert model.flow.inspect().context_inputs == {"x": int} + assert model.flow.inspect().runtime_inputs == {"x": int} + assert model.flow.inspect().required_inputs == {"x": int} + assert model.flow.inspect().bound_inputs == {"offset": 10} + + +def test_generated_flow_model_compute_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int]) -> int: + return a + b + c + + model = add(a=10) + + def worker(n: int) -> int: + return model.flow.compute(b=n, c=n + 1).value + + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) + + assert results == [10 + n + n + 1 for n in range(20)] + + +def test_bound_model_restore_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + dumped = add(a=10).flow.with_context(b=5).model_dump(mode="python") + + def worker(_: int) -> int: + restored = BoundModel.model_validate(dumped) + return restored.flow.compute().value + + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) + + assert results == [15] * 20 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py new file mode 100644 index 0000000..a6810d7 --- /dev/null +++ b/ccflow/tests/test_flow_model.py @@ -0,0 +1,3997 @@ +"""Focused tests for the FromContext-based Flow.model API.""" + +import base64 +import graphlib +import importlib +import inspect +import pickle +import subprocess +import sys +from datetime import date, timedelta +from pathlib import Path +from types import ModuleType +from typing import Annotated, Any, Callable, Literal, Optional, get_args + +import pytest +import ray +from pydantic import BaseModel as PydanticBaseModel, Field, PrivateAttr, ValidationError, model_validator +from ray.cloudpickle import dumps as rcpdumps, loads as rcploads + +import ccflow +import ccflow._flow_model_binding as binding_module +import ccflow.flow_model as flow_model_module +from ccflow import ( + BaseModel, + CallableModel, + ContextBase, + DateRangeContext, + EvaluatorBase, + Flow, + FlowContext, + FlowOptionsOverride, + FromContext, + GenericResult, + Lazy, + ModelEvaluationContext, + ModelRegistry, +) +from ccflow.callable import FlowOptions +from ccflow.evaluators import GraphEvaluator, LoggingEvaluator, MemoryCacheEvaluator, cache_key, combine_evaluators, get_dependency_graph +from ccflow.exttypes import PyObjectPath + + +class SimpleContext(ContextBase): + value: int + + +class ExternalPydanticPayload(PydanticBaseModel): + x: int + _bonus: int = PrivateAttr(default=1) + + +class ExternalCcflowPayload(BaseModel): + x: int + _bonus: int = PrivateAttr(default=1) + + +class ParentRangeContext(ContextBase): + start_date: date + end_date: date + + +class RichRangeContext(ParentRangeContext): + label: str = "child" + + +class OrderedContext(ContextBase): + a: int + b: int + + @model_validator(mode="after") + def _validate_order(self): + if self.a > self.b: + raise ValueError("a must be <= b") + return self + + +@pytest.fixture +def local_ray_runtime(): + with ray.init(num_cpus=1): + yield + + +@Flow.model +def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + +@Flow.model +def string_processor(value: FromContext[int], prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: + return GenericResult(value=f"{prefix}{value}{suffix}") + + +@Flow.model +def data_source(base_value: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + base_value) + + +@Flow.model +def data_transformer(source: int, factor: int) -> GenericResult[int]: + return GenericResult(value=source * factor) + + +@Flow.model +def data_aggregator(input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: + if operation == "add": + return GenericResult(value=input_a + input_b) + raise ValueError(f"unsupported operation: {operation}") + + +@Flow.model +def pipeline_stage1(initial: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + initial) + + +@Flow.model +def pipeline_stage2(stage1_output: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=stage1_output * multiplier) + + +@Flow.model +def pipeline_stage3(stage2_output: int, offset: int) -> GenericResult[int]: + return GenericResult(value=stage2_output + offset) + + +@Flow.model +def date_range_loader_previous_day( + source: str, + start_date: FromContext[date], + end_date: FromContext[date], + include_weekends: bool = False, +) -> GenericResult[dict]: + del include_weekends + return GenericResult( + value={ + "source": source, + "start_date": str(start_date - timedelta(days=1)), + "end_date": str(end_date), + } + ) + + +@Flow.model +def date_range_processor(raw_data: dict, normalize: bool = False) -> GenericResult[str]: + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") + + +@Flow.model +def contextual_loader(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) + + +@Flow.model +def contextual_processor( + prefix: str, + data: dict, + start_date: FromContext[date], + end_date: FromContext[date], +) -> GenericResult[str]: + del start_date, end_date + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") + + +@Flow.context_transform +def increment_b(b: FromContext[int], amount: int) -> int: + return b + amount + + +@Flow.context_transform +def shift_integer_window(start_date: FromContext[int], end_date: FromContext[int], amount: int) -> dict[str, object]: + return { + "start_date": start_date + amount, + "end_date": end_date + amount, + } + + +@Flow.context_transform +def bump_start_date(start_date: FromContext[int], amount: int) -> int: + return start_date + amount + + +@Flow.context_transform +def annotated_start_patch(start_date: FromContext[int]) -> Annotated[dict[str, object], "meta"]: + return {"start_date": start_date + 1} + + +@Flow.context_transform +def optional_start_patch(start_date: FromContext[int]) -> dict[str, object] | None: + return {"start_date": start_date + 2} + + +@Flow.context_transform +def parity_bucket(raw: FromContext[int]) -> int: + return raw % 2 + + +@Flow.context_transform +def seed_plus_one(seed: FromContext[int]) -> int: + return seed + 1 + + +@Flow.context_transform +def non_idempotent_a_step(a: FromContext[int]) -> int: + return 2 if a == 1 else 3 + + +@Flow.context_transform +def static_bad() -> int: + return 2 + + +def lazy_context_transform_for_rejection(value: Lazy[int]) -> int: + return value() + + +@Flow.context_transform +def static_patch() -> dict[str, object]: + return {"a": 2} + + +def test_module_level_flow_model_examples_and_transforms_execute(): + assert data_aggregator(input_a=1, input_b=2, operation="add").flow.compute().value == 3 + with pytest.raises(ValueError, match="unsupported operation"): + data_aggregator(input_a=1, input_b=2, operation="multiply").flow.compute() + + raw = date_range_loader_previous_day(source="library").flow.compute(start_date=date(2024, 1, 2), end_date=date(2024, 1, 3)).value + assert raw == {"source": "library", "start_date": "2024-01-01", "end_date": "2024-01-03"} + assert date_range_processor(raw_data=raw).flow.compute().value.startswith("raw:library") + assert date_range_processor(raw_data=raw, normalize=True).flow.compute().value.startswith("normalized:library") + + contextual = contextual_loader(source="library").flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value + assert contextual_processor(prefix="p", data=contextual).flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value == ( + "p:library:2024-01-01 to 2024-01-02" + ) + + assert ( + pipeline_stage3(stage2_output=pipeline_stage2(stage1_output=pipeline_stage1(initial=2), multiplier=3), offset=4).flow.compute(value=5).value + == 25 + ) + + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int], bucket: FromContext[int]) -> int: + return start_date * 100 + end_date * 10 + bucket + + shifted = load().flow.with_context( + shift_integer_window(amount=2), + start_date=bump_start_date(amount=10), + bucket=parity_bucket(), + ) + assert shifted.flow.compute(start_date=1, end_date=5, raw=7).value == 1171 + + @Flow.model + def add(a: FromContext[int]) -> int: + return a + + assert add().flow.with_context(a=non_idempotent_a_step()).flow.compute(a=1).value == 2 + assert add().flow.with_context(a=non_idempotent_a_step()).flow.compute(a=10).value == 3 + assert add().flow.with_context(static_patch()).flow.compute(a=1).value == 2 + + +def test_flow_model_rejects_invalid_decorator_targets(): + with pytest.raises(TypeError): + Flow.model(123) + with pytest.raises(TypeError): + Flow.model(lambda: None) + + +def test_context_transform_defaults_and_public_validation(): + @Flow.context_transform + def default_amount(amount: int = 5) -> int: + return amount + + @Flow.context_transform + def default_seed(seed: FromContext[int] = 9) -> int: + return seed + 1 + + @Flow.context_transform + def dynamic_patch(seed: FromContext[int]) -> dict[str, object]: + return {"a": seed} + + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + assert add().flow.with_context(a=default_amount(), b=default_seed()).flow.compute().value == 15 + assert add().flow.with_context(dynamic_patch()).flow.compute(seed=1, b=2).value == 3 + + with pytest.raises(TypeError, match="unexpected keyword"): + increment_b(amount=1, extra=2) + with pytest.raises(TypeError, match="Do not pass contextual"): + increment_b(b=1, amount=2) + with pytest.raises(TypeError, match="missing required regular"): + increment_b() + + with pytest.raises(TypeError, match="Positional with_context"): + add().flow.with_context(lambda: {"a": 1}) + with pytest.raises(TypeError, match="Positional with_context"): + add().flow.with_context(123) + + +def test_plain_and_bound_optional_compute_paths(): + class OptionalContextModel(CallableModel): + @Flow.call + def __call__(self, context: Optional[SimpleContext] = None) -> GenericResult[int]: + return GenericResult(value=0 if context is None else context.value) + + assert OptionalContextModel().flow.compute(None).value == 0 + assert OptionalContextModel().flow.compute().value == 0 + assert OptionalContextModel().flow.inspect().required_inputs == {} + + bound = OptionalContextModel().flow.with_context() + assert bound.flow.compute(FlowContext(value=3)).value == 3 + with pytest.raises(TypeError, match="either one context object"): + bound.flow.compute(FlowContext(value=3), value=4) + + +def test_bound_optional_none_context_preserves_wrapped_dependencies(): + class Dep(CallableModel): + @Flow.call + def __call__(self, context: FlowContext) -> GenericResult[int]: + return GenericResult(value=1) + + class Root(CallableModel): + dep: Dep + + @Flow.call + def __call__(self, context: Optional[FlowContext] = None) -> GenericResult[int]: + return GenericResult(value=self.dep(FlowContext()).value + (0 if context is None else context.bonus)) + + @Flow.deps + def __deps__(self, context: Optional[FlowContext]) -> list[tuple[CallableModel, list[ContextBase]]]: + return [(self.dep, [FlowContext()])] + + root = Root(dep=Dep()) + bound = root.flow.with_context() + graph = get_dependency_graph(bound.__call__.get_evaluation_context(bound, None)) + + assert len(graph.ids) == 2 + assert bound.flow.compute().value == 1 + + +def test_from_context_anchor_behavior(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + assert foo(a=11).flow.compute(b=12).value == 23 + assert foo(a=11, b=12).flow.compute().value == 23 + + with pytest.raises(TypeError, match="compute\\(\\) cannot satisfy unbound regular parameter\\(s\\): a"): + foo().flow.compute(a=11, b=12) + + +def test_regular_param_accepts_upstream_model(): + @Flow.model + def source(value: FromContext[int], offset: int) -> int: + return value + offset + + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + model = foo(a=source(offset=5)) + assert model.flow.compute(value=7, b=12).value == 24 + assert model.flow.compute(FlowContext(value=7, b=12)).value == 24 + + +def test_regular_param_containers_are_literals(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def inspect_values(values: list[Any]) -> int: + return sum(isinstance(value, CallableModel) for value in values) + + model = inspect_values(values=[source()]) + + assert model.__deps__(FlowContext(value=10)) == [] + assert model.flow.compute(value=10).value == 1 + + @Flow.model + def total(values: list[int]) -> int: + return sum(values) + + with pytest.raises(TypeError, match="Field 'values'"): + total(values=[source()]) + + +def test_regular_param_upstream_dependency_coerced(): + """Upstream model returning str should be coerced to downstream int annotation.""" + + @Flow.model + def str_source(tag: FromContext[str]) -> str: + return tag + + @Flow.model + def consumer(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = consumer(x=str_source()) + # str_source returns "42" (a str); consumer expects x: int; should be coerced + result = model.flow.compute(tag="42", bonus=10) + assert result.value == 52 + assert isinstance(result.value, int) + + # Also test that invalid coercion raises + with pytest.raises(TypeError, match="Regular parameter"): + model.flow.compute(tag="not_a_number", bonus=10) + + +def test_regular_param_lazy_upstream_dependency_coerced(): + """Lazy upstream model output should be coerced on first call.""" + + @Flow.model + def lazy_source(v: FromContext[int]) -> str: + return str(v) + + @Flow.model + def consumer(x: Lazy[int], bonus: FromContext[int]) -> int: + return x() + bonus + + model = consumer(x=lazy_source()) + result = model.flow.compute(v=7, bonus=3) + assert result.value == 10 + assert isinstance(result.value, int) + + +def test_regular_param_plain_callable_model_projects_dependency_context(): + class ValueContext(ContextBase): + value: int + + class PlainSource(CallableModel): + @property + def context_type(self): + return ValueContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: ValueContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=PlainSource()) + deps = model.__deps__(FlowContext(value=3, bonus=7)) + + assert len(deps) == 1 + dep_model, dep_contexts = deps[0] + assert isinstance(dep_model, PlainSource) + assert dep_contexts == [ValueContext(value=3)] + assert model.flow.compute(FlowContext(value=3, bonus=7)).value == 37 + + +def test_bound_regular_param_name_can_collide_with_ambient_context(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + @Flow.model + def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + model = combine(a=100, left=source()) + assert model.flow.compute(FlowContext(a=7, bonus=5)).value == 112 + + with pytest.raises(TypeError, match="does not accept regular parameter override\\(s\\): a"): + model.flow.compute(a=7, bonus=5) + + +def test_contextual_param_rejects_callable_model(): + @Flow.model + def source(offset: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + offset) + + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="cannot be bound to a CallableModel"): + foo(a=1, b=source(offset=2)) + + +def test_contextual_construction_defaults_and_bound_inputs(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + model = foo(a=11, b=12) + assert model.flow.inspect().bound_inputs == {"a": 11, "b": 12} + assert model.flow.inspect().context_inputs == {"b": int} + assert model.flow.inspect().required_inputs == {} + assert model.flow.compute().value == 23 + + +def test_contextual_function_defaults_remain_contextual(): + @Flow.model + def foo(a: int, b: FromContext[int] = 5) -> int: + return a + b + + model = foo(a=2) + assert model.flow.inspect().bound_inputs == {"a": 2} + assert model.flow.inspect().context_inputs == {"b": int} + assert model.flow.inspect().required_inputs == {} + assert model.flow.compute().value == 7 + assert model.flow.compute(b=10).value == 12 + + +def test_context_type_accepts_richer_subclass_for_from_context(): + @Flow.model(context_type=ParentRangeContext) + def span_days(multiplier: int, start_date: FromContext[date], end_date: FromContext[date]) -> int: + return multiplier * ((end_date - start_date).days + 1) + + model = span_days(multiplier=2) + assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-03").value == 6 + assert model(RichRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 4), label="x")).value == 8 + + +def test_declared_context_type_introspection_reports_effective_field_type(): + class ModeContext(ContextBase): + mode: Literal["a"] + + @Flow.model(context_type=ModeContext) + def choose(mode: FromContext[str]) -> str: + return mode + + model = choose() + + assert model.flow.inspect().context_inputs == {"mode": Literal["a"]} + assert model.flow.inspect().required_inputs == {"mode": Literal["a"]} + assert model.flow.compute(mode="a").value == "a" + with pytest.raises(ValidationError): + model.flow.compute(mode="b") + + +def test_context_type_validation_applies_to_resolved_contextual_values(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.compute(a=2, b=1) + + with pytest.raises(ValueError, match="a must be <= b"): + add(a=2, b=1) + + +def test_context_type_validates_construction_time_contextual_defaults_early(): + class PositiveContext(ContextBase): + x: int = Field(gt=0) + + @Flow.model(context_type=PositiveContext) + def identity(x: FromContext[int]) -> int: + return x + + with pytest.raises(ValueError, match="greater than 0"): + identity(x=-1) + + +def test_context_type_validates_partial_contextual_defaults_early(): + class PositivePairContext(ContextBase): + x: int = Field(gt=0) + y: int + + @Flow.model(context_type=PositivePairContext) + def add(x: FromContext[int], y: FromContext[int]) -> int: + return x + y + + with pytest.raises(ValueError, match="greater than 0"): + add(x=-1) + + +def test_context_type_validates_static_with_context_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(a=2, b=1) + + +def test_context_type_validates_chained_static_with_context_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(a=2).flow.with_context(b=1) + + +def test_context_type_validates_partial_static_with_context_overrides_early(): + class PositivePairContext(ContextBase): + x: int = Field(gt=0) + y: int + + @Flow.model(context_type=PositivePairContext) + def add(x: FromContext[int], y: FromContext[int]) -> int: + return x + y + + with pytest.raises(ValueError, match="greater than 0"): + add().flow.with_context(x=-1) + + +def test_context_type_validates_static_field_transform_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int] = 1) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(a=static_bad()) + + +def test_context_type_validates_static_patch_transform_overrides_early(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int] = 1) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.with_context(static_patch()) + + +def test_context_named_parameters_are_just_regular_parameters(): + @Flow.model + def loader(context: DateRangeContext, source: str = "db") -> GenericResult[str]: + return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") + + model = loader(context=DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)), source="api") + assert model.flow.inspect().bound_inputs["context"] == DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)) + assert model.flow.inspect().context_inputs == {} + assert model.flow.compute().value == "api:2024-01-01:2024-01-02" + + with pytest.raises(TypeError, match="Missing regular parameter\\(s\\) for loader: context"): + loader(source="api").flow.compute(start_date="2024-01-01", end_date="2024-01-02") + + +def test_auto_unwrap_defaults_to_false_for_auto_wrapped_results(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + result = model.flow.compute(b=5) + assert model.result_type == GenericResult[int] + assert inspect.signature(type(model).__call__).return_annotation == GenericResult[int] + assert type(result) is GenericResult[int] + assert repr(result) == "GenericResult[int](value=15)" + assert result.value == 15 + + +def test_compute_does_not_unwrap_explicit_generic_result_returns(): + @Flow.model + def load(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = load() + result = model.flow.compute(value=3) + assert model.result_type == GenericResult[int] + assert type(result) is GenericResult[int] + assert repr(result) == "GenericResult[int](value=6)" + assert result.value == 6 + + +def test_flow_model_rejects_union_resultbase_return_annotations(): + with pytest.raises(TypeError, match="does not support Union or Optional ResultBase"): + + @Flow.model + def optional_result(value: FromContext[int]) -> Optional[GenericResult[int]]: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="does not support Union or Optional ResultBase"): + + @Flow.model + def pep604_optional_result(value: FromContext[int]) -> GenericResult[int] | None: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="does not support Union or Optional ResultBase"): + + @Flow.model + def annotated_optional_result(value: FromContext[int]) -> Annotated[GenericResult[int] | None, "meta"]: + return GenericResult(value=value) + + +def test_flow_model_allows_plain_union_return_annotations(): + @Flow.model + def choose(flag: FromContext[bool]) -> int | str: + return 1 if flag else "one" + + assert choose().flow.compute(flag=True).value == 1 + assert choose().flow.compute(flag=False).value == "one" + + +def test_flow_model_handles_annotated_result_annotations(): + @Flow.model + def explicit_result() -> Annotated[GenericResult[int], "meta"]: + return GenericResult(value=1) + + @Flow.model + def plain_union(flag: FromContext[bool]) -> Annotated[int | str, "meta"]: + return 1 if flag else "one" + + result = explicit_result().flow.compute() + assert type(result) is GenericResult[int] + assert result.value == 1 + assert plain_union().flow.compute(flag=True).value == 1 + assert plain_union().flow.compute(flag=False).value == "one" + + +def test_auto_unwrap_can_be_enabled_for_auto_wrapped_results(): + @Flow.model(auto_unwrap=True) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + assert add(a=10).flow.compute(b=5) == 15 + + +def test_auto_unwrap_only_affects_external_compute_results(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 2 + + @Flow.model(auto_unwrap=True) + def add(left: int, bonus: FromContext[int]) -> int: + return left + bonus + + model = add(left=source()) + assert model.flow.compute(FlowContext(value=4, bonus=3)) == 11 + + +def test_auto_wrap_validates_return_type(): + @Flow.model + def bad(x: FromContext[int]) -> int: + return "oops" + + with pytest.raises(ValidationError, match=r"GenericResult\[int\]"): + bad().flow.compute(x=1) + + +def test_auto_wrap_respects_validate_result_false(): + @Flow.model(validate_result=False) + def bad_decorator() -> int: + return "oops" + + @Flow.model + def bad_runtime() -> int: + return "oops" + + decorator_result = bad_decorator().flow.compute() + runtime_result = bad_runtime().flow.compute(_options=FlowOptions(validate_result=False)) + + assert isinstance(decorator_result, GenericResult) + assert decorator_result.value == "oops" + assert isinstance(runtime_result, GenericResult) + assert runtime_result.value == "oops" + + +def test_auto_wrap_coerces_compatible_return(): + @Flow.model + def coerce(x: FromContext[int]) -> float: + return 3 + + result = coerce().flow.compute(x=1) + assert result.value == 3.0 + assert isinstance(result.value, float) + + +def test_model_base_allows_custom_callable_model_subclass(): + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @model_validator(mode="after") + def _validate_multiplier(self): + if self.multiplier <= 0: + raise ValueError("multiplier must be positive") + return self + + def scaled(self, value: int) -> int: + return value * self.multiplier + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + assert isinstance(model, CustomFlowBase) + assert model.multiplier == 3 + assert model.scaled(4) == 12 + assert model.flow.compute(b=5).value == 15 + + with pytest.raises(ValueError, match="multiplier must be positive"): + add(a=10, multiplier=0) + + +def test_model_base_must_be_callable_model_subclass(): + with pytest.raises(TypeError, match="model_base must be a CallableModel subclass"): + + @Flow.model(model_base=int) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + +def test_context_named_regular_parameter_can_coexist_with_from_context(): + @Flow.model + def mixed(context: SimpleContext, y: FromContext[int]) -> int: + return context.value + y + + model = mixed(context=SimpleContext(value=10)) + assert model.flow.inspect().bound_inputs == {"context": SimpleContext(value=10)} + assert model.flow.inspect().context_inputs == {"y": int} + assert model.flow.compute(y=5).value == 15 + + +@pytest.mark.parametrize("reserved_name", ["flow", "meta", "context_type", "result_type", "type_"]) +def test_flow_model_rejects_reserved_parameter_names(reserved_name): + namespace = {"Flow": Flow, "FromContext": FromContext} + exec( + f"def bad({reserved_name}: str, value: FromContext[int]) -> str:\n return str(value)\n", + namespace, + ) + + with pytest.raises(TypeError, match=f"Parameter name\\(s\\) '{reserved_name}' are reserved"): + Flow.model(namespace["bad"]) + + +def test_context_type_requires_from_context(): + with pytest.raises(TypeError, match="context_type=... requires FromContext"): + + @Flow.model(context_type=DateRangeContext) + def bad(x: int) -> int: + return x + + +def test_lazy_dependency_remains_lazy(): + calls = {"source": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value * 10 + + @Flow.model + def choose(value: int, lazy_value: Lazy[int], threshold: FromContext[int]) -> int: + if value > threshold: + return value + return lazy_value() + + eager = choose(value=50, lazy_value=source()) + assert eager.flow.compute(FlowContext(value=3, threshold=10)).value == 50 + assert calls["source"] == 0 + + deferred = choose(value=5, lazy_value=source()) + assert deferred.flow.compute(FlowContext(value=3, threshold=10)).value == 30 + assert calls["source"] == 1 + + +def test_lazy_parameter_requires_dependency_binding(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def choose(lazy_value: Lazy[int]) -> int: + return lazy_value() + + with pytest.raises(TypeError, match="Lazy"): + choose(lazy_value=1) + + assert choose(lazy_value=source()).flow.compute(value=3).value == 3 + + +def test_lazy_parameter_rejects_literal_function_default(): + with pytest.raises(TypeError, match="Lazy"): + + @Flow.model + def choose(lazy_value: Lazy[int] = 1) -> int: + return lazy_value() + + +def test_lazy_and_from_context_combination_is_rejected(): + with pytest.raises(TypeError, match="cannot combine Lazy"): + + @Flow.model + def bad(x: Lazy[FromContext[int]]) -> int: + return x() + + +def test_bare_flow_marker_annotations_are_rejected(): + with pytest.raises(TypeError, match=r"FromContext\[T\]"): + + @Flow.model + def bad_context(x: FromContext) -> int: + return 1 + + with pytest.raises(TypeError, match=r"Lazy\[T\]"): + + @Flow.model + def bad_lazy(x: Lazy) -> int: + return x() + + with pytest.raises(TypeError, match=r"FromContext\[T\]"): + + @Flow.context_transform + def bad_transform_context(x: FromContext) -> int: + return 1 + + with pytest.raises(TypeError, match=r"Lazy\[T\]"): + + @Flow.context_transform + def bad_transform_lazy(x: Lazy) -> int: + return 1 + + +def test_auto_wrap_and_serialization_roundtrip(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + assert restored.flow.inspect().bound_inputs == {"a": 10} + assert restored.flow.inspect().required_inputs == {"b": int} + assert restored.flow.compute(b=5).value == 15 + + +def test_generated_models_cloudpickle_roundtrip(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = rcploads(rcpdumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 + + +def test_generated_models_plain_pickle_roundtrip(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = pickle.loads(pickle.dumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 + + +def test_generated_model_direct_call_plain_pickle_uses_serialized_factory(monkeypatch): + module = ModuleType("ccflow_test_direct_model") + + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + multiply.__module__ = module.__name__ + multiply.__qualname__ = multiply.__name__ + module.multiply = multiply + monkeypatch.setitem(sys.modules, module.__name__, module) + + factory = Flow.model(multiply) + assert not hasattr(module.multiply, "_generated_model") + + model = factory(a=6) + restored = pickle.loads(pickle.dumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 + + +def test_local_generated_model_effective_cache_key_survives_pickle_roundtrip(): + def make_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + return add(a=1) + + model = make_model() + context = FlowContext(b=2) + before = cache_key(model.__call__.get_evaluation_context(model, context), effective=True) + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(model, protocol=5)) + after = cache_key(restored.__call__.get_evaluation_context(restored, context), effective=True) + assert after == before + + +def test_generated_model_plain_pickle_preserves_external_pydantic_private_state(): + @Flow.model + def read(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalPydanticPayload(x=2) + payload._bonus = 40 + + restored = pickle.loads(pickle.dumps(read(payload=payload), protocol=5)) + + assert restored.payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_generated_model_pickle_preserves_external_ccflow_private_state(): + @Flow.model + def read(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalCcflowPayload(x=2) + payload._bonus = 40 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(read(payload=payload), protocol=5)) + + assert isinstance(restored.payload, ExternalCcflowPayload) + assert restored.payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_generated_model_pickle_preserves_outer_graph_identity_and_cycles(): + @Flow.model + def read(payload: object) -> int: + return 0 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + shared = [] + restored_model, restored_shared = loads(dumps((read(payload=shared), shared), protocol=5)) + assert restored_model.payload is restored_shared + + cycle = read(payload=None) + cycle.payload = cycle + restored_cycle = loads(dumps(cycle, protocol=5)) + assert restored_cycle.payload is restored_cycle + + +def test_local_generated_model_cloudpickle_preserves_local_pydantic_literal_state(): + def make_model(): + class LocalPayload(PydanticBaseModel): + x: int + _bonus: int = PrivateAttr(default=1) + + @Flow.model + def read(payload: object) -> int: + return payload.x + payload._bonus + + payload = LocalPayload(x=2) + payload._bonus = 40 + return read(payload=payload) + + restored = rcploads(rcpdumps(make_model(), protocol=5)) + + assert restored.payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_pickle_preserves_external_pydantic_static_context_value(): + @Flow.model + def read(payload: FromContext[object]) -> int: + return payload.x + payload._bonus + + payload = ExternalPydanticPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(payload=payload) + assert bound.flow.compute().value == 42 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.value + + assert isinstance(restored_payload, ExternalPydanticPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_pickle_preserves_external_pydantic_context_transform_bound_arg(): + @Flow.model + def read(value: FromContext[int]) -> int: + return value + + @Flow.context_transform + def derive(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalPydanticPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(value=derive(payload=payload)) + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.bound_args["payload"] + + assert isinstance(restored_payload, ExternalPydanticPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_pickle_preserves_external_ccflow_static_context_value(): + @Flow.model + def read(payload: FromContext[object]) -> int: + return payload.x + payload._bonus + + payload = ExternalCcflowPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(payload=payload) + assert bound.flow.compute().value == 42 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.value + + assert isinstance(restored_payload, ExternalCcflowPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_bound_model_pickle_preserves_external_ccflow_context_transform_bound_arg(): + @Flow.model + def read(value: FromContext[int]) -> int: + return value + + @Flow.context_transform + def derive(payload: object) -> int: + return payload.x + payload._bonus + + payload = ExternalCcflowPayload(x=2) + payload._bonus = 40 + bound = read().flow.with_context(value=derive(payload=payload)) + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + restored_payload = restored.context_spec.operations[0].spec.bound_args["payload"] + + assert isinstance(restored_payload, ExternalCcflowPayload) + assert restored_payload._bonus == 40 + assert restored.flow.compute().value == 42 + + +def test_local_generated_model_plain_pickle_handles_function_default_state(): + def make_model(): + @Flow.model + def first(xs: list[int] = [1], b: FromContext[int] = 2) -> int: + return xs[0] + b + + return first() + + restored = pickle.loads(pickle.dumps(make_model(), protocol=5)) + + assert restored.flow.compute().value == 3 + + +def test_unresolved_lazy_nested_local_generated_dependency_identity_survives_pickle_roundtrip(): + def make_model(): + @Flow.model + def dependency(d: FromContext[int]) -> int: + return d + + @Flow.model + def source(dep_value: int, a: FromContext[int]) -> int: + return dep_value + a + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + return lazy_value() if use_lazy else x + + return choose(x=7, lazy_value=source(dep_value=dependency())) + + model = make_model() + context = FlowContext(use_lazy=False) + before_eval = model.__call__.get_evaluation_context(model, context) + before_key = cache_key(before_eval, effective=True) + before_root = get_dependency_graph(before_eval).root_id + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(model, protocol=5)) + after_eval = restored.__call__.get_evaluation_context(restored, context) + + assert cache_key(after_eval, effective=True) == before_key + assert get_dependency_graph(after_eval).root_id == before_root + + +def test_generated_model_pydantic_roundtrip_via_base_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + from ccflow import BaseModel + + model = add(a=10) + dumped = model.model_dump(mode="python") + + # BaseModel.model_validate uses the _target_ field (PyObjectPath) to + # reconstruct the correct generated class. + restored = BaseModel.model_validate(dumped) + assert type(restored) is type(model) + assert restored.flow.compute(b=5).value == 15 + + +def test_generated_model_json_roundtrip(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + from ccflow import BaseModel + + model = add(a=10, b=3) + json_str = model.model_dump_json() + restored = BaseModel.model_validate_json(json_str) + assert type(restored) is type(model) + assert restored.flow.compute().value == 13 + assert restored.flow.compute(b=7).value == 17 + + +def test_generated_model_dependency_input_json_roundtrip(): + from ccflow import BaseModel + + model = data_transformer(source=data_source(base_value=1), factor=2) + restored = BaseModel.model_validate_json(model.model_dump_json()) + + assert restored.flow.compute(value=10).value == 22 + assert isinstance(restored.source, CallableModel) + + +def test_generated_model_lazy_dependency_input_json_roundtrip(): + @Flow.model + def choose(source: Lazy[int], use_value: FromContext[bool]) -> int: + return source() if use_value else 0 + + from ccflow import BaseModel + + model = choose(source=data_source(base_value=1)) + restored = BaseModel.model_validate_json(model.model_dump_json()) + + assert restored.flow.compute(value=10, use_value=True).value == 11 + assert restored.flow.compute(value=10, use_value=False).value == 0 + assert isinstance(restored.source, CallableModel) + + +def test_generated_model_type_dict_regular_input_stays_literal_when_annotation_accepts_dict(): + @Flow.model + def payload_type(payload: dict) -> str: + return type(payload).__name__ + + type_payload = data_source(base_value=1).model_dump(mode="python") + model = payload_type(payload=type_payload) + + assert model.flow.compute(value=10).value == "dict" + assert model.payload == type_payload + + +def test_generated_model_target_alias_dict_regular_input_stays_literal(): + @Flow.model + def payload_type(payload: Any) -> str: + return type(payload).__name__ + + alias_payload = data_source(base_value=1).model_dump(mode="python", by_alias=True) + model = payload_type(payload=alias_payload) + + assert model.flow.compute(value=10).value == "dict" + assert model.payload == alias_payload + + +def test_generated_model_target_alias_restores_dependency_after_literal_validation_fails(): + @Flow.model + def total(value: int) -> int: + return value + + alias_payload = data_source(base_value=1).model_dump(mode="python", by_alias=True) + model = total(value=alias_payload) + + assert model.flow.compute(FlowContext(value=10)).value == 11 + assert isinstance(model.value, CallableModel) + + +def test_generated_model_type_marker_restores_dependency_after_literal_validation_fails(): + @Flow.model + def total(value: int) -> int: + return value + + type_payload = data_source(base_value=1).model_dump(mode="python") + model = total(value=type_payload) + + assert model.flow.compute(FlowContext(value=10)).value == 11 + assert isinstance(model.value, CallableModel) + + +@pytest.mark.parametrize( + "payload", + [ + {"_target_": "does.not.exist.Class", "foo": 1}, + {"_target_": "builtins.dict", "foo": 1}, + {"type_": "does.not.exist.Class", "foo": 1}, + {"type_": "builtins.dict", "foo": 1}, + ], +) +def test_serialized_dependency_fallback_preserves_literal_validation_error(payload): + @Flow.model + def total(value: int) -> int: + return value + + with pytest.raises(TypeError, match="Field 'value': expected int, got dict"): + total(value=payload) + + +def test_registry_lookup_preserves_literal_error_when_candidate_type_is_incompatible(): + registry = ModelRegistry.root().clear() + registry.add("context", SimpleContext(value=1)) + + @Flow.model + def total(values: list[int]) -> int: + return sum(values) + + try: + with pytest.raises(TypeError, match="Field 'values': expected list, got str"): + total(values="context") + finally: + registry.clear() + + +def test_importable_generated_model_uses_stable_module_path_for_type_serialization(): + model = basic_loader(source="library", multiplier=3) + stable_path = f"{__name__}._basic_loader_Model" + + assert getattr(sys.modules[__name__], "_basic_loader_Model") is type(model) + assert "__ccflow_import_path__" not in type(model).__dict__ + assert str(PyObjectPath.validate(type(model))) == stable_path + assert str(model.model_dump(mode="python")["type_"]) == stable_path + + +def test_importable_generated_model_duplicate_names_raise_conflict(monkeypatch): + module = ModuleType("ccflow_test_duplicate_generated_models") + module.Flow = Flow + module.FromContext = FromContext + monkeypatch.setitem(sys.modules, module.__name__, module) + + exec( + """ +def stage(value: FromContext[int]) -> int: + return value + 1 +""", + module.__dict__, + ) + first_factory = Flow.model(module.stage) + module.first = first_factory + first = first_factory() + + exec( + """ +def stage(value: FromContext[int]) -> int: + return value + 2 +""", + module.__dict__, + ) + with pytest.raises(ValueError, match="already occupied"): + Flow.model(module.stage) + + assert getattr(module, "_stage_Model") is type(first) + assert not hasattr(module, "_stage_Model_2") + assert str(PyObjectPath.validate(type(first))) == f"{module.__name__}._stage_Model" + assert first.flow.compute(value=10).value == 11 + + +def test_generated_model_pickle_path_ignores_stale_pyobjectpath_cache(monkeypatch): + module = ModuleType("ccflow_test_stale_generated_pickle_path") + module.Flow = Flow + module.FromContext = FromContext + monkeypatch.setitem(sys.modules, module.__name__, module) + + exec( + """ +def foo(a: int, x: FromContext[int]) -> int: + return a + x + 1 +""", + module.__dict__, + ) + first_factory = Flow.model(module.foo) + module.foo = first_factory + first = first_factory(a=10) + path = f"{module.__name__}.foo" + + # Prime PyObjectPath/import_string's cache with the first factory. Pickle + # path selection must still inspect the live module attribute after a reload + # or replacement, otherwise old models can serialize by a path that a clean + # process resolves to different behavior. + assert PyObjectPath(path).object is first_factory + + exec( + """ +def foo(a: int, x: FromContext[int]) -> int: + return a + x + 100 +""", + module.__dict__, + ) + second_factory = Flow.model(module.foo) + module.foo = second_factory + + config = type(first).__flow_model_config__ + assert flow_model_module._generated_model_factory_path_for_pickle(config, type(first)) is None + assert first.__reduce__()[0] is flow_model_module._new_local_flow_model_for_pickle + assert pickle.loads(pickle.dumps(first, protocol=5)).flow.compute(x=1).value == 12 + assert second_factory(a=10).flow.compute(x=1).value == 111 + + +def test_reloaded_importable_generated_model_keeps_clean_process_path(tmp_path, monkeypatch): + module_dir = tmp_path / "reload_case" + module_dir.mkdir() + module_path = module_dir / "repro_mod.py" + module_path.write_text( + "\n".join( + [ + "from ccflow import Flow, FromContext", + "", + "@Flow.model", + "def foo(x: FromContext[int]) -> int:", + " return x + 1", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import repro_mod + + assert str(repro_mod.foo().type_) == "repro_mod._foo_Model" + reloaded = importlib.reload(repro_mod) + model = reloaded.foo() + payload = model.model_dump_json() + script = ( + "import sys\n" + f"sys.path.insert(0, {str(module_dir)!r})\n" + "from ccflow import BaseModel\n" + f"model = BaseModel.model_validate_json({payload!r})\n" + "assert model.flow.compute(x=2).value == 3\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + + assert str(model.type_) == "repro_mod._foo_Model" + assert result.returncode == 0, f"Clean-process reload JSON roundtrip failed:\n{result.stderr}" + + +def test_reloaded_importable_generated_model_allows_stale_factory_aliases(tmp_path, monkeypatch): + module_dir = tmp_path / "reload_alias_case" + module_dir.mkdir() + module_path = module_dir / "alias_mod.py" + module_path.write_text( + "\n".join( + [ + "from ccflow import Flow, FromContext", + "", + "def foo(x: FromContext[int]) -> int:", + " return x + 1", + "", + "bar = Flow.model(foo)", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import alias_mod + + assert str(alias_mod.bar().type_) == "alias_mod._foo_Model" + reloaded = importlib.reload(alias_mod) + + assert str(reloaded.bar().type_) == "alias_mod._foo_Model" + assert reloaded.bar().flow.compute(x=2).value == 3 + + +def test_reloaded_importable_generated_model_allows_stale_decorator_aliases(tmp_path, monkeypatch): + module_dir = tmp_path / "reload_decorator_alias_case" + module_dir.mkdir() + module_path = module_dir / "decor_alias_mod.py" + module_path.write_text( + "\n".join( + [ + "from ccflow import Flow, FromContext", + "", + "@Flow.model", + "def foo(x: FromContext[int]) -> int:", + " return x + 1", + "", + "foo_alias = foo", + "", + ] + ) + ) + monkeypatch.syspath_prepend(str(module_dir)) + + import decor_alias_mod + + assert str(decor_alias_mod.foo().type_) == "decor_alias_mod._foo_Model" + reloaded = importlib.reload(decor_alias_mod) + + assert str(reloaded.foo().type_) == "decor_alias_mod._foo_Model" + assert str(reloaded.foo_alias().type_) == "decor_alias_mod._foo_Model" + assert reloaded.foo().flow.compute(x=2).value == 3 + + +def test_importable_bound_model_context_transform_json_roundtrip_cross_process(): + model = basic_loader(source="library", multiplier=3).flow.with_context(value=increment_b(amount=3)) + payload = model.model_dump_json() + script = ( + "from ccflow import BaseModel\n" + f"data = {payload!r}\n" + "model = BaseModel.model_validate_json(data)\n" + "result = model.flow.compute(b=1)\n" + "assert result.value == 12, f'Expected 12, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process bound-model JSON roundtrip failed:\n{result.stderr}" + + +def test_dependency_graph_cloudpickle_roundtrip(): + from ccflow.evaluators import get_dependency_graph + + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, penalty: FromContext[int]) -> int: + return x + penalty + + model = root(x=source()) + ctx = FlowContext(value=10, penalty=1) + graph = get_dependency_graph(model.__call__.get_evaluation_context(model, ctx)) + + restored = rcploads(rcpdumps(graph)) + assert restored.root_id == graph.root_id + assert set(restored.graph.keys()) == set(graph.graph.keys()) + assert set(restored.ids.keys()) == set(graph.ids.keys()) + + # The restored graph's evaluation contexts should still be functional + for key in graph.ids: + original_ec = graph.ids[key] + restored_ec = restored.ids[key] + assert type(restored_ec.model).__name__ == type(original_ec.model).__name__ + assert restored_ec.fn == original_ec.fn + + +def test_with_context_validates_static_override_types(): + """Static value type mismatch should be caught when context specs are normalized.""" + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="with_context\\(\\)"): + add(a=1).flow.with_context(b="not_an_int") + + +def test_context_transform_serializes_embedded_config_and_bound_args(): + transform_factory = Flow.context_transform(increment_b.__wrapped__) + binding = transform_factory(amount=3) + assert isinstance(binding, flow_model_module.ContextTransform) + assert binding.kind == "context_transform" + assert binding.serialized_config is not None + assert binding.bound_args == {"amount": 3} + + +def test_context_transform_rejects_none_for_required_param(): + with pytest.raises(TypeError, match="Context transform argument"): + increment_b(amount=None) + + +def test_context_transform_rejects_lazy_params(): + with pytest.raises(TypeError, match="does not support Lazy"): + Flow.context_transform(lazy_context_transform_for_rejection) + + +def test_context_transform_direct_call_uses_serialized_payload_when_original_binding_is_plain(monkeypatch): + module = ModuleType("ccflow_test_direct_context_transform") + + def increment(value: FromContext[int]) -> int: + return value + 1 + + increment.__module__ = module.__name__ + increment.__qualname__ = increment.__name__ + module.increment = increment + monkeypatch.setitem(sys.modules, module.__name__, module) + + transform_factory = Flow.context_transform(increment) + binding = transform_factory() + + assert binding.serialized_config is not None + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=binding) + restored = pickle.loads(pickle.dumps(bound, protocol=5)) + assert restored.flow.compute(value=4).value == 15 + + +def test_context_transform_json_roundtrip_recoerces_regular_bound_args(): + from ccflow import BaseModel + + @Flow.model + def load(day: FromContext[date]) -> str: + return day.isoformat() + + @Flow.context_transform + def shift_from_anchor(anchor: date, days: FromContext[int]) -> date: + return anchor + timedelta(days=days) + + bound = load().flow.with_context(day=shift_from_anchor(anchor=date(2024, 1, 1))) + restored = BaseModel.model_validate_json(bound.model_dump_json()) + + assert bound.flow.compute(days=2).value == "2024-01-03" + assert restored.flow.compute(days=2).value == "2024-01-03" + + +def test_context_transform_json_roundtrip_reports_malformed_payload_cleanly(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_context(b=increment_b(amount=1)) + dumped = bound.model_dump(mode="python") + dumped["context_spec"]["operations"][0]["spec"]["serialized_config"] = "not-base64" + + restored = type(bound).model_validate(dumped) + with pytest.raises(TypeError, match="payload does not contain"): + restored.flow.compute(b=4) + + +def test_context_transform_supports_nested_functions_with_serialized_payload(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + @Flow.context_transform + def nested_transform(b: FromContext[int], amount: int) -> int: + return b + amount + + binding = nested_transform(amount=3) + assert binding.serialized_config is not None + + bound = add(a=1).flow.with_context(b=binding) + restored = rcploads(rcpdumps(bound)) + assert restored.flow.compute(b=4).value == 8 + + +def test_context_transform_supports_non_importable_main_functions_with_serialized_payload(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + def main_transform(value: FromContext[int]) -> int: + return value + 1 + + main_transform.__module__ = "__main__" + main_transform.__qualname__ = main_transform.__name__ + + transformed = Flow.context_transform(main_transform) + binding = transformed() + assert binding.serialized_config is not None + + bound = add(a=1).flow.with_context(b=binding) + restored = rcploads(rcpdumps(bound)) + assert restored.flow.compute(value=4).value == 6 + + +def test_context_transform_nested_function_survives_ray_task(local_ray_runtime): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + @Flow.context_transform + def nested_transform(b: FromContext[int], amount: int) -> int: + return b + amount + + bound = add(a=1).flow.with_context(b=nested_transform(amount=3)) + + @ray.remote + def run_model(model): + return model.flow.compute(b=4).value + + assert ray.get(run_model.remote(bound)) == 8 + + +def test_with_context_rejects_raw_callables(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="with_context\\(\\) 'b': expected int"): + add(a=1).flow.with_context(b=lambda ctx: ctx.b + 1) + + +def test_with_context_accepts_callable_literals_for_callable_context_fields(): + def increment(value: int) -> int: + return value + 1 + + @Flow.model + def apply(fn: FromContext[Callable[[int], int]], value: FromContext[int]) -> int: + return fn(value) + + assert apply(fn=increment).flow.compute(value=2).value == 3 + assert apply().flow.compute(fn=increment, value=2).value == 3 + assert apply().flow.with_context(fn=increment).flow.compute(value=2).value == 3 + + +def test_with_context_rejects_wrong_transform_position(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date + end_date + + with pytest.raises(TypeError, match="Field context transforms must be passed by keyword"): + load().flow.with_context(increment_b(amount=1)) + + with pytest.raises(TypeError, match="Patch transforms must be passed positionally"): + load().flow.with_context(start_date=shift_integer_window(amount=10)) + + +def test_chained_with_context_preserves_patch_and_field_order(): + @Flow.context_transform + def patch_a() -> dict[str, int]: + return {"a": 2} + + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + field_then_patch = source().flow.with_context(a=1).flow.with_context(patch_a()) + patch_then_field = source().flow.with_context(patch_a()).flow.with_context(a=1) + + assert ( + field_then_patch.model_dump(mode="json")["context_spec"]["operations"] + != patch_then_field.model_dump(mode="json")["context_spec"]["operations"] + ) + assert field_then_patch.flow.compute().value == 2 + assert patch_then_field.flow.compute().value == 1 + + +def test_chained_with_context_transform_reads_original_context(): + @Flow.context_transform + def patch_a() -> dict[str, int]: + return {"a": 2} + + @Flow.context_transform + def b_from_a(a: FromContext[int]) -> int: + return a + 3 + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + bound = source().flow.with_context(patch_a()).flow.with_context(b=b_from_a()) + + assert bound.flow.compute(a=10).value == 15 + assert bound.flow.inspect().required_inputs == {"a": int} + + +def test_chained_with_context_later_field_override_skips_dead_field_transform(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + bound = source().flow.with_context(a=seed_plus_one()).flow.with_context(a=1) + + assert bound.flow.inspect().bound_inputs == {"a": 1} + assert bound.flow.inspect().runtime_inputs == {} + assert bound.flow.inspect().required_inputs == {} + assert bound.flow.compute().value == 1 + + for dumps, loads in ((pickle.dumps, pickle.loads), (rcpdumps, rcploads)): + restored = loads(dumps(bound, protocol=5)) + assert restored.flow.compute().value == 1 + + +def test_with_context_accepts_wrapped_mapping_patch_annotations(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date * 1000 + end_date + + annotated = load().flow.with_context(annotated_start_patch()) + optional = load().flow.with_context(optional_start_patch()) + + assert annotated.flow.compute(start_date=1, end_date=5).value == 2005 + assert optional.flow.compute(start_date=1, end_date=5).value == 3005 + + +def test_patch_then_keyword_override_precedence(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date * 1000 + end_date + + bound = load().flow.with_context(shift_integer_window(amount=10), start_date=100) + result = bound(FlowContext(start_date=1, end_date=2)) + assert result.value == 100_012 + + +def test_chained_transforms_read_original_runtime_context(): + @Flow.model + def load(start_date: FromContext[int], end_date: FromContext[int]) -> int: + return start_date * 1000 + end_date + + bound = load().flow.with_context( + shift_integer_window(amount=10), + start_date=bump_start_date(amount=100), + ) + + result = bound(FlowContext(start_date=1, end_date=2)) + assert result.value == 101_012 + + +def test_registry_integration_for_generated_models(): + registry = ModelRegistry.root().clear() + model = basic_loader(source="library", multiplier=3) + registry.add("loader", model) + + retrieved = registry["loader"] + assert isinstance(retrieved, CallableModel) + assert retrieved(SimpleContext(value=4)).value == 12 + + +def test_any_annotation_preserves_literal_strings(): + """A parameter typed Any should keep literal strings; registry should not steal them.""" + registry = ModelRegistry.root().clear() + dep_model = basic_loader(source="library", multiplier=1) + registry.add("my_key", dep_model) + + @Flow.model + def uses_any(x: Any, y: FromContext[int]) -> int: + return y if isinstance(x, str) else 999 + + model = uses_any(x="my_key") + result = model.flow.compute(y=3) + assert result.value == 3, "literal string should not be replaced by registry entry for Any-typed param" + + +def test_registry_lookup_does_not_steal_coercible_string_literals(): + registry = ModelRegistry.root().clear() + registry.add("3", data_source(base_value=10)) + registry.add("2024-01-01", data_source(base_value=20)) + registry.add("data.csv", data_source(base_value=30)) + registry.add("loader", data_source(base_value=40)) + + @Flow.model + def uses_int(x: int) -> str: + return f"{type(x).__name__}:{x}" + + @Flow.model + def uses_date(day: date) -> str: + return f"{type(day).__name__}:{day}" + + @Flow.model + def uses_path(path: Path) -> str: + return f"{type(path).__name__}:{path}" + + try: + assert uses_int(x="3").flow.compute(value=1).value == "int:3" + assert uses_date(day="2024-01-01").flow.compute(value=1).value == "date:2024-01-01" + path_model = uses_path(path="data.csv") + assert path_model.path == Path("data.csv") + assert path_model.flow.compute(value=1).value.endswith(":data.csv") + + dependency_model = uses_int(x="loader") + assert isinstance(dependency_model.x, CallableModel) + assert dependency_model.flow.compute(value=1).value == "int:41" + finally: + registry.clear() + + +def test_unexpected_type_adapter_errors_are_not_silently_swallowed(): + class BrokenSchema: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + raise RuntimeError("boom") + + def bad(x: BrokenSchema, y: FromContext[int]) -> int: + del x, y + return 0 + + with pytest.raises(RuntimeError, match="boom"): + Flow.model(bad) + + +@pytest.mark.parametrize("error", [RuntimeError("boom"), TypeError("boom")]) +def test_unexpected_type_validation_errors_are_not_rewritten(error): + from pydantic_core import core_schema + + class BrokenValidation: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + del source, handler + + def validate(value): + del value + raise error + + return core_schema.no_info_plain_validator_function(validate) + + @Flow.model + def bad(x: BrokenValidation, y: FromContext[int]) -> int: + del x, y + return 0 + + with pytest.raises(type(error), match="boom"): + bad(x=object()) + + +@pytest.mark.parametrize("error", [RuntimeError("boom"), AttributeError("boom")]) +def test_unexpected_type_hint_resolution_errors_propagate(monkeypatch, error): + def broken_get_type_hints(*args, **kwargs): + raise error + + monkeypatch.setattr(flow_model_module, "get_type_hints", broken_get_type_hints) + + def add(x: int) -> int: + return x + + with pytest.raises(type(error), match="boom"): + Flow.model(add) + + def transform(x: FromContext[int]) -> int: + return x + + with pytest.raises(type(error), match="boom"): + Flow.context_transform(transform) + + +def test_generated_model_flow_api_introspection_and_execution(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + assert model.flow.inspect().context_inputs == {"b": int} + assert model.flow.inspect().bound_inputs == {"a": 10} + assert model.flow.inspect().required_inputs == {"b": int} + assert model.flow.compute(b=5).value == 15 + + +def test_flow_api_is_self_describing_for_interactive_sessions(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + flow = add(a=1).flow + + assert dir(flow) == ["compute", "inspect", "with_context"] + assert "inspect" in dir(flow) + assert "input_specs" not in dir(flow) + assert "argument_specs" not in dir(flow) + assert "inspect" in repr(flow) + + inspect_signature = inspect.signature(flow.inspect) + assert "inputs" not in inspect_signature.parameters + assert inspect_signature.parameters["dependencies"].annotation == Literal["direct", "recursive", "none"] + assert inspect_signature.return_annotation is flow_model_module.FlowInspection + + +def test_flow_api_completes_after_binding_property_value_in_ipython(): + pytest.importorskip("IPython") + from IPython.core.interactiveshell import InteractiveShell + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + shell = InteractiveShell.instance() + shell.user_ns["flow_model_completion_target"] = add(a=1) + shell.user_ns["flow_model_completion_flow"] = shell.user_ns["flow_model_completion_target"].flow + + matches = shell.Completer.attr_matches("flow_model_completion_flow.") + assert "flow_model_completion_flow.compute" in matches + assert "flow_model_completion_flow.inspect" in matches + assert "flow_model_completion_flow.input_specs" not in matches + assert "flow_model_completion_flow.argument_specs" not in matches + assert "flow_model_completion_flow.help" not in matches + assert "flow_model_completion_flow.validate_inputs" not in matches + + +def test_flow_inspect_inputs_include_defaults_and_sources(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + model = add(a=10) + inspection = model.flow.inspect() + + assert inspection.inputs["a"] == flow_model_module.InputSpec("a", int, False, flow_model_module._UNSET_FLOW_INPUT, 10, "construction") + assert inspection.inputs["b"] == flow_model_module.InputSpec( + "b", int, True, flow_model_module._UNSET_FLOW_INPUT, flow_model_module._UNSET_FLOW_INPUT, "runtime" + ) + assert inspection.inputs["c"] == flow_model_module.InputSpec("c", int, False, 5, 5, "function_default") + assert inspection.inputs["b"].required + assert not inspection.inputs["c"].required + assert model.flow.inspect(b=2).inputs["b"] == flow_model_module.InputSpec("b", int, False, flow_model_module._UNSET_FLOW_INPUT, 2, "runtime") + + +def test_flow_inspect_inputs_keep_dependency_value_but_use_compact_repr(): + @Flow.model + def add(a: int, b: int) -> int: + return a + b + + child = add(a=1, b=1) + model = add(a=1, b=child) + + spec = model.flow.inspect().inputs["b"] + assert spec.value is child + assert spec.value_repr == "" + assert "meta=" not in repr(spec) + assert "value=" in repr(spec) + + +def test_flow_inspect_with_runtime_values_is_structural_not_a_validator(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + + inspection = model.flow.inspect(b=2, unused=1) + assert inspection.inputs["b"].value == 2 + assert inspection.inputs["b"].source == "runtime" + assert "unused" not in inspection.inputs + + regular_inspection = model.flow.inspect(a=1, b=2) + assert regular_inspection.inputs["b"].value == 2 + assert regular_inspection.inputs["b"].source == "runtime" + assert regular_inspection.inputs["a"].value == 10 + assert "input check" not in repr(regular_inspection) + + missing_inspection = model.flow.inspect(FlowContext()) + assert missing_inspection.inputs["b"].required + assert flow_model_module._is_unset_flow_input(missing_inspection.inputs["b"].value) + + +def test_flow_inspect_reports_direct_dependencies_and_unused_context(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 2 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=source()) + explanation = model.flow.inspect(value=3, bonus=4, unused=5) + + assert explanation.inputs["x"].value is model.flow.inspect().inputs["x"].value + assert explanation.required_inputs == {"bonus": int} + assert len(explanation.dependencies) == 1 + assert explanation.dependencies[0].path == "x" + assert explanation.dependencies[0].context == FlowContext(value=3) + child_inspection = explanation.dependencies[0].model.flow.inspect(explanation.dependencies[0].context) + assert child_inspection.runtime_inputs == {"value": int} + assert child_inspection.inputs["value"].value == 3 + assert "dependencies" in str(explanation) + assert repr(explanation) == str(explanation) + assert "FlowInspection(model=_root_Model)" in repr(explanation) + assert "inputs:" in repr(explanation) + assert "x -> _source_Model context=FlowContext(value=3)" in repr(explanation) + + assert model.flow.inspect(dependencies="none").dependencies == () + assert model.flow.inspect().runtime_inputs == {"bonus": int} + assert set(model.flow.inspect().bound_inputs) == {"x"} + with pytest.raises(ValueError, match="dependencies must be one of"): + model.flow.inspect(dependencies="full") + + +def test_flow_with_context_returns_new_bound_model_without_mutating_source(): + @Flow.model(auto_unwrap=False) + def add(x: int, y: int, z: FromContext[int] = 2) -> int: + return x + y + z + + @Flow.context_transform() + def shift_1(z: FromContext[int]) -> int: + return z + 1 + + model = add(x=1, y=add(x=1, y=1)) + bound = model.flow.with_context(z=shift_1()) + + assert model.flow.compute(z=2).value == 7 + assert bound.flow.compute(z=2).value == 9 + assert model.flow.inspect().bound_inputs.keys() == {"x", "y"} + assert bound.flow.inspect().runtime_inputs == {"z": int} + + +def test_bound_model_inspect_reports_wrapped_argument_dependencies(): + @Flow.model(auto_unwrap=False) + def add(x: int, y: int, z: FromContext[int] = 2) -> int: + return x + y + z + + @Flow.context_transform() + def shift_1(z: FromContext[int]) -> int: + return z + 1 + + bound = add(x=1, y=add(x=1, y=1)).flow.with_context(z=shift_1()) + + inspection = bound.flow.inspect(z=2) + + assert "FlowInspection(model=_add_Model.flow.with_context(...))" in repr(inspection) + assert inspection.inputs["z"] == flow_model_module.InputSpec("z", int, False, 2, 3, "context_transform") + assert len(inspection.dependencies) == 1 + assert inspection.dependencies[0].path == "y" + assert inspection.dependencies[0].context == FlowContext(z=3) + assert "" not in repr(inspection) + + root = add(x=1, y=bound) + root_inspection = root.flow.inspect() + assert root_inspection.required_inputs == {} + assert root_inspection.dependencies[0].model.flow.inspect().required_inputs == {"z": int} + + root_with_context = root.flow.inspect(z=2) + child_with_context = root_with_context.dependencies[0].model.flow.inspect(root_with_context.dependencies[0].context) + assert child_with_context.runtime_inputs == {"z": int} + assert child_with_context.inputs["z"].value == 3 + assert "context=FlowContext(z=2)" in repr(root_with_context) + + +def test_flow_inspect_direct_dependencies_do_not_traverse_grandchildren(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.model + def root(x: int) -> int: + return x + + model = root(x=middle(x=leaf())) + + root_inspection = model.flow.inspect() + child_inspection = root_inspection.dependencies[0].model.flow.inspect() + + assert root_inspection.dependencies[0].path == "x" + assert child_inspection.dependencies[0].model.flow.inspect().required_inputs == {"v": int} + + +def test_flow_inspect_recursive_dependencies_report_grandchild_requirements(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.model + def root(x: int) -> int: + return x + + model = root(x=middle(x=leaf())) + + missing = model.flow.inspect(dependencies="recursive") + + assert tuple(dependency.path for dependency in missing.dependencies) == ("x", "x.x") + assert missing.dependencies[1].model.flow.inspect().required_inputs == {"v": int} + + supplied = model.flow.inspect(v=2, dependencies="recursive") + + assert supplied.inputs == model.flow.inspect(v=2, dependencies="none").inputs + assert tuple(dependency.path for dependency in model.flow.inspect(v=2, dependencies="direct").dependencies) == ("x",) + assert tuple(dependency.context for dependency in supplied.dependencies) == (FlowContext(v=2), FlowContext(v=2)) + assert "v" not in supplied.inputs + leaf_with_context = supplied.dependencies[1].model.flow.inspect(supplied.dependencies[1].context) + assert leaf_with_context.inputs["v"].value == 2 + + +def test_flow_inspect_recursive_dependencies_treat_bound_transform_inputs_as_used(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.context_transform() + def shift(seed: FromContext[int]) -> int: + return seed + 1 + + model = middle(x=leaf().flow.with_context(v=shift())) + + missing = model.flow.inspect(dependencies="recursive") + assert missing.dependencies[0].model.flow.inspect().required_inputs == {"seed": int} + + supplied = model.flow.inspect(seed=1, dependencies="recursive") + assert tuple(dependency.path for dependency in supplied.dependencies) == ("x",) + assert flow_model_module._context_values(supplied.dependencies[0].context) == {"seed": 1} + child = supplied.dependencies[0].model.flow.inspect(supplied.dependencies[0].context) + assert child.inputs["v"].value == 2 + + +def test_flow_inspect_projects_bound_dependency_context_to_runtime_inputs(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + @Flow.context_transform() + def shift(seed: FromContext[int]) -> int: + return seed + 1 + + model = root(x=leaf().flow.with_context(v=shift())) + inspection = model.flow.inspect(seed=1, bonus=10, unused=99) + + assert flow_model_module._context_values(inspection.dependencies[0].context) == {"seed": 1} + child = inspection.dependencies[0].model.flow.inspect(inspection.dependencies[0].context) + assert child.inputs["v"].value == 2 + + +def test_flow_inspect_dependency_requirements_skip_lazy_edges(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def root(x: Lazy[int]) -> int: + return 0 + + inspection = root(x=leaf()).flow.inspect(dependencies="recursive") + + assert inspection.dependencies[0].lazy + assert inspection.dependencies[0].model.flow.inspect().required_inputs == {"v": int} + assert "x -> _leaf_Model lazy" in repr(inspection) + + +def test_flow_inspect_dependency_requirements_skip_lazy_descendants(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.model + def root(x: Lazy[int]) -> int: + return 0 + + model = root(x=middle(x=leaf())) + inspection = model.flow.inspect(dependencies="recursive") + + assert model.flow.compute().value == 0 + assert tuple(dependency.path for dependency in inspection.dependencies) == ("x", "x.x") + assert all(dependency.lazy for dependency in inspection.dependencies) + assert inspection.dependencies[1].model.flow.inspect().required_inputs == {"v": int} + + +def test_flow_inspect_recursive_dependencies_tolerates_partial_transform_context(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def middle(x: int) -> int: + return x + + @Flow.context_transform() + def shift(seed: FromContext[int], other: FromContext[int]) -> int: + return seed + other + + model = middle(x=leaf().flow.with_context(v=shift())) + + inspection = model.flow.inspect(seed=1, dependencies="recursive") + + assert tuple(dependency.path for dependency in inspection.dependencies) == ("x",) + child_inspection = inspection.dependencies[0].model.flow.inspect(inspection.dependencies[0].context) + assert child_inspection.runtime_inputs == {"seed": int, "other": int} + assert child_inspection.inputs["v"].source == "context_transform" + assert flow_model_module._is_unset_flow_input(child_inspection.inputs["v"].value) + + +def test_bound_flow_inspect_tolerates_partial_transform_context_without_dependencies(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.context_transform() + def shift(seed: FromContext[int], other: FromContext[int]) -> int: + return seed + other + + bound = leaf().flow.with_context(v=shift()) + + inspection = bound.flow.inspect(seed=1, dependencies="none") + + assert inspection.runtime_inputs == {"seed": int, "other": int} + assert inspection.required_inputs == {"seed": int, "other": int} + assert inspection.inputs["v"].source == "context_transform" + assert flow_model_module._is_unset_flow_input(inspection.inputs["v"].value) + assert inspection.dependencies == () + + +def test_flow_inspect_tolerates_partial_plain_callable_dependency_context(): + class PlainSource(CallableModel): + @property + def context_type(self): + return SimpleContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + inspection = root(x=PlainSource()).flow.inspect(bonus=1) + + assert len(inspection.dependencies) == 1 + assert inspection.dependencies[0].context == FlowContext() + assert inspection.dependencies[0].model.flow.inspect().required_inputs == {"value": int} + + +def test_plain_callable_flow_inspect_reports_partial_runtime_context(): + class PairContext(ContextBase): + a: int + b: int + + class PlainAdder(CallableModel): + @property + def context_type(self): + return PairContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: PairContext) -> GenericResult[int]: + return GenericResult(value=context.a + context.b) + + inspection = PlainAdder().flow.inspect(a=1) + + assert inspection.inputs["a"].value == 1 + assert flow_model_module._is_unset_flow_input(inspection.inputs["b"].value) + + +def test_bound_flow_inspect_uses_static_context_for_dependency_requirements(): + @Flow.model(auto_unwrap=False) + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model(auto_unwrap=False) + def root(x: int, v: FromContext[int] = 0) -> int: + return x + 1 + + bound = root(x=leaf()).flow.with_context(v=1) + inspection = bound.flow.inspect() + + assert bound.flow.compute().value == 2 + assert inspection.dependencies[0].context == FlowContext(v=1) + assert inspection.dependencies[0].model.flow.inspect(inspection.dependencies[0].context).inputs["v"].value == 1 + + +def test_flow_inspect_recursive_dependencies_do_not_swallow_transform_type_errors(): + @Flow.model + def leaf(v: FromContext[int]) -> int: + return v + + @Flow.model + def root(x: int) -> int: + return x + + @Flow.context_transform() + def broken(seed: FromContext[int]) -> int: + raise TypeError("transform bug") + + model = root(x=leaf().flow.with_context(v=broken())) + + with pytest.raises(TypeError, match="transform bug"): + model.flow.inspect(seed=1, dependencies="recursive") + + +def test_dependency_evaluation_preserves_original_exception_type_with_context_note(): + class CustomDependencyError(RuntimeError): + pass + + @Flow.model + def child() -> int: + raise CustomDependencyError("boom") + + @Flow.model + def root(x: int) -> int: + return x + + with pytest.raises(CustomDependencyError) as exc_info: + root(x=child()).flow.compute() + + if hasattr(exc_info.value, "__notes__"): + assert exc_info.value.__notes__ == ["Error while evaluating dependency root.x -> _child_Model."] + + +def test_generated_factory_signature_is_keyword_only_and_includes_model_base_fields(): + sig = inspect.signature(basic_loader) + + assert all(param.kind is inspect.Parameter.KEYWORD_ONLY for param in sig.parameters.values()) + assert list(sig.parameters) == ["source", "multiplier", "value"] + assert sig.parameters["source"].default is flow_model_module._UNSET_FLOW_INPUT + assert sig.parameters["multiplier"].default is flow_model_module._UNSET_FLOW_INPUT + assert sig.parameters["value"].default is flow_model_module._UNSET_FLOW_INPUT + + with pytest.raises(TypeError, match="positional"): + basic_loader("library", 3) + + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + custom_sig = inspect.signature(add) + assert custom_sig.parameters["a"].kind is inspect.Parameter.KEYWORD_ONLY + assert custom_sig.parameters["b"].kind is inspect.Parameter.KEYWORD_ONLY + assert custom_sig.parameters["multiplier"].kind is inspect.Parameter.KEYWORD_ONLY + assert custom_sig.parameters["multiplier"].default == 1 + + +def test_context_transform_factory_signature_only_exposes_regular_bindings(): + sig = inspect.signature(increment_b) + + assert list(sig.parameters) == ["amount"] + assert sig.parameters["amount"].kind is inspect.Parameter.KEYWORD_ONLY + assert sig.parameters["amount"].annotation is int + assert sig.parameters["amount"].default is inspect.Parameter.empty + assert sig.return_annotation is inspect.Signature.empty + + with pytest.raises(TypeError, match="positional"): + increment_b(1) + + +def test_plain_callable_flow_api_paths(): + class PlainModel(CallableModel): + @property + def context_type(self): + return SimpleContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + del context + return [] + + model = PlainModel() + + assert dir(model.flow) == ["compute", "inspect", "with_context"] + assert not hasattr(model.flow, "context_inputs") + assert not hasattr(model.flow, "runtime_inputs") + assert not hasattr(model.flow, "required_inputs") + assert not hasattr(model.flow, "bound_inputs") + assert model.flow.inspect().context_inputs == {"value": int} + assert model.flow.inspect().required_inputs == {"value": int} + assert model.flow.inspect().bound_inputs == {} + assert model.flow.compute({"value": 3}).value == 3 + + with pytest.raises(TypeError, match="either one context object or contextual keyword inputs"): + model.flow.compute(SimpleContext(value=1), value=2) + + +def test_plain_callable_flow_api_is_base_property(): + class PlainModel(CallableModel): + offset: int = 7 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.offset + context.value) + + model = PlainModel() + + assert dir(model.flow) == ["compute", "inspect", "with_context"] + assert model(SimpleContext(value=3)).value == 10 + + +def test_plain_callable_flow_compute_preserves_matching_context_subclass(): + class RequestContext(SimpleContext): + request_id: str + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[tuple[int, str]]: + return GenericResult(value=(context.value, getattr(context, "request_id", "missing"))) + + context = RequestContext(value=3, request_id="abc") + + model = PlainModel() + + assert model.flow.compute(context).value == (3, "abc") + assert model.flow.with_context().flow.compute(context).value == (3, "abc") + + bound = model.flow.with_context(value=4) + other_context = RequestContext(value=3, request_id="def") + + assert bound.flow.compute(context).value == (4, "abc") + assert bound.flow.compute(other_context).value == (4, "def") + first_key = cache_key(bound.__call__.get_evaluation_context(bound, context), effective=True) + second_key = cache_key(bound.__call__.get_evaluation_context(bound, other_context), effective=True) + assert first_key != second_key + + +def test_plain_callable_flow_compute_uses_default_context_when_available(): + class DefaultContext(ContextBase): + value: int + tag: str = "default" + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = DefaultContext(value=7)) -> GenericResult[tuple[int, str]]: + return GenericResult(value=(context.value, context.tag)) + + model = PlainModel() + + assert model.flow.inspect().required_inputs == {} + assert model.flow.compute().value == (7, "default") + assert model.flow.compute(value=3).value == (3, "default") + assert model.flow.compute(tag="runtime").value == (7, "runtime") + + empty_bound = model.flow.with_context() + assert empty_bound.flow.inspect().required_inputs == {} + assert empty_bound.flow.compute().value == (7, "default") + + bound = model.flow.with_context(tag="bound") + assert bound.flow.inspect().required_inputs == {} + assert bound.flow.compute().value == (7, "bound") + assert bound.flow.compute(value=3).value == (3, "bound") + + +def test_plain_callable_default_context_private_state_is_preserved(): + class DefaultContext(ContextBase): + value: int + _bonus: int = PrivateAttr(default=1) + + default_context = DefaultContext(value=7) + default_context._bonus = 40 + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = default_context) -> GenericResult[tuple[int, int, bool]]: + return GenericResult(value=(context.value, context._bonus, context is default_context)) + + model = PlainModel() + + assert model.flow.compute().value == (7, 40, True) + assert model.flow.compute(value=8).value == (8, 40, False) + assert model.flow.with_context().flow.compute().value == (7, 40, True) + assert model.flow.with_context(value=8).flow.compute().value == (8, 40, False) + + +def test_bound_plain_callable_flow_compute_uses_default_context_for_dynamic_transforms(): + class DefaultContext(ContextBase): + value: int + seed: int + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = DefaultContext(value=7, seed=8)) -> GenericResult[tuple[int, int]]: + return GenericResult(value=(context.value, context.seed)) + + @Flow.context_transform + def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + bound = PlainModel().flow.with_context(value=from_seed()) + + assert bound.flow.inspect().required_inputs == {} + assert bound.flow.inspect().runtime_inputs == {"seed": int} + assert bound.flow.compute().value == (9, 8) + assert bound.flow.compute(seed=10).value == (11, 10) + + +def test_bound_plain_callable_dependency_uses_default_context_baseline(): + class DefaultContext(ContextBase): + value: int + seed: int + _bonus: int = PrivateAttr(default=1) + + default_context = DefaultContext(value=7, seed=8) + default_context._bonus = 40 + + class PlainModel(CallableModel): + @Flow.call + def __call__(self, context: DefaultContext = default_context) -> GenericResult[tuple[int, int, int]]: + return GenericResult(value=(context.value, context.seed, context._bonus)) + + @Flow.context_transform + def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + @Flow.model + def consume(x: tuple[int, int, int]) -> tuple[int, int, int]: + return x + + static_bound = PlainModel().flow.with_context(value=3) + assert static_bound.flow.compute().value == (3, 8, 40) + assert consume(x=static_bound).flow.compute().value == (3, 8, 40) + + dynamic_bound = PlainModel().flow.with_context(value=from_seed()) + assert dynamic_bound.flow.compute().value == (9, 8, 40) + assert consume(x=dynamic_bound).flow.compute().value == (9, 8, 40) + assert consume(x=dynamic_bound).flow.compute(seed=10).value == (11, 10, 40) + + model = consume(x=dynamic_bound) + graph = get_dependency_graph(model.__call__.get_evaluation_context(model, model.context_type())) + plain_contexts = [] + for evaluation_context in graph.ids.values(): + while isinstance(evaluation_context.context, ModelEvaluationContext): + evaluation_context = evaluation_context.context + if isinstance(evaluation_context.model, PlainModel): + plain_contexts.append(evaluation_context.context) + + assert plain_contexts == [DefaultContext(value=9, seed=8)] + assert plain_contexts[0]._bonus == 40 + + +def test_unhashable_annotations_still_validate(): + annotation = Annotated[int, []] + + @Flow.model + def add(x: annotation, y: FromContext[annotation]) -> int: + return x + y + + assert add(x="2").flow.compute(y="3").value == 5 + + +def test_flow_model_internal_contract_helpers_cover_portable_edge_shapes(): + annotation = Annotated[dict[str, list[GenericResult[int]]], "contract"] + restored = binding_module._restore_annotation(binding_module._serialize_annotation(annotation)) + + assert get_args(restored)[1:] == ("contract",) + assert binding_module._restore_annotation(binding_module._serialize_annotation(Literal["a", "b"])) == Literal["a", "b"] + assert binding_module._restore_annotation(binding_module._serialize_annotation(int | None)) == Optional[int] + assert binding_module._clone_function_without_annotations(5) == 5 + assert repr(get_args(Lazy[int])[1]) == "Lazy" + assert repr(get_args(FromContext[int])[1]) == "FromContext" + with pytest.raises(TypeError, match="Lazy is an annotation marker"): + Lazy() + with pytest.raises(TypeError, match="Unknown serialized annotation payload"): + binding_module._restore_annotation(object()) + with pytest.raises(TypeError, match="Unknown serialized annotation payload kind"): + binding_module._restore_annotation(binding_module._SerializedAnnotation(kind="bad", value=None)) + with pytest.raises(TypeError, match="Unknown Flow.model parameter payload"): + binding_module._restore_flow_model_param(object()) + with pytest.raises(TypeError, match="Unknown Flow.model config payload"): + binding_module._restore_flow_model_config(object()) + + compatible = binding_module._context_type_annotations_compatible + assert compatible(Any, int) + assert compatible(int, Any) + assert compatible(int | str, int) + assert compatible(int | None, type(None)) + assert compatible(Literal["a", "b"], Literal["a"]) + assert compatible(str, Literal["a"]) + assert compatible(list[int], list[int]) + assert not compatible(int, int | None) + assert not compatible(Literal["a"], Literal["b"]) + assert not compatible(Literal["a"], str) + assert not compatible(list[int], list[str]) + + flow_model_module.clear_flow_model_caches() + unhashable_annotation = Annotated[int, []] + assert flow_model_module._type_adapter(unhashable_annotation).validate_python("1") == 1 + assert flow_model_module._type_adapter(unhashable_annotation).validate_python("2") == 2 + assert flow_model_module._concrete_context_type(Optional[SimpleContext]) is SimpleContext + assert flow_model_module._concrete_context_type(int) is None + assert flow_model_module._bound_field_names(object()) == set() + assert flow_model_module._expected_type_repr(int | str) == "int | str" + assert flow_model_module._coerce_value("payload", object(), object(), "Regular parameter").__class__ is object + assert flow_model_module._unwrap_model_result(3) == 3 + assert flow_model_module._resolve_registry_candidate("__missing_registry_entry__") is None + assert flow_model_module._registry_candidate_allowed(object(), "literal") + assert not flow_model_module._registry_candidate_allowed(int, "not-an-int") + assert flow_model_module._registry_candidate_allowed(int, "3") + with pytest.raises(ImportError, match="does not have a _generated_model"): + flow_model_module._new_generated_flow_model_for_pickle("ccflow.GenericResult") + with pytest.raises(TypeError, match="model_base must be a CallableModel subclass"): + flow_model_module._resolve_generated_model_bases(int) + assert flow_model_module._context_transform_identifier(static_patch()) == "static_patch" + assert flow_model_module._context_transform_repr(static_patch()) == "static_patch()" + assert flow_model_module._context_transform_repr(increment_b(amount=2)) == "increment_b(amount=2)" + assert flow_model_module._context_transform_repr(5) == "5" + + @Flow.model + def dep(value: FromContext[int]) -> int: + return value + + with pytest.raises(TypeError, match="must return a mapping"): + flow_model_module._validate_patch_result(dep(), 1) + with pytest.raises(TypeError, match="string field names"): + flow_model_module._validate_patch_result(dep(), {1: 2}) + lazy_thunk = flow_model_module._make_lazy_thunk(dep(), FlowContext(value=4)) + assert lazy_thunk() == 4 + assert lazy_thunk() == 4 + coercing_thunk = flow_model_module._make_coercing_lazy_thunk(lambda: "5", "value", int) + assert coercing_thunk() == 5 + assert coercing_thunk() == 5 + + +def test_compute_accepts_context_object_for_from_context_models(): + model = basic_loader(source="library", multiplier=3) + + assert model.flow.inspect().context_inputs == {"value": int} + assert model.flow.inspect().required_inputs == {"value": int} + assert model.flow.compute({"value": 4}).value == 12 + assert model.flow.compute(SimpleContext(value=5)).value == 15 + + with pytest.raises(TypeError, match="either one context object or contextual keyword inputs"): + model.flow.compute(SimpleContext(value=1), value=2) + + +def test_additional_validation_and_hint_fallback_paths(monkeypatch): + class MissingFieldContext(ContextBase): + start_date: date + + with pytest.raises(TypeError, match="must define fields for all FromContext parameters"): + + @Flow.model(context_type=MissingFieldContext) + def bad_missing(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class ExtraRequiredContext(ContextBase): + start_date: date + end_date: date + label: str + + with pytest.raises(TypeError, match="has required fields that are not declared as FromContext parameters"): + + @Flow.model(context_type=ExtraRequiredContext) + def bad_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class BadAnnotationContext(ContextBase): + value: str + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=BadAnnotationContext) + def bad_annotation(value: FromContext[int]) -> int: + return value + + with pytest.raises(TypeError, match="context_type must be a ContextBase subclass"): + + @Flow.model(context_type=int) + def bad_context_type(value: FromContext[int]) -> int: + return value + + @Flow.model + def source(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="cannot default to a CallableModel"): + + @Flow.model + def bad_default(value: FromContext[int] = source()) -> int: + return value + + with pytest.raises(TypeError, match="return type annotation"): + + @Flow.model + def missing_return(value: int): + return value + + with pytest.raises(TypeError, match="does not support positional-only parameter 'value'"): + + @Flow.model + def bad_positional_only(value: int, /, bonus: FromContext[int]) -> int: + return value + bonus + + with pytest.raises(TypeError, match="does not support variadic positional parameter 'values'"): + + @Flow.model + def bad_varargs(*values: int) -> int: + return sum(values) + + with pytest.raises(TypeError, match="does not support variadic keyword parameter 'values'"): + + @Flow.model + def bad_varkw(**values: int) -> int: + return sum(values.values()) + + @Flow.model + def keyword_only(value: int, *, bonus: FromContext[int]) -> int: + return value + bonus + + assert keyword_only(value=2).flow.compute(bonus=3).value == 5 + + @Flow.model + def keyword_only_context(*, context: SimpleContext, offset: int) -> int: + return context.value + offset + + assert keyword_only_context(context=SimpleContext(value=3), offset=4).flow.compute().value == 7 + + def missing_hints(*args, **kwargs): + raise AttributeError("missing hints") + + monkeypatch.setattr(flow_model_module, "get_type_hints", missing_hints) + + def add(x: int, y: FromContext[int]) -> int: + return x + y + + with pytest.raises(AttributeError, match="missing hints"): + Flow.model(add) + + +def test_unresolved_forward_refs_do_not_silently_strip_from_context(): + namespace: dict[str, object] = {} + + with pytest.raises(NameError, match="MissingType"): + exec( + """ +from __future__ import annotations +from ccflow import Flow, FromContext + +@Flow.context_transform +def transform(a: MissingType, b: FromContext[int]) -> int: + return b +""", + namespace, + ) + + with pytest.raises(NameError, match="MissingType"): + exec( + """ +from __future__ import annotations +from ccflow import Flow, FromContext + +@Flow.model +def model(a: MissingType, b: FromContext[int]) -> int: + return b +""", + namespace, + ) + + +def test_context_type_validates_parameterized_annotations(): + class IntListContext(ContextBase): + vals: list[int] + + @Flow.model(context_type=IntListContext) + def total(vals: FromContext[list[int]]) -> int: + return sum(vals) + + assert total().flow.compute(vals=["1", "2"]).value == 3 + + class IntDictContext(ContextBase): + vals: dict[str, int] + + @Flow.model(context_type=IntDictContext) + def total_dict(vals: FromContext[dict[str, int]]) -> int: + return sum(vals.values()) + + assert total_dict().flow.compute(vals={"a": "1", "b": 2}).value == 3 + + class StrListContext(ContextBase): + vals: list[str] + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=StrListContext) + def bad(vals: FromContext[list[int]]) -> int: + return sum(vals) + + +def test_context_type_allows_compatible_union_literal_and_generic_fields(): + class RichContext(ContextBase): + flag: Literal["a"] + maybe_value: int | None + values: list[int] + + @Flow.model(context_type=RichContext) + def summarize( + flag: FromContext[Literal["a", "b"]], + maybe_value: FromContext[int | None], + values: FromContext[list[int]], + ) -> int: + return len(flag) + (maybe_value or 0) + sum(values) + + assert summarize().flow.compute(flag="a", maybe_value=None, values=[1, 2]).value == 4 + + +def test_context_type_rejects_nullable_field_for_non_nullable_from_context(): + class OptionalValueContext(ContextBase): + value: int | None + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=OptionalValueContext) + def add_one(value: FromContext[int]) -> int: + return value + 1 + + +def test_compute_forwards_options_with_custom_evaluator(): + calls = {"count": 0} + + @Flow.model + def counter(value: FromContext[int]) -> int: + calls["count"] += 1 + return value + + cache = MemoryCacheEvaluator() + model = counter() + + result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + + assert result1.value == 10 + assert result2.value == 10 + assert calls["count"] == 1 + + +def test_compute_forwards_options_with_graph_evaluator(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + model = root(x=source()) + + # GraphEvaluator evaluates in topo order; verify _options flows through + # and the graph evaluator is actually used (doesn't raise CycleError, computes correctly) + result = model.flow.compute( + FlowContext(value=3, bonus=7), + _options=FlowOptions(evaluator=GraphEvaluator()), + ) + + assert result.value == 37 + + +def test_compute_forwards_options_through_bound_model(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + bound = add(a=10).flow.with_context(b=5) + + result1 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True)) + result2 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True)) + + assert result1.value == 15 + assert result2.value == 15 + assert calls["count"] == 1 + + +def test_compute_forwards_options_for_plain_callable_model(): + calls = {"count": 0} + + class Counter(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.value + self.offset) + + cache = MemoryCacheEvaluator() + model = Counter(offset=5) + + result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) + + assert result1.value == 15 + assert result2.value == 15 + assert calls["count"] == 1 + + +def test_bound_plain_callable_compute_applies_context_before_validation(): + calls = {"count": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.a + context.b) + + assert PlainSource().flow.with_context(a=1).flow.compute(b=2).value == 3 + assert calls["count"] == 1 + + +def test_bound_plain_callable_empty_with_context_preserves_optional_none_context(): + calls = {"count": 0} + + class OptionalContext(ContextBase): + value: int = 1 + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: Optional[OptionalContext] = None) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=0 if context is None else context.value) + + bound = PlainSource().flow.with_context() + + assert bound.flow.compute().value == 0 + assert bound(None).value == 0 + assert bound({"value": 5}).value == 5 + assert bound.flow.compute(value=5).value == 5 + assert bound.flow.compute(_options=FlowOptions(evaluator=GraphEvaluator())).value == 0 + assert calls["count"] == 5 + + +def test_bound_plain_callable_dynamic_context_transform_runs_before_validation(): + calls = {"count": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.a + context.b) + + assert PlainSource().flow.with_context(a=seed_plus_one()).flow.compute(seed=1, b=10).value == 12 + assert calls["count"] == 1 + + +def test_bound_plain_callable_direct_call_applies_context_before_validation(): + calls = {"count": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.a + context.b) + + bound = PlainSource().flow.with_context(a=1) + + assert bound(FlowContext(b=2)).value == 3 + assert bound.__deps__(FlowContext(b=2)) == [(bound.model, [RequiredContext(a=1, b=2)])] + assert calls["count"] == 1 + + +def test_bound_plain_callable_direct_kwargs_use_flow_context(): + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: FlowContext) -> GenericResult[str]: + return GenericResult(value=type(context).__name__) + + bound = PlainSource().flow.with_context(a=1) + + assert bound(b=2).value == "FlowContext" + assert bound.flow.compute(b=2).value == "FlowContext" + + +def test_bound_plain_callable_compute_preserves_bound_scoped_options(): + calls = {"source": 0, "evaluator": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + bound = PlainSource().flow.with_context(a=1) + + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(bound,)): + assert bound.flow.compute(b=2).value == 103 + + assert calls == {"source": 1, "evaluator": 1} + + +def test_bound_plain_callable_dependency_preserves_bound_scoped_options(): + calls = {"source": 0, "consumer": 0, "evaluator": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + @Flow.model + def consumer(x: int) -> int: + calls["consumer"] += 1 + return x + + bound = PlainSource().flow.with_context(a=1) + model = consumer(x=bound) + + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(bound,)): + assert model.flow.compute(b=2).value == 103 + + assert calls == {"source": 1, "consumer": 1, "evaluator": 1} + + +def test_bound_plain_callable_dependency_identity_ignores_unused_ambient_context(): + calls = {"source": 0, "consumer": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + @Flow.model + def consumer(x: int) -> int: + calls["consumer"] += 1 + return x + + cache = MemoryCacheEvaluator() + model = consumer(x=PlainSource().flow.with_context(a=1, b=2)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(unused="one").value == 3 + assert model.flow.compute(unused="two").value == 3 + + assert calls == {"source": 1, "consumer": 1} + + +def test_bound_dependency_identity_rewrites_dynamic_context_once(): + calls = {"source": 0, "consumer": 0} + + @Flow.model + def source(a: FromContext[int]) -> int: + calls["source"] += 1 + return a + + @Flow.model + def consumer(x: int) -> int: + calls["consumer"] += 1 + return x + + cache = MemoryCacheEvaluator() + model = consumer(x=source().flow.with_context(a=non_idempotent_a_step())) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(a=1).value == 2 + assert model.flow.compute(a=2).value == 3 + + assert calls == {"source": 2, "consumer": 2} + + +def test_lazy_bound_plain_callable_dependency_preserves_bound_scoped_options(): + calls = {"source": 0, "consumer": 0, "evaluator": 0} + + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.a + context.b) + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + @Flow.model + def consumer(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["consumer"] += 1 + return lazy_value() if use_lazy else 0 + + bound = PlainSource().flow.with_context(a=1) + model = consumer(lazy_value=bound) + + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(bound,)): + assert model.flow.compute(b=2, use_lazy=True).value == 103 + + assert calls == {"source": 1, "consumer": 1, "evaluator": 1} + + +def test_bound_flow_required_inputs_subtracts_static_context(): + class RequiredContext(ContextBase): + a: int + b: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + return GenericResult(value=context.a + context.b) + + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + assert PlainSource().flow.with_context(a=1).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(a=1).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(a=static_bad()).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(static_patch()).flow.inspect().required_inputs == {"b": int} + assert add().flow.with_context(a=1, b=2).flow.inspect().required_inputs == {} + + +def test_bound_flow_required_inputs_reflects_dynamic_field_transform_inputs(): + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + bound = add().flow.with_context(a=seed_plus_one()) + + assert bound.flow.compute(seed=1, b=10).value == 12 + assert bound.flow.inspect().context_inputs == {"a": int, "b": int} + assert bound.flow.inspect().runtime_inputs == {"b": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"b": int, "seed": int} + + +def test_bound_flow_bound_inputs_include_static_context_bindings(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=1).flow.with_context(b=2) + + assert bound.flow.inspect().context_inputs == {"b": int} + assert bound.flow.inspect().runtime_inputs == {} + assert bound.flow.inspect().required_inputs == {} + assert bound.flow.inspect().bound_inputs == {"a": 1, "b": 2} + + +def test_bound_flow_bound_inputs_drops_static_patch_after_dynamic_override(): + @Flow.model + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + bound = add().flow.with_context(static_patch()).flow.with_context(a=seed_plus_one()) + + assert bound.flow.compute(seed=3, b=10).value == 14 + assert bound.flow.inspect().bound_inputs == {} + assert bound.flow.inspect().context_inputs == {"a": int, "b": int} + assert bound.flow.inspect().runtime_inputs == {"b": int, "seed": int} + assert bound.flow.inspect().required_inputs == {"b": int, "seed": int} + + +def test_generated_model_cache_ignores_unused_flow_context_fields(): + calls = {"source": 0, "root": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + calls["root"] += 1 + return x + bonus + + cache = MemoryCacheEvaluator() + model = root(x=source()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model(FlowContext(value=3, bonus=7, unused="one")).value == 37 + assert model(FlowContext(value=3, bonus=7, unused="two")).value == 37 + + assert calls == {"source": 1, "root": 1} + assert len(cache.cache) == 2 + + +def test_generated_model_cache_uses_effective_key_through_transparent_evaluator(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + evaluator = combine_evaluators(LoggingEvaluator(), cache) + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + assert model.flow.compute(b=1, unused="one").value == 11 + assert model.flow.compute(b=1, unused="two").value == 11 + + assert calls["count"] == 1 + assert len(cache.cache) == 1 + + +def test_effective_cache_key_ignores_untokenizable_unused_ambient_context(): + class BadToken: + def __deepcopy__(self, memo): + return self + + def __getstate__(self): + raise RuntimeError("unused field should not be tokenized") + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=1) + clean_context = FlowContext(b=2) + noisy_context = FlowContext(b=2, unused=BadToken()) + clean_eval = model.__call__.get_evaluation_context(model, clean_context) + noisy_eval = model.__call__.get_evaluation_context(model, noisy_context) + + assert cache_key(noisy_eval, effective=True) == cache_key(clean_eval, effective=True) + + +def test_recursive_effective_cache_key_ignores_untokenizable_unused_ambient_context(): + class BadToken: + def __deepcopy__(self, memo): + return self + + def __getstate__(self): + raise RuntimeError("unused field should not be tokenized") + + @Flow.model + def cycle(value: int, a: FromContext[int]) -> int: + return a + + model = cycle(value=cycle(value=1)) + model.value = model + context = FlowContext(a=1, unused=BadToken()) + evaluation = model.__call__.get_evaluation_context(model, context) + + cache_key(evaluation, effective=True) + graph = get_dependency_graph(evaluation) + assert graph.root_id in graph.ids + with pytest.raises(graphlib.CycleError): + tuple(graphlib.TopologicalSorter(graph.graph).static_order()) + + +def test_cache_key_effective_option_preserves_plain_callable_structural_identity(): + calls = {"count": 0} + + class Counter(CallableModel): + @Flow.call + def __call__(self, context: FlowContext) -> GenericResult[int]: + calls["count"] += 1 + return GenericResult(value=context.value) + + model = Counter() + eval1 = model.__call__.get_evaluation_context(model, FlowContext(value=10, unused="one")) + eval2 = model.__call__.get_evaluation_context(model, FlowContext(value=10, unused="two")) + + assert cache_key(eval1) != cache_key(eval2) + assert cache_key(eval1, effective=True) == cache_key(eval1) + assert cache_key(eval2, effective=True) == cache_key(eval2) + + +def test_generated_model_effective_cache_key_includes_behavior_token(monkeypatch): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + def helper_v1(): + return 1 + + def helper_v2(): + return 2 + + model = add(a=10) + model_type = type(model) + context = model.__call__.get_evaluation_context(model, FlowContext(b=1, unused="same"), _options={"cacheable": True}) + cache = MemoryCacheEvaluator() + + monkeypatch.setattr(model_type, "__ccflow_tokenizer_deps__", [helper_v1], raising=False) + key1 = cache.key(context) + + monkeypatch.setattr(model_type, "__ccflow_tokenizer_deps__", [helper_v2], raising=False) + monkeypatch.delattr(model_type, "__ccflow_tokenizer_cache__", raising=False) + key2 = cache.key(context) + + assert key1 != key2 + + +def test_generated_model_cache_changes_when_consumed_context_field_changes(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(b=1, unused="same").value == 11 + assert model.flow.compute(b=2, unused="same").value == 12 + + assert calls["count"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_model_cache_changes_when_regular_literal_input_changes(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert add(a=10).flow.compute(b=1).value == 11 + assert add(a=20).flow.compute(b=1).value == 21 + + assert calls["count"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_model_cache_does_not_ignore_context_read_by_nontransparent_evaluator(): + class AddAmbient(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + result = context() + return result.model_copy(update={"value": result.value + context.context.unused}) + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + cache = MemoryCacheEvaluator() + evaluator = combine_evaluators(AddAmbient(), cache) + model = add(a=10) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + assert model.flow.compute(b=1, unused=100).value == 111 + assert model.flow.compute(b=1, unused=200).value == 211 + + assert len(cache.cache) == 2 + + +def test_generated_model_cache_key_preserves_result_validation_option(): + calls = {"count": 0} + + @Flow.model + def raw_result(value: FromContext[int]) -> GenericResult[int]: + calls["count"] += 1 + return {"value": str(value)} + + cache = MemoryCacheEvaluator() + model = raw_result() + + with FlowOptionsOverride(options=FlowOptions(evaluator=cache, cacheable=True, validate_result=False)): + first = model.flow.compute(value=3, unused="one") + + with FlowOptionsOverride(options=FlowOptions(evaluator=cache, cacheable=True, validate_result=True)): + second = model.flow.compute(value=3, unused="two") + + assert first == {"value": "3"} + assert second == GenericResult[int](value=3) + assert calls["count"] == 2 + assert len(cache.cache) == 2 + + +def test_generated_model_cache_ignores_unresolved_unused_lazy_dependency_context(): + calls = {"choose": 0} + + @Flow.model + def source(missing: FromContext[int]) -> int: + return missing * 10 + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=source()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, unused="two").value == 7 + + assert calls["choose"] == 1 + assert len(cache.cache) == 1 + + +def test_unused_lazy_plain_dependency_defers_missing_context_validation(): + calls = {"source": 0, "choose": 0} + + class RequiredContext(ContextBase): + missing: int + + class PlainSource(CallableModel): + @Flow.call + def __call__(self, context: RequiredContext) -> GenericResult[int]: + calls["source"] += 1 + return GenericResult(value=context.missing * 10) + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=PlainSource()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, unused="two").value == 7 + with pytest.raises(ValidationError): + model.flow.compute(use_lazy=True) + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 1 + + +def test_unused_lazy_bound_dependency_uses_partial_context_identity(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=source().flow.with_context(a=1)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, a=100).value == 7 + assert model.flow.compute(use_lazy=False, a=200).value == 7 + with pytest.raises(TypeError, match="Missing contextual input"): + model.flow.compute(use_lazy=True, a=300) + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 1 + + +def test_unused_lazy_bound_dependency_with_unresolved_transform_has_stable_identity(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + + @Flow.context_transform + def a_from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + model = choose(x=7, lazy_value=source().flow.with_context(a=a_from_seed())) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=False, b=1, unused="one").value == 7 + assert model.flow.compute(use_lazy=False, b=2, unused="two").value == 7 + with pytest.raises(TypeError, match="Missing contextual input"): + model.flow.compute(use_lazy=True, b=3) + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 1 + + +def test_unused_lazy_resolved_dependency_identity_is_conservative(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int]) -> int: + calls["source"] += 1 + return a + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert choose(x=7, lazy_value=source().flow.with_context(a=1)).flow.compute(use_lazy=False).value == 7 + assert choose(x=7, lazy_value=source().flow.with_context(a=2)).flow.compute(use_lazy=False).value == 7 + + assert calls == {"source": 0, "choose": 2} + assert len(cache.cache) == 2 + + +def test_used_lazy_bound_dependency_identity_applies_dynamic_context_transform(): + calls = {"source": 0, "choose": 0} + + @Flow.model + def source(a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + + @Flow.model + def choose(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else 0 + + cache = MemoryCacheEvaluator() + model = choose(lazy_value=source().flow.with_context(a=seed_plus_one())) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=True, seed=1, b=10).value == 12 + assert model.flow.compute(use_lazy=True, seed=2, b=10).value == 13 + + assert calls == {"source": 2, "choose": 2} + + +def test_used_lazy_generated_dependency_identity_respects_contextual_defaults(): + calls = {"dep": 0, "source": 0, "choose": 0} + + @Flow.model + def dep(v: FromContext[int]) -> int: + calls["dep"] += 1 + return v + + @Flow.model + def source(d: int, a: FromContext[int], b: FromContext[int]) -> int: + calls["source"] += 1 + return a + b + d + + @Flow.model + def choose(lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else 0 + + cache = MemoryCacheEvaluator() + model = choose(lazy_value=source(d=dep(), a=1)) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert model.flow.compute(use_lazy=True, b=10, v=3).value == 14 + assert model.flow.compute(use_lazy=True, b=10, v=999).value == 1010 + + assert calls == {"dep": 2, "source": 2, "choose": 2} + + +def test_generated_model_cache_distinguishes_unresolved_lazy_dependency_models(): + calls = {"choose": 0} + + @Flow.model + def source_one(missing: FromContext[int]) -> int: + return missing * 10 + + @Flow.model + def source_two(missing: FromContext[int]) -> int: + return missing * 100 + + @Flow.model + def choose(x: int, lazy_value: Lazy[int], use_lazy: FromContext[bool]) -> int: + calls["choose"] += 1 + return lazy_value() if use_lazy else x + + cache = MemoryCacheEvaluator() + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert choose(x=7, lazy_value=source_one()).flow.compute(use_lazy=False).value == 7 + assert choose(x=7, lazy_value=source_two()).flow.compute(use_lazy=False).value == 7 + + assert calls["choose"] == 2 + assert len(cache.cache) == 2 + + +def test_bound_generated_sibling_dependencies_keep_distinct_rewritten_contexts_with_graph_cache(): + calls = {"source": 0, "root": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value + + @Flow.model + def root(left: int, right: int) -> int: + calls["root"] += 1 + return left + right + + cache = MemoryCacheEvaluator() + evaluator = combine_evaluators(cache, GraphEvaluator()) + shared = source() + model = root(left=shared.flow.with_context(value=1), right=shared.flow.with_context(value=2)) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + assert model.flow.compute(value=99, unused="ambient").value == 3 + + assert calls == {"source": 2, "root": 1} + assert len(cache.cache) >= 3 + + +def test_bound_generated_model_dependency_graph_traverses_collapsed_child_deps(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 10 + + @Flow.model + def root(x: int, bonus: FromContext[int]) -> int: + return x + bonus + + bound = root(x=source()).flow.with_context(bonus=1) + graph = get_dependency_graph(bound.__call__.get_evaluation_context(bound, FlowContext(value=2, bonus=99))) + + assert graph.root_id not in graph.graph[graph.root_id] + assert len(graph.ids) == 3 + assert len(graph.graph[graph.root_id]) == 1 + + +def test_bound_model_cache_follows_rewritten_context_not_ambient_source_fields(): + calls = {"count": 0} + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["count"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + bound = add(a=10).flow.with_context(b=parity_bucket()) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + assert bound(FlowContext(raw=1)).value == 11 + assert bound(FlowContext(raw=3)).value == 11 + + assert calls["count"] == 1 + assert len(cache.cache) == 2 + + +def test_bound_model_cache_respects_wrapped_model_scoped_evaluator(): + calls = {"add": 0, "evaluator": 0} + + class OffsetEvaluator(EvaluatorBase): + def __call__(self, context: ModelEvaluationContext): + calls["evaluator"] += 1 + result = context() + return result.model_copy(update={"value": result.value + 100}) + + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + calls["add"] += 1 + return a + b + + cache = MemoryCacheEvaluator() + base = add(a=1) + bound = base.flow.with_context(b=2) + + with FlowOptionsOverride(options={"evaluator": cache, "cacheable": True}): + with FlowOptionsOverride(options={"evaluator": OffsetEvaluator()}, models=(base,)): + assert bound.flow.compute().value == 103 + assert bound.flow.compute().value == 3 + + assert calls == {"add": 2, "evaluator": 1} + + +def test_generated_models_cross_process_pickle(): + """Module-level @Flow.model instances are deserializable in a separate process.""" + model = basic_loader(source="library", multiplier=3) + data = pickle.dumps(model, protocol=5) + encoded = base64.b64encode(data).decode() + script = ( + "import pickle, base64\n" + f"data = base64.b64decode('{encoded}')\n" + "model = pickle.loads(data)\n" + "from ccflow import FlowContext\n" + "result = model.flow.compute(value=4)\n" + "assert result.value == 12, f'Expected 12, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process unpickle failed:\n{result.stderr}" + + +def test_local_generated_models_cross_process_cloudpickle(): + """Local @Flow.model instances carry their generated class across processes.""" + from ray.cloudpickle import dumps as rcpdumps + + def make_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + return add(a=1) + + encoded = base64.b64encode(rcpdumps(make_model(), protocol=5)).decode() + script = ( + "import base64\n" + "from ray.cloudpickle import loads as rcploads\n" + f"data = base64.b64decode('{encoded}')\n" + "model = rcploads(data)\n" + "result = model.flow.compute(b=2)\n" + "assert result.value == 3, f'Expected 3, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process local cloudpickle failed:\n{result.stderr}" + + +def test_local_generated_model_postponed_annotations_cross_process_cloudpickle(): + """Local generated models should restore from analyzed config, not worker-side type-hint resolution.""" + from ray.cloudpickle import dumps as rcpdumps + + namespace: dict[str, Any] = {} + exec( + """ +from __future__ import annotations +from ccflow import Flow, FromContext + +def make_model(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + return add(a=1) +""", + namespace, + ) + + encoded = base64.b64encode(rcpdumps(namespace["make_model"](), protocol=5)).decode() + script = ( + "import base64\n" + "from ray.cloudpickle import loads as rcploads\n" + f"data = base64.b64decode('{encoded}')\n" + "model = rcploads(data)\n" + "result = model.flow.compute(b=2)\n" + "assert result.value == 3, f'Expected 3, got {result.value}'\n" + ) + result = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=30) + assert result.returncode == 0, f"Cross-process postponed-annotation cloudpickle failed:\n{result.stderr}" + + +def test_local_generated_model_complex_annotations_same_process_cloudpickle(): + """Local restore should reuse the analyzed contract for nested typing shapes.""" + + def make_model(): + @Flow.model + def summarize( + values: Annotated[list[int], Field(min_length=1)], + mode: Literal["sum", "first"], + offsets: tuple[int, ...], + maybe_offset: int | None, + scale: FromContext[int], + ) -> GenericResult[int]: + total = values[0] if mode == "first" else sum(values) + return GenericResult(value=(total + sum(offsets) + (maybe_offset or 0)) * scale) + + return summarize(values=[1, 2], mode="sum", offsets=(3,), maybe_offset=None) + + restored = rcploads(rcpdumps(make_model(), protocol=5)) + + assert restored.flow.compute(scale=2) == GenericResult[int](value=12) + + +def test_model_base_fields_visible_in_bound_inputs(): + """model_base fields that are explicitly set should appear in bound_inputs.""" + + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + assert model.flow.inspect().bound_inputs == {"a": 10, "multiplier": 3} + + # Default-only model_base field is NOT in bound_inputs + model_default = add(a=10) + assert model_default.flow.inspect().bound_inputs == {"a": 10} + + +def _annotation_contains(annotation: object, expected: object) -> bool: + if annotation is expected: + return True + return any(_annotation_contains(arg, expected) for arg in get_args(annotation)) + + +def _schema_contains(schema: object, predicate) -> bool: + if isinstance(schema, dict): + if predicate(schema): + return True + return any(_schema_contains(value, predicate) for value in schema.values()) + if isinstance(schema, list): + return any(_schema_contains(value, predicate) for value in schema) + return False + + +def test_generated_model_fields_preserve_construction_schema(): + @Flow.model + def str_source(tag: FromContext[str]) -> str: + return tag + + @Flow.model + def consumer(x: int, lazy_value: Lazy[int], y: FromContext[int]) -> int: + return x + lazy_value() + y + + generated_cls = getattr(consumer, "_generated_model") + fields = generated_cls.model_fields + properties = generated_cls.model_json_schema()["properties"] + + assert _annotation_contains(fields["x"].annotation, int) + assert _annotation_contains(fields["x"].annotation, CallableModel) + assert _schema_contains(properties["x"], lambda node: node.get("type") == "integer") + assert _schema_contains(properties["x"], lambda node: node.get("$ref", "").endswith("/CallableModel")) + + assert _annotation_contains(fields["lazy_value"].annotation, CallableModel) + assert _schema_contains(properties["lazy_value"], lambda node: node.get("$ref", "").endswith("/CallableModel")) + + assert _annotation_contains(fields["y"].annotation, int) + assert _schema_contains(properties["y"], lambda node: node.get("type") == "integer") + + model = consumer(x=str_source(), lazy_value=str_source()) + assert model.flow.compute(tag="3", y="4").value == 10 + + with pytest.raises(TypeError, match="Lazy"): + consumer(x=1, lazy_value=1) + + with pytest.raises(TypeError, match="Regular parameter"): + model.flow.compute(tag="not_a_number", y=1) + + +def test_model_base_fields_rejected_by_compute(): + """compute() should reject kwargs matching model_base field names.""" + + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + with pytest.raises(TypeError, match="does not accept model configuration override\\(s\\): multiplier"): + model.flow.compute(b=5, multiplier=99) + + +def test_flow_model_public_exports_exclude_context_spec_models(): + assert "StaticValueSpec" not in flow_model_module.__all__ + assert "ContextTransform" not in flow_model_module.__all__ + assert "flow_context_transform" not in flow_model_module.__all__ + assert "DependencySpec" not in flow_model_module.__all__ + assert "InputCheck" not in flow_model_module.__all__ + assert not hasattr(ccflow, "StaticValueSpec") + assert not hasattr(ccflow, "ContextTransform") + assert not hasattr(ccflow, "flow_context_transform") + assert not hasattr(ccflow, "DependencySpec") + assert not hasattr(ccflow, "InputCheck") + assert not hasattr(flow_model_module, "flow_context_transform") + assert not hasattr(flow_model_module, "InputCheck") diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py new file mode 100644 index 0000000..e93bac8 --- /dev/null +++ b/ccflow/tests/test_flow_model_hydra.py @@ -0,0 +1,56 @@ +"""Hydra integration tests for the FromContext-based Flow.model API.""" + +from datetime import date +from pathlib import Path + +from ccflow import CallableModel, FlowContext, ModelRegistry + +CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") + + +def setup_function(): + ModelRegistry.root().clear() + + +def teardown_function(): + ModelRegistry.root().clear() + + +def test_basic_loader_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + loader = registry["flow_loader"] + assert isinstance(loader, CallableModel) + assert loader.flow.compute(value=10).value == 50 + + +def test_registry_dependency_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + transformer = registry["flow_transformer"] + assert transformer.source is registry["flow_source"] + assert transformer.flow.compute(value=5).value == 315 + + +def test_diamond_dependency_from_yaml_reuses_shared_source(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + aggregator = registry["diamond_aggregator"] + assert aggregator.input_a.source is registry["diamond_source"] + assert aggregator.input_b.source is registry["diamond_source"] + assert aggregator.flow.compute(value=10).value == 140 + + +def test_from_context_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + loader = registry["contextual_loader_model"] + processor = registry["contextual_processor_model"] + + assert processor.data is loader + result = processor.flow.compute(FlowContext(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31))) + assert result.value == "output:data_source:2024-03-01 to 2024-03-31" diff --git a/docs/wiki/Flow-Model.md b/docs/wiki/Flow-Model.md new file mode 100644 index 0000000..4bed310 --- /dev/null +++ b/docs/wiki/Flow-Model.md @@ -0,0 +1,633 @@ +# Flow Model + +`@Flow.model` turns a plain Python function into a real `CallableModel`. + +The design is intentionally narrow: + +- ordinary unmarked parameters are regular bound inputs, +- `FromContext[T]` marks the only runtime/contextual inputs, +- `@Flow.context_transform` defines reusable contextual rewrites, +- `.flow.compute(...)` is the execution entry point for the full DAG, +- `.flow.with_context(*patches, **field_overrides)` rewires contextual inputs on one dependency edge, +- upstream `CallableModel`s can still be passed as ordinary inputs. + +The goal is that a reader can look at one function signature and immediately +see: + +1. which values come from runtime context, +1. which values must be bound as regular configuration or dependencies, +1. how to rewrite contextual inputs for one branch of the graph. + +Generated models still plug into the existing evaluator, registry, cache, +Hydra, and serialization machinery — `@Flow.model` does not create a new +execution engine. + +## Core Example + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def foo(a: int, b: FromContext[int]) -> int: + return a + b + + +# Build an instance with a=11 bound, then supply b=12 at runtime: +configured = foo(a=11) +result = configured.flow.compute(b=12) +assert result.value == 23 # .value unwraps the GenericResult wrapper + +# Or create a different instance that stores b=12 as its contextual default: +prefilled = foo(a=11, b=12) +result = prefilled.flow.compute() +assert result.value == 23 +``` + +> **Note:** When the function returns a plain value (like `int` above) instead +> of a `ResultBase` subclass, `@Flow.model` automatically wraps it in +> `GenericResult`. Access the inner value with `.value`. + +This is the core contract: + +- `a` is a regular parameter — it must be bound at construction time, +- `b` is contextual because it is marked with `FromContext[int]` — it can come + from runtime context, a contextual default stored on the model instance, or a + function default, +- `.flow.compute(...)` may carry extra ambient context for upstream graph + branches, but it never binds regular parameters. + +Nothing is being mutated at execution time in the second example. +`prefilled = foo(a=11, b=12)` constructs a different model instance whose +contextual default for `b` is already `12`. Because `b` is still contextual, +incoming runtime context can still override that default. + +This means the following is **invalid**: + +```python +foo().flow.compute(a=11, b=12) +# TypeError: compute() cannot satisfy unbound regular parameter(s): a. +# Bind them at construction time; compute() only supplies runtime context. +``` + +`a` is not contextual, so it must be bound at construction time (`foo(a=11)`). +By contrast, extra ambient fields that are only needed by upstream +`with_context(...)` rewrites are allowed in `compute(**kwargs)` because +`@Flow.model` generated models use `FlowContext` (an open bag) as their +runtime context type. + +## Regular Parameters vs Contextual Parameters + +### Regular Parameters + +Regular parameters are the unmarked ones. + +They can be satisfied by: + +- a literal value, +- a default value from the function signature, +- an upstream `CallableModel`. + +When an upstream model is supplied, `@Flow.model` evaluates it with the current +context and passes the resolved value into the function. This is how you wire +stages together — just pass one model as an argument to another: + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def load_value(value: FromContext[int], offset: int) -> int: + return value + offset + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +# Wire load_value into add's 'a' parameter: +model = add(a=load_value(offset=5)) + +# At runtime, load_value runs first (value=7 + offset=5 = 12), +# then add runs (a=12 + b=12 = 24): +assert model.flow.compute(value=7, b=12).value == 24 +``` + +Only direct regular-parameter values are treated as upstream dependencies in +this first version. Containers such as `list`, `tuple`, and `dict` are ordinary +literal values; `@Flow.model` does not scan them for nested models. + +### Contextual Parameters + +Contextual parameters are the ones marked with `FromContext[...]`. + +They can be satisfied by: + +- runtime context, +- construction-time keyword inputs, stored as contextual defaults on the model instance, +- keyword callable literals for `FromContext[Callable[..., T]]` fields, +- function defaults. + +Construction-time contextual defaults cannot be `CallableModel` values, because +that would be ambiguous with dependency binding. If the contextual field itself +is typed to accept a `CallableModel`, pass the model as runtime context or via +`.flow.with_context(...)` as ordinary data. A raw callable passed positionally is +always treated as a regular argument candidate, not as a contextual default. In +other words, `FromContext[Callable[..., T]]` allows a callable literal only when +it is provided by keyword for that contextual field. + +A construction-time value for a contextual parameter is still a default, not a +conversion into a regular bound parameter. + +Generated models reserve a few framework attribute names for the model API: +`flow`, `meta`, `context_type`, `result_type`, and `type_`. Do not use these as +`@Flow.model` function parameter names. + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +model = add(a=10, b=5) +assert model.flow.compute().value == 15 +assert model.flow.compute(b=7).value == 17 +``` + +Contextual precedence is: + +1. branch-local `.flow.with_context(...)` rewrites, +1. incoming runtime context, +1. contextual defaults stored on the model instance, +1. function defaults. + +## `.flow.compute(...)` + +`.flow.compute(...)` is the ergonomic execution entry point for contextual +execution of the whole DAG. + +For generated `@Flow.model` stages it accepts either: + +- keyword inputs that become the ambient runtime context bag, or +- one context object. + +It does not accept both at the same time. + +Plain handwritten `CallableModel` instances also expose `.flow.compute(...)`. +For those models, keyword inputs build or update the runtime context. If the +decorated `@Flow.call` method declares a default context object, no-argument +`.flow.compute()` uses that default, and keyword inputs override fields from +that default for the `.flow.compute(...)` call. + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +model = add(a=10) +assert model.flow.compute(b=5).value == 15 +assert model.flow.compute(FlowContext(b=6)).value == 16 +``` + +For `@Flow.model` generated models, the kwargs form is intentionally a DAG +entrypoint: it can include extra fields needed only by upstream transformed +dependencies. Regular parameters are still never read from runtime context. +`compute()` enforces two guardrails on keyword inputs: + +- If a key matches an **unbound** regular parameter, it raises early instead of + silently treating that value as configuration. +- If a key matches an **already-bound** regular parameter, it raises to prevent + accidental rebinding. Use a context object (`FlowContext`) when you need + ambient fields whose names collide with bound regular parameters. + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def source(value: FromContext[int]) -> int: + return value + + +@Flow.model +def add(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + +@Flow.context_transform +def add_offset(value: FromContext[int], amount: int) -> int: + return value + amount + + +base = source() +model = add( + left=base.flow.with_context(value=add_offset(amount=1)), + right=base.flow.with_context(value=add_offset(amount=10)), +) + +assert model.flow.inspect().context_inputs == {"bonus": int} +assert model.flow.compute(value=5, bonus=100).value == 121 +``` + +If a regular parameter is already bound on the root model and you need to pass +an ambient context field with the same name for upstream graph nodes, use a +context object instead of keyword inputs. The kwargs form rejects keys that +match already-bound regular parameters to prevent accidental rebinding: + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def source(a: FromContext[int]) -> int: + return a + + +@Flow.model +def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + +model = combine(a=100, left=source()) + +# The context object form lets ambient 'a=7' flow to upstream nodes +# while root 'a' stays bound to 100: +assert model.flow.compute(FlowContext(a=7, bonus=5)).value == 112 + +# The kwargs form rejects 'a' because it is a bound regular parameter: +# model.flow.compute(a=7, bonus=5) +# → TypeError: compute() does not accept regular parameter override(s): a. +``` + +`compute()` returns the same result object you would get from `model(context)`, +unless `auto_unwrap=True` is enabled for an auto-wrapped plain return type: + +```python +from ccflow import Flow, FromContext + + +@Flow.model(auto_unwrap=True) +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +result = add(a=10).flow.compute(b=5) +assert result == 15 +``` + +## `@Flow.context_transform` and `.flow.with_context(...)` + +`@Flow.context_transform` defines reusable, serializable contextual rewrites using the +same `FromContext[...]` language as `@Flow.model`. A transform's return type +determines how it can be used in `with_context()`: + +- **Patch transforms** return a `Mapping` (e.g. `dict[str, object]`) of + contextual field names to replacement values. They are passed as **positional + inputs** to `with_context()`. +- **Field transforms** return a single scalar value. They are passed as + **keyword inputs** to `with_context()`, keyed by the contextual field they + replace. + +```python +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model +def count_visitors(location: str, start_date: FromContext[date], end_date: FromContext[date]) -> int: + days = (end_date - start_date).days + 1 + return days * 12 + len(location) + + +@Flow.model +def visitor_delta(current: int, previous: int, start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"window": f"{start_date} -> {end_date}", "change": current - previous} + + +# Patch transform — shifts multiple fields together, passed positionally: +@Flow.context_transform +def previous_window(start_date: FromContext[date], end_date: FromContext[date], days: int) -> dict[str, object]: + return { + "start_date": start_date - timedelta(days=days), + "end_date": end_date - timedelta(days=days), + } + + +# Field transform — shifts one field, passed as a keyword: +@Flow.context_transform +def shift_end(end_date: FromContext[date], days: int) -> date: + return end_date - timedelta(days=days) + + +current = count_visitors(location="library") + +# Patch transform: shift both dates back 30 days +previous = current.flow.with_context(previous_window(days=30)) + +# Mix both forms: patch shifts start_date, keyword shifts end_date independently +custom = current.flow.with_context( + previous_window(days=30), + end_date=shift_end(days=7), # keyword override replaces end_date from the patch +) + +delta = visitor_delta(current=current, previous=previous) +result = delta.flow.compute(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))) +``` + +In this example, `current` and `previous` share the same underlying +`count_visitors` configuration but see different date windows at runtime. +`previous` uses a patch transform to shift both dates back 30 days. +`custom` demonstrates combining both forms: the patch sets both dates, then the +keyword override replaces `end_date` with a different shift. Keyword overrides +always apply last. + +Key rules: + +- `with_context()` only targets contextual fields, +- positional inputs must be patch transforms, +- keyword overrides may be literals or field transforms, +- raw positional callables are rejected; use named `@Flow.context_transform` + helpers for positional patch transforms, +- keyword callable literals are allowed only when the target field is typed as + `FromContext[Callable[..., T]]`; other keyword callables must be field + transforms, +- transforms are branch-local — they only affect the wrapped dependency, not + the entire pipeline, +- patch results merge left-to-right, then keyword overrides apply last, +- every transform evaluates against the original incoming runtime context; if + multiple fields must move together, put that logic inside one patch + transform. + +Context transforms serialize the analyzed transform contract directly in +`serialized_config` so bound models can move through pickle and Ray workers +without re-resolving annotations from the defining module. This applies to +importable module functions, local functions, nested functions, `__main__`, and +notebook-defined transforms. Treat that `serialized_config` as an opaque +generated artifact owned by ccflow, not as a stable hand-written YAML/JSON +configuration format. + +## `context_type=...` + +When you want the `FromContext[...]` fields to match an existing nominal +context shape, use `context_type=...`: + +```python +from datetime import date + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model(context_type=DateRangeContext) +def count_visitors(location: str, start_date: FromContext[date], end_date: FromContext[date]) -> int: + days = (end_date - start_date).days + 1 + return days * 12 + len(location) +``` + +That preserves the primary `FromContext[...]` authoring model while letting +callers pass richer context objects whose relevant fields satisfy the declared +`context_type`. + +`context_type=...` is a validation/coercion contract for the named +`FromContext[...]` fields. Generated `@Flow.model` instances still expose +`FlowContext` as their runtime `context_type`. + +If the function genuinely needs the runtime context object itself inside the +function body on each call, use a normal `CallableModel` subclass instead of +`@Flow.model`. + +For class-based `CallableModel` methods that want to declare context fields as +keyword-only parameters, see `Flow.call(auto_context=...)` in +[Workflows](Workflows#flow-decorator). + +## Introspection + +The public `.flow` surface is intentionally small: + +- `model.flow.compute(...)`: evaluate the model, +- `model.flow.with_context(...)`: return a branch-local contextual wrapper, +- `model.flow.inspect(...)`: return a structured debugging summary. + +Use `inspect(...)` when you want to understand what is bound, what is still +contextual, and which direct dependencies are attached: + +```python +inspection = model.flow.inspect() + +inspection.inputs # direct function inputs and their sources +inspection.context_inputs # declared contextual contract +inspection.runtime_inputs # direct runtime inputs after wrapper bindings +inspection.required_inputs # required runtime inputs still needed +inspection.bound_inputs # construction/static values already fixed +inspection.dependencies # dependency edges, controlled by dependencies=... +``` + +The top-level input fields are intentionally current-level only. They describe +the model or wrapper you inspected, not a flattened view of the whole dependency +graph. `inspection.inputs` is a dict from that model's function input name to an +`InputSpec`. Each `InputSpec` reports the expected type, whether the input is +required, the declared default, the effective direct value if known, and whether +that value came from construction, a function default, runtime context, or +`with_context(...)`. + +Dependency information lives under `inspection.dependencies`. Each dependency +edge reports the parameter path, target model, projected context values when +known, and whether the edge is lazy. To inspect a child, inspect that child +model directly: + +```python +inspection = model.flow.inspect(value=3) +dependency = inspection.dependencies[0] + +if dependency.context is None: + child = dependency.model.flow.inspect() +else: + child = dependency.model.flow.inspect(dependency.context) + +child.inputs +child.runtime_inputs +child.required_inputs +``` + +This explicit nesting avoids merging unrelated models into one ambiguous input +namespace. It also avoids name collisions when multiple dependencies use the +same context field name for different branches. + +`context_inputs` intentionally stays faithful to the model's declared contract. +For bound models, `with_context(...)` bindings are reflected in +`runtime_inputs`, `required_inputs`, `bound_inputs`, and `inputs`. Literal +bindings satisfy their target fields. Transform bindings with runtime inputs add +those source context inputs to the effective runtime view. Static transforms, +meaning transforms with no contextual parameters, may be evaluated during +introspection so their output fields can be reported precisely. A transform +parameter like `seed: FromContext[int] = 0` is still a runtime input; its default +only means the caller is not required to provide it. + +`inspect(...)` reports the direct API for the current model or wrapper. By +default, it also inspects immediate dependencies: + +```python +model.flow.inspect(dependencies="direct") # default +model.flow.inspect(dependencies="recursive") # follow inspect-visible dependencies +model.flow.inspect(dependencies="none") # skip dependency inspection +``` + +Recursive inspection follows dependencies visible from constructed +`@Flow.model` inputs and `with_context(...)` wrappers. It is intended for +debugging generated-model trees, not as a complete evaluator graph. A +handwritten `CallableModel` can still appear as a dependency target when it is +bound to an `@Flow.model` regular input, but `inspect(...)` does not expand that +handwritten model's custom `CallableModel.__deps__` implementation. Recursive +mode changes which dependency edges are listed; it does not change the meaning +of the top-level input fields. Lazy dependency edges are listed as `lazy`; +inspect the lazy target directly if you want to see what it could require when +called. + +Because introspection may evaluate static `@Flow.context_transform` functions, +context transforms should be deterministic, side-effect-free, and cheap. This is +the same practical contract expected by cache identity and dependency analysis. + +Example: + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + +model = add(a=10) +inspection = model.flow.inspect() +assert inspection.context_inputs == {"b": int, "c": int} +assert inspection.runtime_inputs == {"b": int, "c": int} +assert inspection.required_inputs == {"b": int} +assert inspection.bound_inputs == {"a": 10} +assert set(inspection.inputs) == {"a", "b", "c"} + + +@Flow.context_transform +def from_seed(seed: FromContext[int]) -> int: + return seed + 1 + + +bound = add(a=10).flow.with_context(b=from_seed()) +bound_inspection = bound.flow.inspect() +assert bound_inspection.context_inputs == {"b": int, "c": int} +assert bound_inspection.runtime_inputs == {"c": int, "seed": int} +assert bound_inspection.required_inputs == {"seed": int} +assert bound_inspection.bound_inputs == {"a": 10} +``` + +In the bound example, `b` remains in `context_inputs` because `add` still +declares `b` as part of its contextual contract. It is absent from +`runtime_inputs` because this wrapper supplies `b` from `from_seed()`. `seed` +appears in `runtime_inputs` because the transform reads it from the caller's +runtime context. + +When you pass a proposed context object or runtime kwargs to `inspect(...)`, +inspection uses those values to fill known direct `inputs` and project context +onto dependency edges. It does not validate the proposed context or report +unused fields: + +```python +inspection = add(a=10).flow.inspect(b=2, unused=1) +assert inspection.inputs["b"].value == 2 +assert "unused" not in inspection.inputs +``` + +This keeps `inspect(...)` structural. A stricter debug-time input checker can be +added later with explicit current-model versus graph-wide semantics. + +## Lazy Dependencies + +`Lazy[T]` defers evaluation of an upstream dependency until the function body +explicitly calls it. This is useful when a dependency is expensive and only +needed conditionally: + +```python +from ccflow import Flow, FlowContext, FromContext, Lazy + + +@Flow.model +def load_value(value: FromContext[int]) -> int: + return value * 10 + + +@Flow.model +def maybe_use(current: int, fallback: Lazy[int], threshold: FromContext[int]) -> int: + if current > threshold: + return current # fallback is never evaluated + return fallback() # evaluate only when needed + + +model = maybe_use(current=50, fallback=load_value()) + +# current (50) > threshold (10), so load_value never runs: +assert model.flow.compute(value=3, threshold=10).value == 50 + +# current (5) <= threshold (10), so load_value runs (3 * 10 = 30): +model2 = maybe_use(current=5, fallback=load_value()) +assert model2.flow.compute(value=3, threshold=10).value == 30 +``` + +Without `Lazy[T]`, the upstream model would always run. With it, the function +controls exactly when (and whether) the dependency executes. + +## When To Use `@Flow.model` + +Use `@Flow.model` when: + +- the stage logic is naturally a plain function, +- you want ordinary inputs to look like ordinary Python function parameters, +- the contextual contract is small and explicit, +- the main goal is easy graph authoring on top of existing ccflow machinery. + +Use a hand-written class-based `CallableModel` when: + +- the model needs custom methods or substantial internal state, +- the full context object is the natural primary interface, +- the stage is no longer best expressed as one function and a small amount of + wiring. + +## Troubleshooting + +**`compute()` says a field is not contextual** + +That field is a regular parameter. Bind it at construction time. Only +`FromContext[...]` fields belong in `compute()`. + +**`with_context()` rejects a field** + +`with_context()` only rewrites contextual inputs. If you are trying to attach one +stage to another, pass the upstream model as a regular argument at construction +time. + +**A contextual parameter still shows up in `context_inputs` after I bound it** + +That is expected. `context_inputs` reports the declared contextual contract of +the model or wrapped model. It does not mean the current wrapper still requires +the caller to provide that field. + +Use `model.flow.inspect().runtime_inputs` to see the effective direct runtime +context inputs after `with_context(...)` bindings. Use +`model.flow.inspect().required_inputs` to see what still must be provided by the +caller. Static transforms with no contextual parameters may be evaluated during +introspection, so their output fields can be removed from `runtime_inputs` and +`required_inputs` or added to `bound_inputs`. + +**A shared dependency runs more than once** + +`@Flow.model` authors the graph cleanly, but execution still follows the normal +ccflow evaluator path. If you need deduplication or graph scheduling, use the +appropriate evaluators and cache settings just as you would for class-based +`CallableModel`s. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..724c853 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -22,6 +22,10 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan `CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +`@Flow.model` is a plain-function API for defining `CallableModel`s with normal Python function signatures. +It keeps the same evaluator, registry, cache, and result machinery while making contextual execution more ergonomic via helpers like `.flow.compute(...)` and `.flow.with_context(...)`. +See [Flow Model](Flow-Model) for the full guide. + ## Model Registry A `ModelRegistry` is a named collection of models. diff --git a/docs/wiki/Workflows.md b/docs/wiki/Workflows.md index 6789d79..6ef92b1 100644 --- a/docs/wiki/Workflows.md +++ b/docs/wiki/Workflows.md @@ -355,6 +355,12 @@ The behavior of the `@Flow.call` decorator can be controlled in several ways: - by setting `options` in the `meta` attribute of the CallableModel - by passing `_options` directly to the `__call__` method +For hand-written `CallableModel` classes, `@Flow.call(auto_context=True)` is +also available when the `__call__` method should declare context fields as +keyword-only parameters instead of accepting one explicit context object. This +is an opt-in `Flow.call` feature; it does not add the dependency wiring or +`FromContext[...]` semantics provided by `@Flow.model`. + An example of the first one (model-specific options) is to disable validation of the result type on a particular model ```python @@ -542,12 +548,17 @@ For a direct `CallableModel` call, the cache key depends on: - `model.model_dump(mode="python")` - `compute_behavior_token(type(model))` -For a `ModelEvaluationContext`, the cache key depends on: +For a `ModelEvaluationContext`, the default structural cache key depends on: - the underlying context payload plus the function name being evaluated (`__call__` vs `__deps__`) - the behavior token of the underlying model class - the data and behavior tokens of any **non-transparent** evaluators in the chain +`MemoryCacheEvaluator` and graph construction request an effective key for +generated `@Flow.model` nodes. That effective key can project a runtime +`FlowContext` down to the contextual fields a generated node actually reads, so +unused ambient fields do not necessarily split memory-cache or graph identity. + Transparent evaluators are skipped, so wrapping a model with logging or other pass-through evaluators does not change its cache identity. `compute_behavior_token()` hashes the class's Python-defined methods and also consults `__ccflow_tokenizer_deps__` for behavior that lives outside the class body. This is useful when the callable depends on module-level helpers or shared helper classes that should also invalidate the cache key when they change. diff --git a/docs/wiki/_Sidebar.md b/docs/wiki/_Sidebar.md index 6b20b2f..75c31c0 100644 --- a/docs/wiki/_Sidebar.md +++ b/docs/wiki/_Sidebar.md @@ -13,6 +13,7 @@ Notes for editors: - [Installation](Installation) - [Design Goals](Design-Goals) - [Key Features](Key-Features) +- [Flow Model](Flow-Model) - [First Steps](First-Steps) **Tutorials**