From 68ee37914f0a26861c36641bcfcb587e61153dea Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Mon, 16 Mar 2026 13:26:51 +0300 Subject: [PATCH] raise error for required but missing context --- modern_di/providers/context_provider.py | 14 +++++++--- modern_di/providers/factory.py | 10 ++++++- tests/providers/test_context_provider.py | 34 ++++++++++++++++++------ 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/modern_di/providers/context_provider.py b/modern_di/providers/context_provider.py index 71a7af4..b54e1fd 100644 --- a/modern_di/providers/context_provider.py +++ b/modern_di/providers/context_provider.py @@ -10,10 +10,16 @@ class ContextProvider(AbstractProvider[types.T_co]): - __slots__ = AbstractProvider.BASE_SLOTS - - def __init__(self, *, scope: Scope = Scope.APP, context_type: type[types.T_co]) -> None: - super().__init__(scope=scope, bound_type=context_type) + __slots__ = [*AbstractProvider.BASE_SLOTS, "_context_type"] + + def __init__( + self, + *, + scope: Scope = Scope.APP, + context_type: type[types.T_co], + bound_type: type | None = types.UNSET, # type: ignore[assignment] + ) -> None: + super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else context_type) def validate(self, container: "Container") -> dict[str, typing.Any]: # noqa: ARG002 return {"bound_type": self.bound_type, "self": self} diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index aede648..45bce30 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -3,6 +3,7 @@ import typing from modern_di import errors, types +from modern_di.providers import ContextProvider from modern_di.providers.abstract import AbstractProvider from modern_di.scope import Scope from modern_di.types_parser import SignatureItem, parse_creator @@ -63,11 +64,18 @@ def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]: if provider: break + is_kwarg_not_found = not self._kwargs or k not in self._kwargs if provider: result[k] = provider + if is_kwarg_not_found and isinstance(provider, ContextProvider) and provider.resolve(container) is None: + raise RuntimeError( + errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( + arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator + ) + ) continue - if (not self._kwargs or k not in self._kwargs) and v.default == types.UNSET: + if v.default == types.UNSET and is_kwarg_not_found: raise RuntimeError( errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator diff --git a/tests/providers/test_context_provider.py b/tests/providers/test_context_provider.py index 77be649..7d093ef 100644 --- a/tests/providers/test_context_provider.py +++ b/tests/providers/test_context_provider.py @@ -1,18 +1,30 @@ +import dataclasses import datetime -from modern_di import Container, Scope, providers +import pytest + +from modern_di import Container, Group, Scope, providers -context_provider = providers.ContextProvider(scope=Scope.APP, context_type=datetime.datetime) request_context_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=datetime.datetime) +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class SomeFactory: + arg1: datetime.datetime + + +class MyGroup(Group): + context_provider = providers.ContextProvider(scope=Scope.APP, context_type=datetime.datetime) + some_factory = providers.Factory(creator=SomeFactory) + + def test_context_provider() -> None: now = datetime.datetime.now(tz=datetime.timezone.utc) app_container = Container(context={datetime.datetime: now}) - app_container.validate_provider(context_provider) - instance1 = app_container.resolve_provider(context_provider) - instance2 = app_container.resolve_provider(context_provider) + app_container.validate_provider(MyGroup.context_provider) + instance1 = app_container.resolve_provider(MyGroup.context_provider) + instance2 = app_container.resolve_provider(MyGroup.context_provider) assert instance1 is instance2 is now @@ -20,14 +32,20 @@ def test_context_provider_set_context_after_creation() -> None: now = datetime.datetime.now(tz=datetime.timezone.utc) app_container = Container() app_container.set_context(datetime.datetime, now) - instance1 = app_container.resolve_provider(context_provider) - instance2 = app_container.resolve_provider(context_provider) + instance1 = app_container.resolve_provider(MyGroup.context_provider) + instance2 = app_container.resolve_provider(MyGroup.context_provider) assert instance1 is instance2 is now def test_context_provider_not_found() -> None: app_container = Container() - assert app_container.resolve_provider(context_provider) is None + assert app_container.resolve_provider(MyGroup.context_provider) is None + + +def test_context_provider_not_found_but_required() -> None: + app_container = Container(groups=[MyGroup]) + with pytest.raises(RuntimeError, match=r"Argument arg1 of type cannot be resolved"): + app_container.resolve(SomeFactory) def test_context_provider_in_request_scope() -> None: