diff --git a/imednet/core/endpoint/mixins.py b/imednet/core/endpoint/mixins.py index 3f755e58..5262cc6c 100644 --- a/imednet/core/endpoint/mixins.py +++ b/imednet/core/endpoint/mixins.py @@ -11,7 +11,6 @@ Iterable, List, Optional, - Protocol, Type, TypeVar, cast, @@ -25,21 +24,12 @@ from imednet.utils.filters import build_filter_string from .base import BaseEndpoint +from .protocols import EndpointProtocol -if TYPE_CHECKING: # pragma: no cover - imported for type hints only - - class EndpointProtocol(Protocol): - PATH: str - MODEL: Type[JsonModel] - _id_param: str - _cache_name: Optional[str] - requires_study_key: bool - PAGE_SIZE: int - _pop_study_filter: bool - _missing_study_exception: type[Exception] - - def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]: ... - def _build_path(self, *segments: Any) -> str: ... +if TYPE_CHECKING: # pragma: no cover + # EndpointProtocol is imported from .protocols, but we keep this import for + # backward compatibility if needed + pass T = TypeVar("T", bound=JsonModel) @@ -134,11 +124,12 @@ def _resolve_params( filters: Dict[str, Any], ) -> tuple[Optional[str], Dict[str, Any], Dict[str, Any]]: # This method handles filter normalization and cache retrieval preparation - # Assuming _auto_filter is available via self (BaseEndpoint) - filters = self._auto_filter(filters) # type: ignore[attr-defined] + # Assuming _auto_filter is available via self (EndpointProtocol) + filters = cast(EndpointProtocol, self)._auto_filter(filters) # Extract special parameters using the hook special_params = self._extract_special_params(filters) + if special_params: if extra_params is None: extra_params = {} @@ -181,12 +172,14 @@ class ListEndpointMixin(ParamMixin, CacheMixin, ParsingMixin[T]): PAGE_SIZE: int = DEFAULT_PAGE_SIZE def _get_path(self, study: Optional[str]) -> str: + # Cast to EndpointProtocol to access PATH and _build_path + protocol_self = cast(EndpointProtocol, self) segments: Iterable[Any] - if self.requires_study_key: - segments = (study, self.PATH) + if protocol_self.requires_study_key: + segments = (study, protocol_self.PATH) else: - segments = (self.PATH,) if self.PATH else () - return self._build_path(*segments) # type: ignore[attr-defined] + segments = (protocol_self.PATH,) if protocol_self.PATH else () + return protocol_self._build_path(*segments) def _resolve_parse_func(self) -> Callable[[Any], T]: """ @@ -229,6 +222,31 @@ def _execute_sync_list( self._update_local_cache(result, study, has_filters, cache) return result + def _prepare_list_request( + self, + study_key: Optional[str], + extra_params: Optional[Dict[str, Any]], + filters: Dict[str, Any], + refresh: bool, + ) -> tuple[Optional[List[T]], str, Dict[str, Any], Optional[str], bool, Any]: + """ + Prepare parameters, cache, and path for list request. + + Returns: + Tuple of (cached_result, path, params, study, has_other_filters, cache_obj) + """ + # self is ListEndpointMixin, which inherits ParamMixin and CacheMixin + study, params, other_filters = self._resolve_params(study_key, extra_params, filters) + + cache = self._get_local_cache() + cached_result = self._check_cache_hit(study, refresh, other_filters, cache) + + if cached_result is not None: + return cast(List[T], cached_result), "", {}, study, False, None + + path = self._get_path(study) + return None, path, params, study, bool(other_filters), cache + def _list_impl( self, client: RequestorProtocol | AsyncRequestorProtocol, @@ -240,24 +258,23 @@ def _list_impl( **filters: Any, ) -> List[T] | Awaitable[List[T]]: - study, params, other_filters = self._resolve_params(study_key, extra_params, filters) + cached_result, path, params, study, has_filters, cache = self._prepare_list_request( + study_key, extra_params, filters, refresh + ) - cache = self._get_local_cache() - cached_result = self._check_cache_hit(study, refresh, other_filters, cache) if cached_result is not None: - return cast(List[T], cached_result) + return cached_result - path = self._get_path(study) paginator = paginator_cls(client, path, params=params, page_size=self.PAGE_SIZE) parse_func = self._resolve_parse_func() if hasattr(paginator, "__aiter__"): return self._execute_async_list( - cast(AsyncPaginator, paginator), parse_func, study, bool(other_filters), cache + cast(AsyncPaginator, paginator), parse_func, study, has_filters, cache ) return self._execute_sync_list( - cast(Paginator, paginator), parse_func, study, bool(other_filters), cache + cast(Paginator, paginator), parse_func, study, has_filters, cache ) @@ -339,7 +356,9 @@ def _get_path_for_id(self, study_key: Optional[str], item_id: Any) -> str: segments = (study_key, self.PATH, item_id) else: segments = (self.PATH, item_id) if self.PATH else (item_id,) - return self._build_path(*segments) + # self is mixed into BaseEndpoint which implements _build_path + # Use cast for type checking compliance without importing BaseEndpoint + return cast(BaseEndpoint, self)._build_path(*segments) def _raise_not_found(self, study_key: Optional[str], item_id: Any) -> None: raise ValueError(f"{self.MODEL.__name__} not found") @@ -350,6 +369,7 @@ def _get_impl_path( *, study_key: Optional[str], item_id: Any, + is_async: bool = False, ) -> T | Awaitable[T]: path = self._get_path_for_id(study_key, item_id) @@ -361,15 +381,20 @@ def process_response(response: Any) -> T: self._raise_not_found(study_key, item_id) return self._parse_item(data) - if inspect.iscoroutinefunction(client.get): + if is_async: async def _await() -> T: - response = await client.get(path) # type: ignore + # We assume client is AsyncRequestorProtocol because is_async=True + # But we can't be sure type-wise unless we narrow it down. + # In practice, caller ensures this. + aclient = cast(AsyncRequestorProtocol, client) + response = await aclient.get(path) return process_response(response) return _await() - response = client.get(path) # type: ignore + sclient = cast(RequestorProtocol, client) + response = sclient.get(path) return process_response(response) diff --git a/imednet/core/endpoint/protocols.py b/imednet/core/endpoint/protocols.py new file mode 100644 index 00000000..889a3b82 --- /dev/null +++ b/imednet/core/endpoint/protocols.py @@ -0,0 +1,25 @@ +from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable + +from imednet.models.json_base import JsonModel + + +@runtime_checkable +class EndpointProtocol(Protocol): + """Protocol defining the interface for endpoint classes.""" + + PATH: str + MODEL: Type[JsonModel] + _id_param: str + _cache_name: Optional[str] + requires_study_key: bool + PAGE_SIZE: int + _pop_study_filter: bool + _missing_study_exception: type[Exception] + + def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]: + """Apply automatic filters (e.g., default study key).""" + ... + + def _build_path(self, *segments: Any) -> str: + """Build the API path.""" + ... diff --git a/imednet/endpoints/jobs.py b/imednet/endpoints/jobs.py index 9a905753..f9d0c815 100644 --- a/imednet/endpoints/jobs.py +++ b/imednet/endpoints/jobs.py @@ -66,7 +66,7 @@ async def async_get(self, study_key: str, batch_id: str) -> JobStatus: client = self._require_async_client() return await cast( Awaitable[JobStatus], - self._get_impl_path(client, study_key=study_key, item_id=batch_id), + self._get_impl_path(client, study_key=study_key, item_id=batch_id, is_async=True), ) def _execute_list_request( diff --git a/tests/unit/endpoints/test_jobs_async.py b/tests/unit/endpoints/test_jobs_async.py new file mode 100644 index 00000000..cef26d84 --- /dev/null +++ b/tests/unit/endpoints/test_jobs_async.py @@ -0,0 +1,31 @@ +from unittest.mock import AsyncMock + +import pytest + +import imednet.endpoints.jobs as jobs +from imednet.models.jobs import JobStatus + + +@pytest.mark.asyncio +async def test_async_get_success(dummy_client, context, response_factory): + # Setup async mock + async_client = AsyncMock() + async_client.get.return_value = response_factory({"jobId": "1"}) + + ep = jobs.JobsEndpoint(dummy_client, context, async_client=async_client) + + result = await ep.async_get("S1", "B1") + + async_client.get.assert_called_once_with("/api/v1/edc/studies/S1/jobs/B1") + assert isinstance(result, JobStatus) + + +@pytest.mark.asyncio +async def test_async_get_not_found(dummy_client, context, response_factory): + async_client = AsyncMock() + async_client.get.return_value = response_factory({}) + + ep = jobs.JobsEndpoint(dummy_client, context, async_client=async_client) + + with pytest.raises(ValueError): + await ep.async_get("S1", "B1")