Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions modern_di/providers/context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@


class ContextProvider(AbstractProvider[types.T_co]):
__slots__ = AbstractProvider.BASE_SLOTS

def __init__(self, *, scope: Scope = Scope.APP, context_type: type[types.T_co]) -> None:
super().__init__(scope=scope, bound_type=context_type)
__slots__ = [*AbstractProvider.BASE_SLOTS, "_context_type"]

def __init__(
self,
*,
scope: Scope = Scope.APP,
context_type: type[types.T_co],
bound_type: type | None = types.UNSET, # type: ignore[assignment]
) -> None:
super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else context_type)

def validate(self, container: "Container") -> dict[str, typing.Any]: # noqa: ARG002
return {"bound_type": self.bound_type, "self": self}
Expand Down
10 changes: 9 additions & 1 deletion modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing

from modern_di import errors, types
from modern_di.providers import ContextProvider
from modern_di.providers.abstract import AbstractProvider
from modern_di.scope import Scope
from modern_di.types_parser import SignatureItem, parse_creator
Expand Down Expand Up @@ -63,11 +64,18 @@ def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]:
if provider:
break

is_kwarg_not_found = not self._kwargs or k not in self._kwargs
if provider:
result[k] = provider
if is_kwarg_not_found and isinstance(provider, ContextProvider) and provider.resolve(container) is None:
raise RuntimeError(
errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format(
arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator
)
)
continue

if (not self._kwargs or k not in self._kwargs) and v.default == types.UNSET:
if v.default == types.UNSET and is_kwarg_not_found:
raise RuntimeError(
errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format(
arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator
Expand Down
34 changes: 26 additions & 8 deletions tests/providers/test_context_provider.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,51 @@
import dataclasses
import datetime

from modern_di import Container, Scope, providers
import pytest

from modern_di import Container, Group, Scope, providers


context_provider = providers.ContextProvider(scope=Scope.APP, context_type=datetime.datetime)
request_context_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=datetime.datetime)


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class SomeFactory:
arg1: datetime.datetime


class MyGroup(Group):
context_provider = providers.ContextProvider(scope=Scope.APP, context_type=datetime.datetime)
some_factory = providers.Factory(creator=SomeFactory)


def test_context_provider() -> None:
now = datetime.datetime.now(tz=datetime.timezone.utc)
app_container = Container(context={datetime.datetime: now})
app_container.validate_provider(context_provider)
instance1 = app_container.resolve_provider(context_provider)
instance2 = app_container.resolve_provider(context_provider)
app_container.validate_provider(MyGroup.context_provider)
instance1 = app_container.resolve_provider(MyGroup.context_provider)
instance2 = app_container.resolve_provider(MyGroup.context_provider)
assert instance1 is instance2 is now


def test_context_provider_set_context_after_creation() -> None:
now = datetime.datetime.now(tz=datetime.timezone.utc)
app_container = Container()
app_container.set_context(datetime.datetime, now)
instance1 = app_container.resolve_provider(context_provider)
instance2 = app_container.resolve_provider(context_provider)
instance1 = app_container.resolve_provider(MyGroup.context_provider)
instance2 = app_container.resolve_provider(MyGroup.context_provider)
assert instance1 is instance2 is now


def test_context_provider_not_found() -> None:
app_container = Container()
assert app_container.resolve_provider(context_provider) is None
assert app_container.resolve_provider(MyGroup.context_provider) is None


def test_context_provider_not_found_but_required() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match=r"Argument arg1 of type <class 'datetime.datetime'> cannot be resolved"):
app_container.resolve(SomeFactory)


def test_context_provider_in_request_scope() -> None:
Expand Down
Loading