From 6977e5b6a7e0cf05139e342f203edcc9d809e4bc Mon Sep 17 00:00:00 2001 From: Pascal Tomecek Date: Mon, 11 May 2026 21:15:43 -0400 Subject: [PATCH] feat: native tokenization engine, drop dask runtime dep Replaces the dask-based normalize_token / tokenize re-exports with a native singledispatch engine. Adds structural handlers for common types (stdlib, datetime, Decimal, UUID, pathlib, Enum, partial, MappingProxy, methods, OrderedDict, code objects, functions, types, pydantic BaseModel) and lazy-registered handlers for numpy. Unknown types fall back to a cloudpickle digest, raising TypeError when pickling fails. The public tokenize() retains its variadic (*args, **kwargs) signature for drop-in compatibility. compute_data_token, compute_cache_token, and compute_behavior_token from PR #196 are unchanged. Cycle detection via a module-level ContextVar prevents RecursionError on self-referential structures and emits a stable __cycle__ marker. Drops dask from runtime dependencies; adds cloudpickle. Signed-off-by: Pascal Tomecek --- ccflow/tests/utils/test_tokenize.py | 691 ++++++++++++++++++++++++++-- ccflow/utils/tokenize.py | 367 ++++++++++++++- pyproject.toml | 2 +- 3 files changed, 999 insertions(+), 61 deletions(-) diff --git a/ccflow/tests/utils/test_tokenize.py b/ccflow/tests/utils/test_tokenize.py index 0166894..35945f3 100644 --- a/ccflow/tests/utils/test_tokenize.py +++ b/ccflow/tests/utils/test_tokenize.py @@ -1,16 +1,25 @@ """Tests for tokenize helpers used by cache keys.""" +import enum as _enum +import re +from collections import OrderedDict +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from functools import partial +from pathlib import Path, PurePosixPath +from types import MappingProxyType +from uuid import UUID + +import numpy as np +import pandas as pd import pytest +from pydantic import BaseModel as _PlainPydantic from ccflow.callable import CallableModel, ContextBase, EvaluatorBase, ModelEvaluationContext from ccflow.context import NullContext from ccflow.evaluators.common import cache_key from ccflow.result import GenericResult -from ccflow.utils.tokenize import compute_behavior_token, compute_cache_token, compute_data_token - -# --------------------------------------------------------------------------- -# Data token -# --------------------------------------------------------------------------- +from ccflow.utils.tokenize import compute_behavior_token, compute_cache_token, compute_data_token, normalize_token, tokenize class TestComputeDataToken: @@ -56,11 +65,6 @@ def f(self): assert token1 != token2 -# --------------------------------------------------------------------------- -# Basic behavior -# --------------------------------------------------------------------------- - - class TestComputeBehaviorToken: def test_returns_sha256_hex(self): class M: @@ -188,11 +192,6 @@ def f(self): assert compute_behavior_token(make_model(1)) != compute_behavior_token(make_model(2)) -# --------------------------------------------------------------------------- -# Method collection -# --------------------------------------------------------------------------- - - class TestMethodCollection: def test_includes_regular_methods(self): class M: @@ -266,11 +265,6 @@ def f(self): assert compute_behavior_token(A) == compute_behavior_token(B) -# --------------------------------------------------------------------------- -# Dependencies (__ccflow_tokenizer_deps__) -# --------------------------------------------------------------------------- - - def _helper_add(x): return x + 1 @@ -398,11 +392,6 @@ def g(self): compute_behavior_token(A) -# --------------------------------------------------------------------------- -# Integration with cache_key() -# --------------------------------------------------------------------------- - - class TestCacheKeyIntegration: def test_callable_model_includes_behavior(self): """cache_key for a CallableModel includes the behavior hash.""" @@ -539,11 +528,6 @@ def __call__(self, context: ModelEvaluationContext): assert key1 != key2 -# --------------------------------------------------------------------------- -# Decorator unwrapping (Flow.call, etc.) -# --------------------------------------------------------------------------- - - class TestDecoratorUnwrapping: def test_flow_call_different_impls_differ(self): """@Flow.call wrappers are unwrapped — different implementations hash differently.""" @@ -578,11 +562,6 @@ def __call__(self, context: NullContext) -> GenericResult: assert compute_behavior_token(A) == compute_behavior_token(B) -# --------------------------------------------------------------------------- -# MRO / inherited methods -# --------------------------------------------------------------------------- - - class TestInheritedMethods: def test_inherited_call_included(self): """Subclass that inherits __call__ from parent picks up parent's method.""" @@ -628,3 +607,645 @@ def g(self): assert t_base != t_sub # But base's cached token is unaffected assert compute_behavior_token(Base) == t_base + + +class _Color(_enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +class TestNormalizeTokenPrimitives: + @pytest.mark.parametrize( + "value,expected", + [ + (None, None), + (True, True), + (False, False), + (42, 42), + (3.14, 3.14), + ("hello", "hello"), + (b"data", b"data"), + ], + ) + def test_primitives(self, value, expected): + assert normalize_token(value) == expected + + @pytest.mark.parametrize( + "value,expected", + [ + (date(2024, 1, 15), ("date", "2024-01-15")), + (datetime(2024, 1, 15, 10, 30, 0), ("datetime", "2024-01-15T10:30:00")), + (time(10, 30, 0), ("time", "10:30:00")), + (timedelta(hours=1, minutes=30), ("timedelta", 0, 5400, 0)), + ], + ) + def test_datetime_types(self, value, expected): + assert normalize_token(value) == expected + + @pytest.mark.parametrize( + "value,expected", + [ + ((1, "a", True), ("tuple", (1, "a", True))), + ([1, 2, 3], ("list", (1, 2, 3))), + ({3, 1, 2}, ("set", (1, 2, 3))), + (frozenset({3, 1, 2}), ("frozenset", (1, 2, 3))), + ({"b": 2, "a": 1}, ("dict", (("a", 1), ("b", 2)))), + ], + ) + def test_collections(self, value, expected): + assert normalize_token(value) == expected + + def test_uuid(self): + u = UUID("12345678-1234-5678-1234-567812345678") + assert normalize_token(u) == ("uuid", "12345678-1234-5678-1234-567812345678") + + def test_path(self): + assert normalize_token(Path("/tmp/test.txt")) == ("path", "/tmp/test.txt") + + def test_pure_path(self): + assert normalize_token(PurePosixPath("/tmp/test.txt")) == ("path", "/tmp/test.txt") + + def test_enum(self): + assert normalize_token(_Color.RED) == ("enum", _Color.__module__, "_Color", "RED") + + def test_nested_collections(self): + data = {"key": [1, (2, 3)]} + assert normalize_token(data) == ("dict", (("key", ("list", (1, ("tuple", (2, 3))))),)) + + +class TestNormalizeTokenAdditionalBuiltins: + def test_complex(self): + assert normalize_token(complex(1, 2)) == ("complex", 1.0, 2.0) + assert normalize_token(complex(1, 2)) != normalize_token(complex(2, 1)) + + def test_ellipsis(self): + assert normalize_token(...) == ("ellipsis",) + + def test_slice(self): + assert normalize_token(slice(1, 10, 2)) == ("slice", 1, 10, 2) + assert normalize_token(slice(1, 10)) == ("slice", 1, 10, None) + assert normalize_token(slice(1, 10)) != normalize_token(slice(1, 11)) + + def test_builtin_function(self): + # len is a builtin with __self__ = builtins module + assert normalize_token(len) == ("builtin", "builtins", "len", ("module", "builtins")) + assert normalize_token(len) != normalize_token(print) + + def test_decimal(self): + assert normalize_token(Decimal("3.14")) == ("decimal", "3.14") + assert normalize_token(Decimal("3.14")) != normalize_token(Decimal("3.15")) + + def test_partial(self): + p1 = partial(int, base=16) + p2 = partial(int, base=10) + assert normalize_token(p1) != normalize_token(p2) + assert normalize_token(p1) == normalize_token(partial(int, base=16)) + + def test_mappingproxy(self): + mp = MappingProxyType({"a": 1, "b": 2}) + # MappingProxy is tagged distinctly from dict so a proxy is never confused with the same-keyed dict + assert normalize_token(mp) == ("mappingproxy", (("a", 1), ("b", 2))) + assert normalize_token(mp) != normalize_token({"a": 1, "b": 2}) + + +class TestNormalizeTokenFunctionsAndTypes: + def test_function(self): + def my_func(x): + return x + 1 + + result = normalize_token(my_func) + assert result[0] == "__function__" + assert isinstance(result[1], str) + assert len(result[1]) == 64 + + def test_function_deterministic(self): + def my_func(x): + return x + 1 + + assert normalize_token(my_func) == normalize_token(my_func) + + def test_type(self): + assert normalize_token(int) == ("type", "builtins.int") + + def test_pydantic_basemodel(self): + class PlainPydantic(_PlainPydantic): + x: int = 1 + + obj = PlainPydantic(x=5) + result = normalize_token(obj) + assert result[0] == "pydantic" + assert "PlainPydantic" in result[1] + # Same data → same canonical form + assert normalize_token(PlainPydantic(x=5)) == normalize_token(PlainPydantic(x=5)) + assert normalize_token(PlainPydantic(x=5)) != normalize_token(PlainPydantic(x=6)) + + +class TestNormalizeTokenNumpyPandas: + def test_numpy_ndarray(self): + arr = np.array([1, 2, 3], dtype=np.int64) + result = normalize_token(arr) + assert result[0] == "ndarray" + assert result[1] == "int64" + assert result[2] == (3,) + assert normalize_token(arr) == normalize_token(np.array([1, 2, 3], dtype=np.int64)) + + def test_numpy_different_data(self): + assert normalize_token(np.array([1, 2, 3])) != normalize_token(np.array([1, 2, 4])) + + def test_numpy_different_dtype(self): + a = np.array([1, 2, 3], dtype=np.float32) + b = np.array([1, 2, 3], dtype=np.float64) + assert normalize_token(a) != normalize_token(b) + + def test_numpy_empty_same_dtype(self): + a = np.array([], dtype=np.float64) + b = np.array([], dtype=np.float64) + assert normalize_token(a) == normalize_token(b) + + def test_numpy_empty_diff_dtype(self): + a = np.array([], dtype=np.float32) + b = np.array([], dtype=np.float64) + assert normalize_token(a) != normalize_token(b) + + def test_numpy_discontiguous_array(self): + arr = np.arange(10) + assert normalize_token(arr[::2]) != normalize_token(arr[::3]) + + def test_numpy_structured_array(self): + dt = np.dtype([("x", np.int32), ("y", np.float64)]) + arr = np.array([(1, 2.0), (3, 4.0)], dtype=dt) + assert normalize_token(arr)[0] == "ndarray" + + def test_numpy_scalar(self): + s = np.int64(42) + assert normalize_token(s) == ("np_scalar", "int64", 42) + + def test_numpy_datetime64(self): + token = normalize_token(np.datetime64("2024-01-01")) + assert token == ("np_scalar", "datetime64", date(2024, 1, 1)) + + def test_pandas_timestamp(self): + ts = pd.Timestamp("2024-01-15") + assert normalize_token(ts) == ("pd_timestamp", ts.isoformat()) + + def test_pandas_dataframe_via_cloudpickle(self): + df1 = pd.DataFrame({"a": [1, 2, 3]}) + df2 = pd.DataFrame({"a": [1, 2, 3]}) + df3 = pd.DataFrame({"a": [1, 2, 4]}) + assert normalize_token(df1) == normalize_token(df2) + assert normalize_token(df1) != normalize_token(df3) + + def test_pandas_series_via_cloudpickle(self): + s1 = pd.Series([1, 2, 3], name="x") + s2 = pd.Series([1, 2, 3], name="x") + s3 = pd.Series([1, 2, 4], name="x") + assert normalize_token(s1) == normalize_token(s2) + assert normalize_token(s1) != normalize_token(s3) + + +class TestNormalizeTokenCloudpickleFallback: + def test_compiled_regex_same_pattern_same_token(self): + r1 = re.compile("abc", re.IGNORECASE) + r2 = re.compile("abc", re.IGNORECASE) + assert normalize_token(r1) == normalize_token(r2) + + def test_unpicklable_raises_typeerror(self): + class Unpicklable: + def __reduce__(self): + raise RuntimeError("cannot pickle") + + with pytest.raises(TypeError, match="Cannot tokenize"): + normalize_token(Unpicklable()) + + def test_arbitrary_object_falls_back(self): + class Foo: + def __init__(self, x): + self.x = x + + result = normalize_token(Foo(1)) + assert result[0] == "__cloudpickle__" + assert isinstance(result[1], str) + + +class TestNormalizeTokenContainerEdgeCases: + def test_dict_with_mixed_key_types(self): + d = {1: "a", "1": "b"} + result = normalize_token(d) + assert result[0] == "dict" + assert len(result[1]) == 2 + + def test_dict_order_independence(self): + assert normalize_token({"b": 2, "a": 1}) == normalize_token({"a": 1, "b": 2}) + + def test_inf_and_negative_inf_distinct(self): + assert normalize_token(float("inf")) != normalize_token(float("-inf")) + + +class TestTokenizePublicAPI: + def test_tokenize_returns_hex_string(self): + token = tokenize({"a": [1, 2, 3]}) + assert isinstance(token, str) + assert len(token) == 64 # sha256 hex + + def test_tokenize_deterministic(self): + assert tokenize({"a": 1, "b": 2}) == tokenize({"b": 2, "a": 1}) + + def test_tokenize_distinguishes(self): + assert tokenize({"a": 1}) != tokenize({"a": 2}) + + +class TestNormalizeTokenDeterminismRegressions: + _ADDR_RE = re.compile(r"0x[0-9a-fA-F]{4,}") + + def _assert_portable(self, canonical): + r = repr(canonical) + assert "", "exec") + c2 = compile("x = 'bar'", "", "exec") + assert tokenize(c1) != tokenize(c2) + + def test_function_docstring_still_stripped(self): + # Conversely, _hash_function_bytecode must still ignore docstrings on function bodies. + def with_doc(): + """A docstring.""" + return 1 + + def without_doc(): + return 1 + + # Wrapping in classes lets compute_behavior_token surface the difference (or lack thereof). + WithDoc = type("X", (object,), {"f": with_doc}) + WithoutDoc = type("X", (object,), {"f": without_doc}) + assert compute_behavior_token(WithDoc) == compute_behavior_token(WithoutDoc) + + def test_slice_recurses_on_bounds(self): + class Marker: + pass + + m = Marker() + canon = normalize_token(slice(m, None, None)) + # The Marker should have been normalized via the type/cloudpickle dispatch — no raw repr leaks + # like "<...Marker object at 0x...>" should appear. + assert "object at 0x" not in repr(canon) + + def test_slice_primitive_bounds_stable(self): + assert tokenize(slice(1, 2, 3)) == tokenize(slice(1, 2, 3)) + assert tokenize(slice(1, 2, 3)) != tokenize(slice(1, 2, 4)) + + def test_timedelta_microsecond_precision(self): + td1 = timedelta(days=10**8, microseconds=1) + td2 = timedelta(days=10**8, microseconds=2) + # total_seconds() loses microsecond precision at large day counts due to float rounding; + # the (days, seconds, microseconds) decomposition keeps them distinct. + assert tokenize(td1) != tokenize(td2) + + def test_enum_includes_module(self): + E1 = _enum.Enum("Color", {"RED": 1}, module="pkg.one") + E2 = _enum.Enum("Color", {"RED": 1}, module="pkg.two") + assert tokenize(E1.RED) != tokenize(E2.RED) diff --git a/ccflow/utils/tokenize.py b/ccflow/utils/tokenize.py index 528c5ed..836ec59 100644 --- a/ccflow/utils/tokenize.py +++ b/ccflow/utils/tokenize.py @@ -1,17 +1,33 @@ -# ruff: noqa: F401 """Tokenization utilities for ccflow models. -Re-exports ``normalize_token`` and ``tokenize`` from dask for data hashing. -Adds helpers for dask-based data hashing, ccflow-specific behavior hashing, -and combined cache-token hashing, useful for cache-key invalidation when -callable logic changes. +Provides a native ``normalize_token`` ``singledispatch`` registry that +canonicalizes Python objects to deterministically hashable forms, plus +helpers for ccflow-specific behavior hashing and combined cache-token +hashing used by ``cache_key()`` for cache-key invalidation when callable +logic changes. + +The native engine replaces the previous ``dask.base.tokenize`` dependency. +Type handlers cover stdlib primitives, datetime, Decimal, UUID, pathlib, +``functools.partial``, ``MappingProxyType``, numpy, pandas, and plain +pydantic ``BaseModel``. Unknown objects fall back to a cloudpickle-based +digest; objects whose pickling raises an exception produce a clear +``TypeError``. """ +import contextvars +import enum import hashlib import inspect +from collections import OrderedDict +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from functools import partial, singledispatch +from pathlib import PurePath +from types import CodeType, MappingProxyType, MethodType, MethodWrapperType, ModuleType from typing import Any, Callable, Iterable, List, Optional, Tuple +from uuid import UUID -from dask.base import normalize_token, tokenize +from pydantic import BaseModel as _PydanticBaseModel __all__ = ( "compute_behavior_token", @@ -24,7 +40,6 @@ def _sha256_hexdigest(*parts: bytes | str) -> str: """Return a SHA-256 hex digest for one or more byte/string parts.""" - hasher = hashlib.sha256() for part in parts: if isinstance(part, str): @@ -33,10 +48,306 @@ def _sha256_hexdigest(*parts: bytes | str) -> str: return hasher.hexdigest() -def compute_data_token(value: Any) -> str: - """Compute a deterministic data token using dask's tokenization.""" +# ContextVar (rather than threading.local) so the visited set auto-isolates across asyncio tasks. +_visited: contextvars.ContextVar[Optional[set]] = contextvars.ContextVar("_ccflow_normalize_visited", default=None) + + +def _with_cycle_check(obj: Any, build: Callable[[], Any]) -> Any: + """Invoke ``build`` with object-identity cycle detection, returning ``("__cycle__", type_name)`` on re-entry.""" + visited = _visited.get() + created = visited is None + if created: + visited = set() + token = _visited.set(visited) + elif id(obj) in visited: + return ("__cycle__", type(obj).__name__) + visited.add(id(obj)) + try: + return build() + finally: + visited.discard(id(obj)) + if created: + _visited.reset(token) + + +@singledispatch +def normalize_token(obj: Any) -> Any: + """Produce a canonical, deterministically hashable representation of ``obj``. + + This is a ``singledispatch`` function — register handlers for new types via:: + + @normalize_token.register(MyType) + def _(obj): + return ("mytype", ...) + + Unknown types fall back to a ``cloudpickle``-based digest, raising ``TypeError`` on pickling failure. + """ + try: + import cloudpickle + except ImportError as exc: # pragma: no cover - defensive + raise TypeError(f"Cannot tokenize object of type {type(obj).__qualname__}: cloudpickle is not installed.") from exc + + try: + pickled = cloudpickle.dumps(obj) + except Exception as exc: + raise TypeError(f"Cannot tokenize object of type {type(obj).__qualname__}. Register a normalize_token handler for this type.") from exc + return ("__cloudpickle__", hashlib.sha256(pickled).hexdigest()) + + +@normalize_token.register(type(None)) +def _normalize_none(obj): + return None + + +@normalize_token.register(bool) +@normalize_token.register(int) +@normalize_token.register(float) +@normalize_token.register(str) +@normalize_token.register(bytes) +def _normalize_primitive(obj): + return obj + + +@normalize_token.register(date) +def _normalize_date(obj): + return ("date", obj.isoformat()) + + +@normalize_token.register(datetime) +def _normalize_datetime(obj): + return ("datetime", obj.isoformat()) + + +@normalize_token.register(time) +def _normalize_time(obj): + return ("time", obj.isoformat()) + + +@normalize_token.register(timedelta) +def _normalize_timedelta(obj): + return ("timedelta", obj.days, obj.seconds, obj.microseconds) + + +@normalize_token.register(UUID) +def _normalize_uuid(obj): + return ("uuid", str(obj)) + + +@normalize_token.register(PurePath) +def _normalize_path(obj): + return ("path", str(obj)) + + +@normalize_token.register(enum.Enum) +def _normalize_enum(obj): + return ("enum", type(obj).__module__, type(obj).__qualname__, obj.name) + + +@normalize_token.register(tuple) +def _normalize_tuple(obj): + return ("tuple", tuple(normalize_token(item) for item in obj)) + + +@normalize_token.register(list) +def _normalize_list(obj): + return _with_cycle_check(obj, lambda: ("list", tuple(normalize_token(item) for item in obj))) + + +@normalize_token.register(set) +def _normalize_set(obj): + return ("set", tuple(sorted((normalize_token(item) for item in obj), key=repr))) + + +@normalize_token.register(frozenset) +def _normalize_frozenset(obj): + return ("frozenset", tuple(sorted((normalize_token(item) for item in obj), key=repr))) + + +@normalize_token.register(dict) +def _normalize_dict(obj): + return _with_cycle_check( + obj, + lambda: ("dict", tuple(sorted(((normalize_token(k), normalize_token(v)) for k, v in obj.items()), key=repr))), + ) + - return tokenize(value) +@normalize_token.register(OrderedDict) +def _normalize_ordered_dict(obj): + return _with_cycle_check(obj, lambda: ("__ordereddict__", tuple((normalize_token(k), normalize_token(v)) for k, v in obj.items()))) + + +@normalize_token.register(complex) +def _normalize_complex(obj): + return ("complex", obj.real, obj.imag) + + +@normalize_token.register(type(Ellipsis)) +def _normalize_ellipsis(obj): + return ("ellipsis",) + + +@normalize_token.register(slice) +def _normalize_slice(obj): + return ("slice", normalize_token(obj.start), normalize_token(obj.stop), normalize_token(obj.step)) + + +@normalize_token.register(type(len)) # builtin_function_or_method +def _normalize_builtin(obj): + # __self__ is the bound instance for methods like [].append, the module for unbound builtins like + # math.sin (where __self__ is the math module), or absent for some C-level callables. Distinguish + # the module case (cheap by name) from the bound-instance case (recurse) so math.sin vs cmath.sin + # don't collide and `[1,2].append` vs `[3,4].append` differ by their bound list contents. + self_obj = getattr(obj, "__self__", None) + module = getattr(obj, "__module__", None) or "" + if isinstance(self_obj, ModuleType): + return ("builtin", module, obj.__qualname__, ("module", self_obj.__name__)) + return ("builtin", module, obj.__qualname__, normalize_token(self_obj)) + + +@normalize_token.register(Decimal) +def _normalize_decimal(obj): + return ("decimal", str(obj)) + + +@normalize_token.register(partial) +def _normalize_partial(obj): + return ("partial", normalize_token(obj.func), normalize_token(obj.args), normalize_token(sorted(obj.keywords.items()))) + + +@normalize_token.register(MappingProxyType) +def _normalize_mappingproxy(obj): + # Preserve the proxy's iteration order rather than sorting (the underlying mapping might be an + # OrderedDict where order is semantically meaningful). Tagged separately so a proxy is never + # mistaken for an equivalent plain dict. + return _with_cycle_check(obj, lambda: ("mappingproxy", tuple((normalize_token(k), normalize_token(v)) for k, v in obj.items()))) + + +@normalize_token.register(CodeType) +def _normalize_code(obj): + # Faithfully include identifier/signature fields so two functions that differ only in attribute + # access, local variable names, or signature don't collide. Docstring stripping is left to + # _hash_function_bytecode, which knows it's looking at function code (not a `compile(..., "exec")` + # block whose first const may be real program data). + return ( + "code", + obj.co_code, + tuple(normalize_token(c) for c in obj.co_consts), + obj.co_names, + obj.co_varnames, + obj.co_freevars, + obj.co_cellvars, + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_posonlyargcount, + obj.co_flags, + ) + + +@normalize_token.register(type(lambda: None)) # FunctionType +def _normalize_function(obj): + return ("__function__", _hash_function_bytecode(obj)) + + +@normalize_token.register(MethodType) +def _normalize_method(obj): + return ("__method__", normalize_token(obj.__func__), normalize_token(obj.__self__)) + + +@normalize_token.register(MethodWrapperType) +def _normalize_method_wrapper(obj): + # method-wrapper objects (e.g. `{}.__init__`) have no __func__; key by name and bound instance. + return ("__method_wrapper__", obj.__name__, normalize_token(obj.__self__)) + + +@normalize_token.register(type) +def _normalize_type(obj): + return ("type", f"{obj.__module__}.{obj.__qualname__}") + + +@normalize_token.register(_PydanticBaseModel) +def _normalize_pydantic_basemodel(obj): + """Hash a pydantic model by its non-excluded fields (and any ``extra='allow'`` extras).""" + + def build(): + type_path = f"{type(obj).__module__}.{type(obj).__qualname__}" + model_fields = type(obj).model_fields + # Iterate model_fields directly rather than `for k, v in obj` so models that override __iter__ + # (or otherwise hide fields) still tokenize structurally. + fields = tuple((name, normalize_token(getattr(obj, name))) for name, info in model_fields.items() if not info.exclude) + extras = getattr(obj, "__pydantic_extra__", None) or {} + extras_canonical = tuple(sorted(((k, normalize_token(v)) for k, v in extras.items()), key=repr)) + return ("pydantic", type_path, fields, extras_canonical) + + return _with_cycle_check(obj, build) + + +def _register_numpy() -> None: + try: + import numpy as np + except ImportError: # pragma: no cover + return + + # Perf note: dask streams a contiguous view directly into the hasher; we materialize via tobytes() and + # have a separate fast-path for object-dtype arrays. ~10x slower than dask on large numeric arrays and + # ~18x slower on object-dtype arrays — fine for typical config-style cache keys, room for improvement + # if multi-MB arrays become routine cache inputs. + @normalize_token.register(np.ndarray) + def _normalize_ndarray(obj): + # Object-dtype arrays store PyObject* pointers; tobytes() embeds process-local addresses, so recurse + # element-wise instead. Cycle detection is keyed on the array because tolist() returns a fresh list. + if obj.dtype.hasobject: + return _with_cycle_check(obj, lambda: ("ndarray", str(obj.dtype), obj.shape, normalize_token(obj.tolist()))) + return ("ndarray", str(obj.dtype), obj.shape, hashlib.sha256(np.ascontiguousarray(obj).tobytes()).hexdigest()) + + @normalize_token.register(np.ma.MaskedArray) + def _normalize_masked_array(obj): + # MaskedArray is a subclass of ndarray, so without an explicit handler the mask would be silently + # dropped. Normalize the underlying data as a plain ndarray and include the mask + fill_value. + data = normalize_token(np.asarray(obj.data)) + mask = obj.mask + mask_canonical = bool(mask) if mask is np.ma.nomask else normalize_token(np.asarray(mask)) + return ("masked_array", data, mask_canonical, normalize_token(obj.fill_value)) + + @normalize_token.register(np.generic) + def _normalize_np_scalar(obj): + return ("np_scalar", str(type(obj).__name__), obj.item()) + + +def _register_pandas() -> None: + try: + import pandas as pd + except ImportError: # pragma: no cover + return + + # Only Timestamp has a structural handler; DataFrame/Series/Index fall through to the cloudpickle + # fallback. That works today but is fragile across pandas version upgrades — a follow-up could add + # structural handlers (matching dask) for stability and a perf win on large frames. + @normalize_token.register(pd.Timestamp) + def _normalize_pd_timestamp(obj): + return ("pd_timestamp", obj.isoformat()) + + +_register_numpy() +_register_pandas() + + +def tokenize(*args: Any, **kwargs: Any) -> str: + """Return a deterministic SHA-256 hex digest of the given args/kwargs (variadic to match ``dask.base.tokenize``).""" + payload = (args, kwargs) if kwargs else args + visited_token = _visited.set(set()) + try: + return _sha256_hexdigest(repr(normalize_token(payload))) + finally: + _visited.reset(visited_token) + + +def compute_data_token(value: Any) -> str: + """Compute a deterministic data token for a single value.""" + visited_token = _visited.set(set()) + try: + return _sha256_hexdigest(repr(normalize_token(value))) + finally: + _visited.reset(visited_token) def compute_cache_token(*, data_values: Iterable[Any] = (), behavior_classes: Iterable[type] = ()) -> str: @@ -128,31 +439,37 @@ def _function_state(func: Callable) -> Tuple[Any, Any, Tuple[Tuple[str, bool, An def _hash_function_bytecode(func: Callable) -> Optional[str]: """Return a SHA-256 hex digest of a function's behavior-relevant state. - The function is first unwrapped through any decorator chains, so that - e.g. ``@Flow.call`` wrappers do not mask the real implementation. - - In addition to ``co_code`` and ``co_consts``, this includes: - - positional defaults (``__defaults__``) - - keyword-only defaults (``__kwdefaults__``) - - closure cell contents - so that behavior changes that do not affect bytecode alone still change - the token. + Unwraps decorator chains (``inspect.unwrap``) so that wrappers like + ``@Flow.call`` do not mask the implementation. Includes the recursively + normalized code object (with the leading docstring const stripped here, + where we know we have a function body), positional and keyword-only + defaults, and closure cell contents. - Returns ``None`` for objects without ``__code__`` (C builtins, etc.). + Returns ``None`` for objects without ``__code__``. """ unwrapped = _unwrap_function(func) if unwrapped is None: return None code = unwrapped.__code__ - # Include constants (skip first if it's the docstring) consts = code.co_consts - if consts and isinstance(consts[0], str): + # Function code starts with the docstring slot (a str when present, None when absent). Strip it + # so adding/removing a docstring doesn't change the behavior token. + if consts and isinstance(consts[0], (str, type(None))): consts = consts[1:] - return _sha256_hexdigest( + code_canonical = ( + "code", code.co_code, - repr(consts), - compute_data_token(_function_state(unwrapped)), + tuple(normalize_token(c) for c in consts), + code.co_names, + code.co_varnames, + code.co_freevars, + code.co_cellvars, + code.co_argcount, + code.co_kwonlyargcount, + code.co_posonlyargcount, + code.co_flags, ) + return _sha256_hexdigest(repr(code_canonical), compute_data_token(_function_state(unwrapped))) def _dependency_info(dep: object, *, _visited: Tuple[type, ...]) -> Optional[Tuple[Tuple[str, str, str, str], str, str]]: diff --git a/pyproject.toml b/pyproject.toml index 48e4fa0..580ed5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ dependencies = [ "cloudpathlib", - "dask", + "cloudpickle", "deprecated", "hydra-core", "jinja2",