Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 6 additions & 67 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@

import collections.abc
import copy
import inspect
import logging
import pathlib
import platform
import sys
import warnings
from types import GenericAlias, MappingProxyType
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar

import omegaconf
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Expand Down Expand Up @@ -90,66 +87,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None)
return deps


# Pydantic 2 has different handling of serialization.
# This requires some workarounds at the moment until the feature is added to easily get a mode that
# is compatible with Pydantic 1
# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel,
# such that the new annotation contains SerializeAsAny
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# https://github.com/pydantic/pydantic-core/pull/740
# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation
# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478
from pydantic._internal._model_construction import ModelMetaclass # noqa: E402

_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10")


def _adjust_annotations(annotation):
origin = get_origin(annotation)
args = get_args(annotation)
if not _IS_PY39:
from types import UnionType

if origin is UnionType:
origin = Union

if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)):
return SerializeAsAny[annotation]
elif origin and args:
# Filter out typing.Type and generic types
if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)):
return annotation
elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39
return ClassVar[_adjust_annotations(args[0])]
else:
try:
return origin[tuple(_adjust_annotations(arg) for arg in args)]
except TypeError:
raise TypeError(f"Could not adjust annotations for {origin}")
else:
return annotation


class _SerializeAsAnyMeta(ModelMetaclass):
def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs):
annotations: dict = namespaces.get("__annotations__", {})

for base in bases:
for base_ in base.__mro__:
if base_ is PydanticBaseModel:
annotations.update(base_.__annotations__)

for field, annotation in annotations.items():
if not field.startswith("__"):
annotations[field] = _adjust_annotations(annotation)

namespaces["__annotations__"] = annotations

return super().__new__(self, name, bases, namespaces, **kwargs)


class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta):
class BaseModel(PydanticBaseModel, _RegistryMixin):
"""BaseModel is a base class for all pydantic models within the ccflow framework.

This gives us a way to add functionality to the framework, including
Expand Down Expand Up @@ -204,6 +142,8 @@ def type_(self) -> PyObjectPath:
# where the default behavior is just to drop the mis-named value. This prevents that
extra="forbid",
ser_json_timedelta="float",
# Polymorphic serialization is the behavior of allowing a subclass of a model (or Pydantic dataclass) to override serialization so that the subclass' serialization is used, rather than the original model types's serialization. This will expose all the data defined on the subclass in the serialized payload.
polymorphic_serialization=True,
)

def __str__(self):
Expand Down Expand Up @@ -238,8 +178,7 @@ def get_widget(

kwargs = {"fallback": str, "mode": "json"}
kwargs.update(json_kwargs or {})
# Can't use self.model_dump_json or self.model_dump because they don't expose the fallback argument
return JSON(self.__pydantic_serializer__.to_python(self, **kwargs), **(widget_kwargs or {}))
return JSON(self.model_dump(**kwargs), **(widget_kwargs or {}))

def __panel__(self):
"""Return a Panel viewable for this model.
Expand Down
55 changes: 13 additions & 42 deletions ccflow/tests/test_base_serialize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pickle
import platform
import unittest
from typing import Annotated, ClassVar, Dict, List, Optional, Type, Union
from typing import Annotated, Optional

import numpy as np
from packaging import version
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, ValidationError

from ccflow import BaseModel, NDArray
Expand Down Expand Up @@ -162,6 +160,18 @@ def test_serialization_enum(self):
def test_serialization_nested_subclass(self):
self._check_serialization(NestedModel(a=ChildModel(field1=0, field2=10)))

def test_serialization_nested_subclass_uses_duck_typing(self):
model = NestedModel(a=ChildModel(field1=0, field2=10))

assert model.model_dump(mode="python") == {
"a": {
"field1": 0,
"field2": 10,
"type_": "ccflow.tests.test_base_serialize.ChildModel",
},
"type_": "ccflow.tests.test_base_serialize.NestedModel",
}

def test_from_str_serialization(self):
serialized = '{"_target_": "ccflow.tests.test_base_serialize.ChildModel", "field1": 9, "field2": 4}'
deserialized = BaseModel.model_validate_json(serialized)
Expand Down Expand Up @@ -213,45 +223,6 @@ class C(PydanticBaseModel):
# C implements the normal pydantic BaseModel whichhould allow extra fields.
_ = C(extra_field1=1)

def test_serialize_as_any(self):
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# This test could be removed once there is a different solution to the issue above
from pydantic import SerializeAsAny
from pydantic.types import constr

if version.parse(platform.python_version()) >= version.parse("3.10"):
pipe_union = A | int
else:
pipe_union = Union[A, int]

class MyNestedModel(BaseModel):
a1: A
a2: Optional[Union[A, int]]
a3: Dict[str, Optional[List[A]]]
a4: ClassVar[A]
a5: Type[A]
a6: constr(min_length=1)
a7: pipe_union

target = {
"a1": SerializeAsAny[A],
"a2": Optional[Union[SerializeAsAny[A], int]],
"a4": ClassVar[SerializeAsAny[A]],
"a5": Type[A],
"a6": constr(min_length=1), # Uses Annotation
"a7": Union[SerializeAsAny[A], int],
}
target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]]
annotations = MyNestedModel.__annotations__
self.assertEqual(str(annotations["a1"]), str(target["a1"]))
self.assertEqual(str(annotations["a2"]), str(target["a2"]))
self.assertEqual(str(annotations["a3"]), str(target["a3"]))
self.assertEqual(str(annotations["a4"]), str(target["a4"]))
self.assertEqual(str(annotations["a5"]), str(target["a5"]))
self.assertEqual(str(annotations["a6"]), str(target["a6"]))
self.assertEqual(str(annotations["a7"]), str(target["a7"]))

def test_pickle_consistency(self):
model = MultiAttributeModel(z=1, y="test", x=3.14, w=True)
serialized = pickle.dumps(model)
Expand Down
79 changes: 79 additions & 0 deletions ccflow/tests/test_evaluation_context_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from datetime import date

from ccflow import DateContext
from ccflow.callable import ModelEvaluationContext
from ccflow.evaluators import GraphEvaluator, LoggingEvaluator, MultiEvaluator
from ccflow.tests.evaluators.util import NodeModel


def _make_nested_mec(model):
ctx = DateContext(date=date(2022, 1, 1))
mec = model.__call__.get_evaluation_context(model, ctx)
assert isinstance(mec, ModelEvaluationContext)
# ensure nested: outer model is an evaluator, inner is a ModelEvaluationContext
assert isinstance(mec.context, ModelEvaluationContext)
return mec


def test_mec_model_dump_basic():
m = NodeModel()
mec = _make_nested_mec(m)

d = mec.model_dump()
assert isinstance(d, dict)
assert "fn" in d and "model" in d and "context" in d and "options" in d

s = mec.model_dump_json()
parsed = json.loads(s)
assert parsed["fn"] == d["fn"]
# Also verify mode-specific dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)


def test_mec_model_dump_diamond_graph():
n0 = NodeModel()
n1 = NodeModel(deps_model=[n0])
n2 = NodeModel(deps_model=[n0])
root = NodeModel(deps_model=[n1, n2])

mec = _make_nested_mec(root)

d = mec.model_dump()
assert isinstance(d, dict)
assert set(["fn", "model", "context", "options"]).issubset(d.keys())

s = mec.model_dump_json()
json.loads(s)
# verify mode dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)


def test_mec_model_dump_with_multi_evaluator():
m = NodeModel()
_ = LoggingEvaluator() # ensure import/validation
evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(), GraphEvaluator()])

# Simulate how Flow builds evaluation context with a custom evaluator
ctx = DateContext(date=date(2022, 1, 1))
mec = ModelEvaluationContext(model=evaluator, context=m.__call__.get_evaluation_context(m, ctx))

d = mec.model_dump()
assert isinstance(d, dict)
assert "fn" in d and "model" in d and "context" in d
s = mec.model_dump_json()
json.loads(s)
# verify mode dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)
2 changes: 1 addition & 1 deletion ccflow/ui/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,5 @@ def _on_model_change(self, event):
self._tabs.active = 0

# Update & show JSONEditor
self._json_editor.value = model.__pydantic_serializer__.to_python(model, fallback=str, mode="json")
self._json_editor.value = model.model_dump(mode="json", fallback=str)
self._json_container.visible = True
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies = [
"orjson",
"pandas",
"pyarrow",
"pydantic>=2.6,<3",
"pydantic>=2.13,<3",
"smart_open",
"tenacity",
]
Expand Down
Loading