diff --git a/modern_di/providers/context_provider.py b/modern_di/providers/context_provider.py index e6579ab..2b5e41c 100644 --- a/modern_di/providers/context_provider.py +++ b/modern_di/providers/context_provider.py @@ -10,7 +10,7 @@ class ContextProvider(AbstractProvider[types.T_co]): - __slots__ = [*AbstractProvider.BASE_SLOTS] + __slots__ = [*AbstractProvider.BASE_SLOTS, "_context_type"] def __init__( self, @@ -20,10 +20,11 @@ def __init__( bound_type: type | None = types.UNSET, # type: ignore[assignment] ) -> None: super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else context_type) + self._context_type = context_type def validate(self, container: "Container") -> dict[str, typing.Any]: # noqa: ARG002 return {"bound_type": self.bound_type, "self": self} def resolve(self, container: "Container") -> types.T_co | None: container = container.find_container(self.scope) - return container.context_registry.find_context(typing.cast(type[types.T_co], self.bound_type)) + return container.context_registry.find_context(typing.cast(type[types.T_co], self._context_type))