Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 184 additions & 1 deletion ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,28 @@
import pathlib
import platform
import sys
import types
import typing
import warnings
from functools import singledispatch
from types import GenericAlias, MappingProxyType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generic,
List,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
get_origin,
)

from packaging import version

Expand Down Expand Up @@ -165,6 +184,155 @@ def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **k
return super().__new__(self, name, bases, namespaces, **kwargs)


class _PydanticGenericTypeSpec(NamedTuple):
# Pickle must not rely on process-local names like ``GenericResult[int]``.
# Store generic args as data so the receiver can rebuild ``origin[args]``.
origin: Type[PydanticBaseModel]
args: Tuple[Any, ...]


class _GenericAliasTypeSpec(NamedTuple):
# Builtin/typing aliases such as ``list[GenericResult[int]]`` are normally
# pickleable, but their nested Pydantic generic args may not be. Only wrap
# aliases when at least one nested arg had to become portable data.
origin: Any
args: Tuple[Any, ...]
typing_name: Optional[str]


def _is_pydantic_generic_specialization(value: Any) -> bool:
if not isinstance(value, type):
return False
metadata = getattr(value, "__pydantic_generic_metadata__", None)
return bool(metadata and metadata.get("origin") is not None)


@singledispatch
def _portable_generic_type_arg(value: Any) -> Any:
"""Convert fragile generic type arguments into pickleable rebuild specs.

The top-level reducer stores ``origin`` and ``args`` for a model instance,
for example ``GenericResult`` and ``(int,)`` for ``GenericResult[int]``.
Some args are themselves runtime Pydantic generic classes, such as the
``ListResult[int]`` in ``GenericResult[ListResult[int]]``. Those nested
classes have the same cross-process problem as the top-level class: pickle
may look for a process-local global named ``ListResult[int]``.

This helper walks only type arguments, not model field values. Plain
importable args like ``int`` are left alone. Pydantic generic
specializations become explicit ``(origin, args)`` specs. Generic aliases
like ``list[ListResult[int]]`` are wrapped only if one of their nested args
needed that conversion. Some type args contain list/tuple containers, such
as the argument list in ``typing.Callable[[ListResult[int]], int]``; those
containers are also walked because they are part of the type expression.
"""
if _is_pydantic_generic_specialization(value):
# Example: ``ListResult[int]`` becomes
# ``_PydanticGenericTypeSpec(ListResult, (int,))``. This avoids storing
# the generated class object that may not be globally registered in the
# receiving process.
metadata = value.__pydantic_generic_metadata__
return _PydanticGenericTypeSpec(
metadata["origin"],
tuple(_portable_generic_type_arg(arg) for arg in metadata.get("args", ())),
)

origin = get_origin(value)
args = get_args(value)
if origin is not None and args:
# Example: ``list[ListResult[int]]`` is a normal generic alias, but it
# contains a fragile Pydantic specialization. Rebuild the alias from a
# portable version of its args only when recursion changed something.
portable_args = tuple(_portable_generic_type_arg(arg) for arg in args)
if portable_args != args:
typing_name = getattr(value, "_name", None) if getattr(value, "__module__", None) == "typing" else None
return _GenericAliasTypeSpec(origin, portable_args, typing_name)
return value


@_portable_generic_type_arg.register(list)
def _(value: list) -> Any:
portable_items = [_portable_generic_type_arg(item) for item in value]
return portable_items if portable_items != value else value


@_portable_generic_type_arg.register(tuple)
def _(value: tuple) -> Any:
portable_items = tuple(_portable_generic_type_arg(item) for item in value)
return portable_items if portable_items != value else value


@singledispatch
def _restore_generic_type_arg(value: Any) -> Any:
return value


@_restore_generic_type_arg.register(_PydanticGenericTypeSpec)
def _(value: _PydanticGenericTypeSpec) -> Any:
origin, args = value
return origin[tuple(_restore_generic_type_arg(arg) for arg in args)]


@_restore_generic_type_arg.register(_GenericAliasTypeSpec)
def _(value: _GenericAliasTypeSpec) -> Any:
origin, args, typing_name = value
restored_args = tuple(_restore_generic_type_arg(arg) for arg in args)
if typing_name is not None:
alias = getattr(typing, typing_name)
if typing_name == "Optional":
non_none_args = tuple(arg for arg in restored_args if arg is not type(None))
if len(non_none_args) == 1:
return alias[non_none_args[0]]
return alias[restored_args]
try:
return origin[restored_args]
except TypeError as exc:
if len(restored_args) == 1:
try:
return origin[restored_args[0]]
except TypeError:
pass
# ``types.UnionType`` is not itself subscriptable; rebuild PEP 604
# unions from their members if one of those members needed
# portable restoration.
union_type = getattr(types, "UnionType", None)
if union_type is not None and origin is union_type:
result = restored_args[0]
for arg in restored_args[1:]:
result = result | arg
return result
raise exc


@_restore_generic_type_arg.register(list)
def _(value: list) -> list:
return [_restore_generic_type_arg(item) for item in value]


@_restore_generic_type_arg.register(tuple)
def _(value: tuple) -> tuple:
return tuple(_restore_generic_type_arg(item) for item in value)


def _new_ccflow_generic_model(origin: Type[PydanticBaseModel], args: Tuple[Any, ...]) -> PydanticBaseModel:
"""Restore a Pydantic generic specialization without a process-local global.

Pydantic registers runtime generic specializations like ``GenericResult[int]``
in the origin module only in the process that first materializes them. Pickle
can then serialize instances by that global name, but fresh workers may not
have created the same specialization yet. Restore from the stable origin
class plus generic args instead of relying on that process-local module
mutation.
"""

# Materialize the specialized class in this process. Pickle will restore the
# raw Pydantic state afterwards through BaseModel.__setstate__; keeping that
# state in the outer pickle stream preserves memo semantics for shared
# references, cycles, and protocol-5 buffers.
cls = origin[tuple(_restore_generic_type_arg(arg) for arg in args)]
return cls.__new__(cls)


class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta):
"""BaseModel is a base class for all pydantic models within the ccflow framework.

Expand Down Expand Up @@ -334,6 +502,21 @@ def __setstate__(self, state):
state["__pydantic_fields_set__"] = set(state["__pydantic_fields_set__"])
super().__setstate__(state)

def __reduce_ex__(self, protocol):
if _is_pydantic_generic_specialization(type(self)):
# Pydantic's default reducer may serialize runtime generic classes
# by names like ``GenericResult[int]``. Those names exist only
# after a process has materialized that exact specialization, so
# Ray/fresh workers can fail to import them while unpickling.
# Serialize a stable constructor instead: generic origin plus
# portable generic args. Leave Pydantic instance state to the outer
# pickle stream so shared references, cycles, and protocol-5 buffers
# keep normal pickle semantics.
metadata = type(self).__pydantic_generic_metadata__
args = tuple(_portable_generic_type_arg(arg) for arg in metadata.get("args", ()))
return (_new_ccflow_generic_model, (metadata["origin"], args), self.__getstate__())
return super().__reduce_ex__(protocol)


class _ModelRegistryData(PydanticBaseModel):
"""A data structure representation of the model registry, without the associated functionality"""
Expand Down
Loading
Loading