Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 57 additions & 32 deletions imednet/core/endpoint/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Iterable,
List,
Optional,
Protocol,
Type,
TypeVar,
cast,
Expand All @@ -25,21 +24,12 @@
from imednet.utils.filters import build_filter_string

from .base import BaseEndpoint
from .protocols import EndpointProtocol

if TYPE_CHECKING: # pragma: no cover - imported for type hints only

class EndpointProtocol(Protocol):
PATH: str
MODEL: Type[JsonModel]
_id_param: str
_cache_name: Optional[str]
requires_study_key: bool
PAGE_SIZE: int
_pop_study_filter: bool
_missing_study_exception: type[Exception]

def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]: ...
def _build_path(self, *segments: Any) -> str: ...
if TYPE_CHECKING: # pragma: no cover
# EndpointProtocol is imported from .protocols, but we keep this import for
# backward compatibility if needed
pass
Comment on lines +29 to +32
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if TYPE_CHECKING block is now a no-op (pass) and the accompanying comment about “backward compatibility” is misleading since nothing is conditionally imported/defined anymore. Consider removing this block entirely, or restoring a concrete type-only import/alias if something actually needs to be guarded for type-checking.

Copilot uses AI. Check for mistakes.


T = TypeVar("T", bound=JsonModel)
Expand Down Expand Up @@ -134,11 +124,12 @@ def _resolve_params(
filters: Dict[str, Any],
) -> tuple[Optional[str], Dict[str, Any], Dict[str, Any]]:
# This method handles filter normalization and cache retrieval preparation
# Assuming _auto_filter is available via self (BaseEndpoint)
filters = self._auto_filter(filters) # type: ignore[attr-defined]
# Assuming _auto_filter is available via self (EndpointProtocol)
filters = cast(EndpointProtocol, self)._auto_filter(filters)

# Extract special parameters using the hook
special_params = self._extract_special_params(filters)

if special_params:
if extra_params is None:
extra_params = {}
Expand Down Expand Up @@ -181,12 +172,14 @@ class ListEndpointMixin(ParamMixin, CacheMixin, ParsingMixin[T]):
PAGE_SIZE: int = DEFAULT_PAGE_SIZE

def _get_path(self, study: Optional[str]) -> str:
# Cast to EndpointProtocol to access PATH and _build_path
protocol_self = cast(EndpointProtocol, self)
segments: Iterable[Any]
if self.requires_study_key:
segments = (study, self.PATH)
if protocol_self.requires_study_key:
segments = (study, protocol_self.PATH)
else:
segments = (self.PATH,) if self.PATH else ()
return self._build_path(*segments) # type: ignore[attr-defined]
segments = (protocol_self.PATH,) if protocol_self.PATH else ()
return protocol_self._build_path(*segments)

def _resolve_parse_func(self) -> Callable[[Any], T]:
"""
Expand Down Expand Up @@ -229,6 +222,31 @@ def _execute_sync_list(
self._update_local_cache(result, study, has_filters, cache)
return result

def _prepare_list_request(
self,
study_key: Optional[str],
extra_params: Optional[Dict[str, Any]],
filters: Dict[str, Any],
refresh: bool,
) -> tuple[Optional[List[T]], str, Dict[str, Any], Optional[str], bool, Any]:
"""
Prepare parameters, cache, and path for list request.

Returns:
Tuple of (cached_result, path, params, study, has_other_filters, cache_obj)
"""
# self is ListEndpointMixin, which inherits ParamMixin and CacheMixin
study, params, other_filters = self._resolve_params(study_key, extra_params, filters)

cache = self._get_local_cache()
cached_result = self._check_cache_hit(study, refresh, other_filters, cache)

if cached_result is not None:
return cast(List[T], cached_result), "", {}, study, False, None

path = self._get_path(study)
return None, path, params, study, bool(other_filters), cache

def _list_impl(
self,
client: RequestorProtocol | AsyncRequestorProtocol,
Expand All @@ -240,24 +258,23 @@ def _list_impl(
**filters: Any,
) -> List[T] | Awaitable[List[T]]:

study, params, other_filters = self._resolve_params(study_key, extra_params, filters)
cached_result, path, params, study, has_filters, cache = self._prepare_list_request(
study_key, extra_params, filters, refresh
)

cache = self._get_local_cache()
cached_result = self._check_cache_hit(study, refresh, other_filters, cache)
if cached_result is not None:
return cast(List[T], cached_result)
return cached_result

path = self._get_path(study)
paginator = paginator_cls(client, path, params=params, page_size=self.PAGE_SIZE)
parse_func = self._resolve_parse_func()

if hasattr(paginator, "__aiter__"):
return self._execute_async_list(
cast(AsyncPaginator, paginator), parse_func, study, bool(other_filters), cache
cast(AsyncPaginator, paginator), parse_func, study, has_filters, cache
)

return self._execute_sync_list(
cast(Paginator, paginator), parse_func, study, bool(other_filters), cache
cast(Paginator, paginator), parse_func, study, has_filters, cache
)


Expand Down Expand Up @@ -339,7 +356,9 @@ def _get_path_for_id(self, study_key: Optional[str], item_id: Any) -> str:
segments = (study_key, self.PATH, item_id)
else:
segments = (self.PATH, item_id) if self.PATH else (item_id,)
return self._build_path(*segments)
# self is mixed into BaseEndpoint which implements _build_path
# Use cast for type checking compliance without importing BaseEndpoint
Comment on lines +359 to +360
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says this cast is used “without importing BaseEndpoint”, but BaseEndpoint is already imported at the top of this module. Either update the comment to reflect the real intent (type narrowing) or remove it to avoid confusion.

Suggested change
# self is mixed into BaseEndpoint which implements _build_path
# Use cast for type checking compliance without importing BaseEndpoint
# self is expected to be mixed into BaseEndpoint, which implements _build_path;
# cast narrows self to BaseEndpoint so type checkers see _build_path as available

Copilot uses AI. Check for mistakes.
return cast(BaseEndpoint, self)._build_path(*segments)

def _raise_not_found(self, study_key: Optional[str], item_id: Any) -> None:
raise ValueError(f"{self.MODEL.__name__} not found")
Expand All @@ -350,6 +369,7 @@ def _get_impl_path(
*,
study_key: Optional[str],
item_id: Any,
is_async: bool = False,
) -> T | Awaitable[T]:
path = self._get_path_for_id(study_key, item_id)

Expand All @@ -361,15 +381,20 @@ def process_response(response: Any) -> T:
self._raise_not_found(study_key, item_id)
return self._parse_item(data)

if inspect.iscoroutinefunction(client.get):
if is_async:

async def _await() -> T:
response = await client.get(path) # type: ignore
# We assume client is AsyncRequestorProtocol because is_async=True
# But we can't be sure type-wise unless we narrow it down.
# In practice, caller ensures this.
aclient = cast(AsyncRequestorProtocol, client)
response = await aclient.get(path)
return process_response(response)

return _await()

response = client.get(path) # type: ignore
sclient = cast(RequestorProtocol, client)
response = sclient.get(path)
return process_response(response)


Expand Down
25 changes: 25 additions & 0 deletions imednet/core/endpoint/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable

from imednet.models.json_base import JsonModel


@runtime_checkable
class EndpointProtocol(Protocol):
"""Protocol defining the interface for endpoint classes."""

PATH: str
MODEL: Type[JsonModel]
_id_param: str
_cache_name: Optional[str]
requires_study_key: bool
PAGE_SIZE: int
_pop_study_filter: bool
_missing_study_exception: type[Exception]

def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""Apply automatic filters (e.g., default study key)."""
...

def _build_path(self, *segments: Any) -> str:
"""Build the API path."""
...
2 changes: 1 addition & 1 deletion imednet/endpoints/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def async_get(self, study_key: str, batch_id: str) -> JobStatus:
client = self._require_async_client()
return await cast(
Awaitable[JobStatus],
self._get_impl_path(client, study_key=study_key, item_id=batch_id),
self._get_impl_path(client, study_key=study_key, item_id=batch_id, is_async=True),
)

def _execute_list_request(
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/endpoints/test_jobs_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unittest.mock import AsyncMock

import pytest

import imednet.endpoints.jobs as jobs
from imednet.models.jobs import JobStatus


@pytest.mark.asyncio
async def test_async_get_success(dummy_client, context, response_factory):
# Setup async mock
async_client = AsyncMock()
async_client.get.return_value = response_factory({"jobId": "1"})

ep = jobs.JobsEndpoint(dummy_client, context, async_client=async_client)

result = await ep.async_get("S1", "B1")

async_client.get.assert_called_once_with("/api/v1/edc/studies/S1/jobs/B1")
assert isinstance(result, JobStatus)


@pytest.mark.asyncio
async def test_async_get_not_found(dummy_client, context, response_factory):
async_client = AsyncMock()
async_client.get.return_value = response_factory({})

ep = jobs.JobsEndpoint(dummy_client, context, async_client=async_client)

with pytest.raises(ValueError):
await ep.async_get("S1", "B1")
Loading