From db3fb18e3ca2a80deadfe85779445f70de8e9f1c Mon Sep 17 00:00:00 2001 From: Adam Dangoor Date: Tue, 24 Feb 2026 23:24:13 +0000 Subject: [PATCH] Add asyncio support with async client classes Implement async versions of all three client classes (AsyncVWS, AsyncCloudRecoService, AsyncVuMarkService) alongside transport abstraction. Adds AsyncTransport protocol and AsyncHTTPXTransport using httpx.AsyncClient. Includes 99 new async integration tests with complete exception coverage. All 287 tests pass with strict mypy and ruff validation. Co-Authored-By: Claude Haiku 4.5 --- docs/source/api-reference.rst | 12 + pyproject.toml | 2 + spelling_private_dict.txt | 2 + src/vws/__init__.py | 6 + src/vws/_async_vws_request.py | 77 +++ src/vws/async_query.py | 226 +++++++ src/vws/async_vumark_service.py | 170 ++++++ src/vws/async_vws.py | 692 ++++++++++++++++++++++ src/vws/transports.py | 122 +++- tests/conftest.py | 55 +- tests/test_async_cloud_reco_exceptions.py | 120 ++++ tests/test_async_query.py | 225 +++++++ tests/test_async_vws.py | 500 ++++++++++++++++ tests/test_async_vws_exceptions.py | 304 ++++++++++ tests/test_transports.py | 147 ++++- 15 files changed, 2656 insertions(+), 4 deletions(-) create mode 100644 src/vws/_async_vws_request.py create mode 100644 src/vws/async_query.py create mode 100644 src/vws/async_vumark_service.py create mode 100644 src/vws/async_vws.py create mode 100644 tests/test_async_cloud_reco_exceptions.py create mode 100644 tests/test_async_query.py create mode 100644 tests/test_async_vws.py create mode 100644 tests/test_async_vws_exceptions.py diff --git a/docs/source/api-reference.rst b/docs/source/api-reference.rst index 3fc7441c..c2765245 100644 --- a/docs/source/api-reference.rst +++ b/docs/source/api-reference.rst @@ -5,6 +5,18 @@ API Reference :undoc-members: :members: +.. automodule:: vws.async_vws + :undoc-members: + :members: + +.. automodule:: vws.async_query + :undoc-members: + :members: + +.. automodule:: vws.async_vumark_service + :undoc-members: + :members: + .. automodule:: vws.reports :undoc-members: :members: diff --git a/pyproject.toml b/pyproject.toml index 88d229a1..b4fbfeea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ optional-dependencies.dev = [ "pyright==1.1.408", "pyroma==5.0.1", "pytest==9.0.2", + "pytest-asyncio==1.3.0", "pytest-cov==7.0.0", "pyyaml==6.0.3", "ruff==0.15.2", @@ -359,6 +360,7 @@ ignore_path = [ # but Vulture does not enable this. ignore_names = [ # Public API classes imported by users from vws.transports + "AsyncHTTPXTransport", "HTTPXTransport", # pytest configuration "pytest_collect_file", diff --git a/spelling_private_dict.txt b/spelling_private_dict.txt index 8831d3d3..73d358a0 100644 --- a/spelling_private_dict.txt +++ b/spelling_private_dict.txt @@ -27,6 +27,8 @@ admin api args ascii +async +asyncio beartype bool boolean diff --git a/src/vws/__init__.py b/src/vws/__init__.py index a091641f..a181cb81 100644 --- a/src/vws/__init__.py +++ b/src/vws/__init__.py @@ -1,11 +1,17 @@ """A library for Vuforia Web Services.""" +from .async_query import AsyncCloudRecoService +from .async_vumark_service import AsyncVuMarkService +from .async_vws import AsyncVWS from .query import CloudRecoService from .vumark_service import VuMarkService from .vws import VWS __all__ = [ "VWS", + "AsyncCloudRecoService", + "AsyncVWS", + "AsyncVuMarkService", "CloudRecoService", "VuMarkService", ] diff --git a/src/vws/_async_vws_request.py b/src/vws/_async_vws_request.py new file mode 100644 index 00000000..4ebd50c8 --- /dev/null +++ b/src/vws/_async_vws_request.py @@ -0,0 +1,77 @@ +"""Internal helper for making authenticated async requests to the +Vuforia Target API. +""" + +from beartype import BeartypeConf, beartype +from vws_auth_tools import authorization_header, rfc_1123_date + +from vws.response import Response +from vws.transports import AsyncTransport + + +@beartype(conf=BeartypeConf(is_pep484_tower=True)) +async def async_target_api_request( + *, + content_type: str, + server_access_key: str, + server_secret_key: str, + method: str, + data: bytes, + request_path: str, + base_vws_url: str, + request_timeout_seconds: float | tuple[float, float], + extra_headers: dict[str, str], + transport: AsyncTransport, +) -> Response: + """Make an async request to the Vuforia Target API. + + Args: + content_type: The content type of the request. + server_access_key: A VWS server access key. + server_secret_key: A VWS server secret key. + method: The HTTP method which will be used in the + request. + data: The request body which will be used in the + request. + request_path: The path to the endpoint which will be + used in the request. + base_vws_url: The base URL for the VWS API. + request_timeout_seconds: The timeout for the request. + This can be a float to set both the connect and + read timeouts, or a (connect, read) tuple. + extra_headers: Additional headers to include in the + request. + transport: The async HTTP transport to use for the + request. + + Returns: + The response to the request. + """ + date_string = rfc_1123_date() + + signature_string = authorization_header( + access_key=server_access_key, + secret_key=server_secret_key, + method=method, + content=data, + content_type=content_type, + date=date_string, + request_path=request_path, + ) + + headers = { + "Authorization": signature_string, + "Date": date_string, + "Content-Type": content_type, + **extra_headers, + } + + url = base_vws_url.rstrip("/") + request_path + + return await transport( + method=method, + url=url, + headers=headers, + data=data, + request_timeout=request_timeout_seconds, + ) diff --git a/src/vws/async_query.py b/src/vws/async_query.py new file mode 100644 index 00000000..56981d4e --- /dev/null +++ b/src/vws/async_query.py @@ -0,0 +1,226 @@ +"""Async tools for interacting with the Vuforia Cloud Recognition +Web APIs. +""" + +import datetime +import json +from http import HTTPMethod, HTTPStatus +from typing import Any, Self + +from beartype import BeartypeConf, beartype +from urllib3.filepost import encode_multipart_formdata +from vws_auth_tools import authorization_header, rfc_1123_date + +from vws._image_utils import ImageType as _ImageType +from vws._image_utils import get_image_data as _get_image_data +from vws.exceptions.cloud_reco_exceptions import ( + AuthenticationFailureError, + BadImageError, + InactiveProjectError, + MaxNumResultsOutOfRangeError, + RequestTimeTooSkewedError, +) +from vws.exceptions.custom_exceptions import ( + RequestEntityTooLargeError, + ServerError, +) +from vws.include_target_data import CloudRecoIncludeTargetData +from vws.reports import QueryResult, TargetData +from vws.transports import AsyncHTTPXTransport, AsyncTransport + + +@beartype(conf=BeartypeConf(is_pep484_tower=True)) +class AsyncCloudRecoService: + """An async interface to the Vuforia Cloud Recognition Web + APIs. + """ + + def __init__( + self, + *, + client_access_key: str, + client_secret_key: str, + base_vwq_url: str = "https://cloudreco.vuforia.com", + request_timeout_seconds: float | tuple[float, float] = 30.0, + transport: AsyncTransport | None = None, + ) -> None: + """ + Args: + client_access_key: A VWS client access key. + client_secret_key: A VWS client secret key. + base_vwq_url: The base URL for the VWQ API. + request_timeout_seconds: The timeout for each + HTTP request. This can be a float to set both + the connect and read timeouts, or a + (connect, read) tuple. + transport: The async HTTP transport to use for + requests. Defaults to + ``AsyncHTTPXTransport()``. + """ + self._client_access_key = client_access_key + self._client_secret_key = client_secret_key + self._base_vwq_url = base_vwq_url + self._request_timeout_seconds = request_timeout_seconds + self._transport = transport or AsyncHTTPXTransport() + + async def aclose(self) -> None: + """Close the underlying transport if it supports closing.""" + close = getattr(self._transport, "aclose", None) + if close is not None: + await close() + + async def __aenter__(self) -> Self: + """Enter the async context manager.""" + return self + + async def __aexit__(self, *_args: object) -> None: + """Exit the async context manager and close the transport.""" + await self.aclose() + + async def query( + self, + *, + image: _ImageType, + max_num_results: int = 1, + include_target_data: CloudRecoIncludeTargetData = ( + CloudRecoIncludeTargetData.TOP + ), + ) -> list[QueryResult]: + """Use the Vuforia Web Query API to make an Image + Recognition Query. + + See + https://developer.vuforia.com/library/web-api/vuforia-query-web-api + for parameter details. + + Args: + image: The image to make a query against. + max_num_results: The maximum number of matching + targets to be returned. + include_target_data: Indicates if target_data + records shall be returned for the matched + targets. Accepted values are top (default + value, only return target_data for top ranked + match), none (return no target_data), all + (for all matched targets). + + Raises: + ~vws.exceptions.cloud_reco_exceptions.AuthenticationFailureError: + The client access key pair is not correct. + ~vws.exceptions.cloud_reco_exceptions.MaxNumResultsOutOfRangeError: + ``max_num_results`` is not within the range (1, 50). + ~vws.exceptions.cloud_reco_exceptions.InactiveProjectError: The + project is inactive. + ~vws.exceptions.cloud_reco_exceptions.RequestTimeTooSkewedError: + There is an error with the time sent to Vuforia. + ~vws.exceptions.cloud_reco_exceptions.BadImageError: There is a + problem with the given image. For example, it must be a JPEG or + PNG file in the grayscale or RGB color space. + ~vws.exceptions.custom_exceptions.RequestEntityTooLargeError: The + given image is too large. + ~vws.exceptions.custom_exceptions.ServerError: There is an + error with Vuforia's servers. + + Returns: + An ordered list of target details of matching + targets. + """ + image_content = _get_image_data(image=image) + body: dict[str, Any] = { + "image": ( + "image.jpeg", + image_content, + "image/jpeg", + ), + "max_num_results": ( + None, + int(max_num_results), + "text/plain", + ), + "include_target_data": ( + None, + include_target_data.value, + "text/plain", + ), + } + date = rfc_1123_date() + request_path = "/v1/query" + content, content_type_header = encode_multipart_formdata(fields=body) + method = HTTPMethod.POST + + authorization_string = authorization_header( + access_key=self._client_access_key, + secret_key=self._client_secret_key, + method=method, + content=content, + # Note that this is not the actual Content-Type + # header value sent. + content_type="multipart/form-data", + date=date, + request_path=request_path, + ) + + headers = { + "Authorization": authorization_string, + "Date": date, + "Content-Type": content_type_header, + } + + response = await self._transport( + method=method, + url=self._base_vwq_url.rstrip("/") + request_path, + headers=headers, + data=content, + request_timeout=self._request_timeout_seconds, + ) + + if response.status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: + raise RequestEntityTooLargeError(response=response) + + if "Integer out of range" in response.text: + raise MaxNumResultsOutOfRangeError( + response=response, + ) + + if ( + response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR + ): # pragma: no cover + raise ServerError(response=response) + + result_code = json.loads(s=response.text)["result_code"] + if result_code != "Success": + exception = { + "AuthenticationFailure": (AuthenticationFailureError), + "BadImage": BadImageError, + "InactiveProject": InactiveProjectError, + "RequestTimeTooSkewed": (RequestTimeTooSkewedError), + }[result_code] + raise exception(response=response) + + result: list[QueryResult] = [] + result_list = list( + json.loads(s=response.text)["results"], + ) + for item in result_list: + target_data: TargetData | None = None + if "target_data" in item: + target_data_dict = item["target_data"] + metadata = target_data_dict["application_metadata"] + timestamp_string = target_data_dict["target_timestamp"] + target_timestamp = datetime.datetime.fromtimestamp( + timestamp=timestamp_string, + tz=datetime.UTC, + ) + target_data = TargetData( + name=target_data_dict["name"], + application_metadata=metadata, + target_timestamp=target_timestamp, + ) + + query_result = QueryResult( + target_id=item["target_id"], + target_data=target_data, + ) + + result.append(query_result) + return result diff --git a/src/vws/async_vumark_service.py b/src/vws/async_vumark_service.py new file mode 100644 index 00000000..c1e189ee --- /dev/null +++ b/src/vws/async_vumark_service.py @@ -0,0 +1,170 @@ +"""Async interface to the Vuforia VuMark Generation Web API.""" + +import json +from http import HTTPMethod, HTTPStatus +from typing import Self + +from beartype import BeartypeConf, beartype + +from vws._async_vws_request import async_target_api_request +from vws.exceptions.custom_exceptions import ServerError +from vws.exceptions.vws_exceptions import ( + AuthenticationFailureError, + BadRequestError, + DateRangeError, + FailError, + InvalidAcceptHeaderError, + InvalidInstanceIdError, + InvalidTargetTypeError, + RequestTimeTooSkewedError, + TargetStatusNotSuccessError, + TooManyRequestsError, + UnknownTargetError, +) +from vws.transports import AsyncHTTPXTransport, AsyncTransport +from vws.vumark_accept import VuMarkAccept + + +@beartype(conf=BeartypeConf(is_pep484_tower=True)) +class AsyncVuMarkService: + """An async interface to the Vuforia VuMark Generation Web + API. + """ + + def __init__( + self, + *, + server_access_key: str, + server_secret_key: str, + base_vws_url: str = "https://vws.vuforia.com", + request_timeout_seconds: float | tuple[float, float] = 30.0, + transport: AsyncTransport | None = None, + ) -> None: + """ + Args: + server_access_key: A VWS server access key. + server_secret_key: A VWS server secret key. + base_vws_url: The base URL for the VWS API. + request_timeout_seconds: The timeout for each + HTTP request. This can be a float to set both + the connect and read timeouts, or a + (connect, read) tuple. + transport: The async HTTP transport to use for + requests. Defaults to + ``AsyncHTTPXTransport()``. + """ + self._server_access_key = server_access_key + self._server_secret_key = server_secret_key + self._base_vws_url = base_vws_url + self._request_timeout_seconds = request_timeout_seconds + self._transport = transport or AsyncHTTPXTransport() + + async def aclose(self) -> None: + """Close the underlying transport if it supports closing.""" + close = getattr(self._transport, "aclose", None) + if close is not None: + await close() + + async def __aenter__(self) -> Self: + """Enter the async context manager.""" + return self + + async def __aexit__(self, *_args: object) -> None: + """Exit the async context manager and close the transport.""" + await self.aclose() + + async def generate_vumark_instance( + self, + *, + target_id: str, + instance_id: str, + accept: VuMarkAccept, + ) -> bytes: + """Generate a VuMark instance image. + + See + https://developer.vuforia.com/library/vuforia-engine/web-api/vumark-generation-web-api/ + for parameter details. + + Args: + target_id: The ID of the VuMark target. + instance_id: The instance ID to encode in the + VuMark. + accept: The image format to return. + + Returns: + The VuMark instance image bytes. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.InvalidAcceptHeaderError: The + Accept header value is not supported. + ~vws.exceptions.vws_exceptions.InvalidInstanceIdError: The + instance ID is invalid. For example, it may be empty. + ~vws.exceptions.vws_exceptions.InvalidTargetTypeError: The target + is not a VuMark template target. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.vws_exceptions.TargetStatusNotSuccessError: The + target is not in the success state. + ~vws.exceptions.vws_exceptions.UnknownTargetError: The given target + ID does not match a target in the database. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + request_path = f"/targets/{target_id}/instances" + content_type = "application/json" + request_data = json.dumps( + obj={"instance_id": instance_id}, + ).encode(encoding="utf-8") + + response = await async_target_api_request( + content_type=content_type, + server_access_key=self._server_access_key, + server_secret_key=self._server_secret_key, + method=HTTPMethod.POST, + data=request_data, + request_path=request_path, + base_vws_url=self._base_vws_url, + request_timeout_seconds=(self._request_timeout_seconds), + extra_headers={"Accept": accept}, + transport=self._transport, + ) + + if ( + response.status_code == HTTPStatus.TOO_MANY_REQUESTS + ): # pragma: no cover + # The Vuforia API returns a 429 response with no + # JSON body. + raise TooManyRequestsError(response=response) + + if ( + response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR + ): # pragma: no cover + raise ServerError(response=response) + + if response.status_code == HTTPStatus.OK: + return response.content + + result_code = json.loads(s=response.text)["result_code"] + + exception = { + "AuthenticationFailure": (AuthenticationFailureError), + "BadRequest": BadRequestError, + "DateRangeError": DateRangeError, + "Fail": FailError, + "InvalidAcceptHeader": InvalidAcceptHeaderError, + "InvalidInstanceId": InvalidInstanceIdError, + "InvalidTargetType": InvalidTargetTypeError, + "RequestTimeTooSkewed": RequestTimeTooSkewedError, + "TargetStatusNotSuccess": (TargetStatusNotSuccessError), + "UnknownTarget": UnknownTargetError, + }[result_code] + + raise exception(response=response) diff --git a/src/vws/async_vws.py b/src/vws/async_vws.py new file mode 100644 index 00000000..f0f9005a --- /dev/null +++ b/src/vws/async_vws.py @@ -0,0 +1,692 @@ +"""Async tools for interacting with Vuforia APIs.""" + +import asyncio +import base64 +import json +from datetime import date +from http import HTTPMethod, HTTPStatus +from typing import Self + +from beartype import BeartypeConf, beartype + +from vws._async_vws_request import async_target_api_request +from vws._image_utils import ImageType as _ImageType +from vws._image_utils import get_image_data as _get_image_data +from vws.exceptions.custom_exceptions import ( + ServerError, + TargetProcessingTimeoutError, +) +from vws.exceptions.vws_exceptions import ( + AuthenticationFailureError, + BadImageError, + BadRequestError, + DateRangeError, + FailError, + ImageTooLargeError, + MetadataTooLargeError, + ProjectHasNoAPIAccessError, + ProjectInactiveError, + ProjectSuspendedError, + RequestQuotaReachedError, + RequestTimeTooSkewedError, + TargetNameExistError, + TargetQuotaReachedError, + TargetStatusNotSuccessError, + TargetStatusProcessingError, + TooManyRequestsError, + UnknownTargetError, +) +from vws.reports import ( + DatabaseSummaryReport, + TargetRecord, + TargetStatusAndRecord, + TargetStatuses, + TargetSummaryReport, +) +from vws.response import Response +from vws.transports import AsyncHTTPXTransport, AsyncTransport + + +@beartype(conf=BeartypeConf(is_pep484_tower=True)) +class AsyncVWS: + """An async interface to Vuforia Web Services APIs.""" + + def __init__( + self, + *, + server_access_key: str, + server_secret_key: str, + base_vws_url: str = "https://vws.vuforia.com", + request_timeout_seconds: float | tuple[float, float] = 30.0, + transport: AsyncTransport | None = None, + ) -> None: + """ + Args: + server_access_key: A VWS server access key. + server_secret_key: A VWS server secret key. + base_vws_url: The base URL for the VWS API. + request_timeout_seconds: The timeout for each + HTTP request. This can be a float to set both + the connect and read timeouts, or a + (connect, read) tuple. + transport: The async HTTP transport to use for + requests. Defaults to + ``AsyncHTTPXTransport()``. + """ + self._server_access_key = server_access_key + self._server_secret_key = server_secret_key + self._base_vws_url = base_vws_url + self._request_timeout_seconds = request_timeout_seconds + self._transport = transport or AsyncHTTPXTransport() + + async def aclose(self) -> None: + """Close the underlying transport if it supports closing.""" + close = getattr(self._transport, "aclose", None) + if close is not None: + await close() + + async def __aenter__(self) -> Self: + """Enter the async context manager.""" + return self + + async def __aexit__(self, *_args: object) -> None: + """Exit the async context manager and close the transport.""" + await self.aclose() + + async def make_request( + self, + *, + method: str, + data: bytes, + request_path: str, + expected_result_code: str, + content_type: str, + extra_headers: dict[str, str] | None = None, + ) -> Response: + """Make an async request to the Vuforia Target API. + + Args: + method: The HTTP method which will be used in + the request. + data: The request body which will be used in the + request. + request_path: The path to the endpoint which + will be used in the request. + expected_result_code: See + "VWS API Result Codes" on + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api. + content_type: The content type of the request. + extra_headers: Additional headers to include in + the request. + + Returns: + The response to the request. + + Raises: + ~vws.exceptions.custom_exceptions.ServerError: + There is an error with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: + Vuforia is rate limiting access. + json.JSONDecodeError: The server did not respond + with valid JSON. This may happen if the + server address is not a valid Vuforia server. + """ + response = await async_target_api_request( + content_type=content_type, + server_access_key=self._server_access_key, + server_secret_key=self._server_secret_key, + method=method, + data=data, + request_path=request_path, + base_vws_url=self._base_vws_url, + request_timeout_seconds=self._request_timeout_seconds, + extra_headers=extra_headers or {}, + transport=self._transport, + ) + + if ( + response.status_code == HTTPStatus.TOO_MANY_REQUESTS + ): # pragma: no cover + # The Vuforia API returns a 429 response with no JSON body. + raise TooManyRequestsError(response=response) + + if ( + response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR + ): # pragma: no cover + raise ServerError(response=response) + + result_code = json.loads(s=response.text)["result_code"] + + if result_code == expected_result_code: + return response + + exception = { + "AuthenticationFailure": AuthenticationFailureError, + "BadImage": BadImageError, + "BadRequest": BadRequestError, + "DateRangeError": DateRangeError, + "Fail": FailError, + "ImageTooLarge": ImageTooLargeError, + "MetadataTooLarge": MetadataTooLargeError, + "ProjectHasNoAPIAccess": ProjectHasNoAPIAccessError, + "ProjectInactive": ProjectInactiveError, + "ProjectSuspended": ProjectSuspendedError, + "RequestQuotaReached": RequestQuotaReachedError, + "RequestTimeTooSkewed": RequestTimeTooSkewedError, + "TargetNameExist": TargetNameExistError, + "TargetQuotaReached": TargetQuotaReachedError, + "TargetStatusNotSuccess": TargetStatusNotSuccessError, + "TargetStatusProcessing": TargetStatusProcessingError, + "UnknownTarget": UnknownTargetError, + }[result_code] + + raise exception(response=response) + + async def add_target( + self, + *, + name: str, + width: float, + image: _ImageType, + application_metadata: str | None, + active_flag: bool, + ) -> str: + """Add a target to a Vuforia Web Services database. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#add + for parameter details. + + Args: + name: The name of the target. + width: The width of the target. + image: The image of the target. + active_flag: Whether or not the target is active for query. + application_metadata: The application metadata of the target. + This must be base64 encoded, for example by using:: + + base64.b64encode('input_string').decode('ascii') + + Returns: + The target ID of the new target. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.BadImageError: There is a problem + with the given image. For example, it must be a JPEG or PNG + file in the grayscale or RGB color space. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.MetadataTooLargeError: The given + metadata is too large. The maximum size is 1 MB of data when + Base64 encoded. + ~vws.exceptions.vws_exceptions.ImageTooLargeError: The given image + is too large. + ~vws.exceptions.vws_exceptions.TargetNameExistError: A target with + the given ``name`` already exists. + ~vws.exceptions.vws_exceptions.ProjectInactiveError: The project is + inactive. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. This has been seen to happen when the + given name includes a bad character. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + image_data = _get_image_data(image=image) + image_data_encoded = base64.b64encode(s=image_data).decode( + encoding="ascii", + ) + + data = { + "name": name, + "width": width, + "image": image_data_encoded, + "active_flag": active_flag, + "application_metadata": application_metadata, + } + + content = json.dumps(obj=data).encode(encoding="utf-8") + + response = await self.make_request( + method=HTTPMethod.POST, + data=content, + request_path="/targets", + expected_result_code="TargetCreated", + content_type="application/json", + ) + + return str(object=json.loads(s=response.text)["target_id"]) + + async def get_target_record(self, target_id: str) -> TargetStatusAndRecord: + """Get a given target's target record from the Target + Management System. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#target-record. + + Args: + target_id: The ID of the target to get details of. + + Returns: + Response details of a target from Vuforia. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.UnknownTargetError: The given target + ID does not match a target in the database. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + response = await self.make_request( + method=HTTPMethod.GET, + data=b"", + request_path=f"/targets/{target_id}", + expected_result_code="Success", + content_type="application/json", + ) + + result_data = json.loads(s=response.text) + status = TargetStatuses(value=result_data["status"]) + target_record_dict = dict(result_data["target_record"]) + target_record = TargetRecord( + target_id=target_record_dict["target_id"], + active_flag=bool(target_record_dict["active_flag"]), + name=target_record_dict["name"], + width=float(target_record_dict["width"]), + tracking_rating=int(target_record_dict["tracking_rating"]), + reco_rating=target_record_dict["reco_rating"], + ) + return TargetStatusAndRecord( + status=status, + target_record=target_record, + ) + + async def wait_for_target_processed( + self, + *, + target_id: str, + seconds_between_requests: float = 0.2, + timeout_seconds: float = 60 * 5, + ) -> None: + """Wait up to five minutes (arbitrary) for a target to + get past the processing stage. + + Args: + target_id: The ID of the target to wait for. + seconds_between_requests: The number of seconds to + wait between requests made while polling the + target status. + We wait 0.2 seconds by default, rather than + less, than that to decrease the number of calls + made to the API, to decrease the likelihood of + hitting the request quota. + timeout_seconds: The maximum number of seconds to + wait for the target to be processed. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.custom_exceptions.TargetProcessingTimeoutError: The + target remained in the processing stage for more than + ``timeout_seconds`` seconds. + ~vws.exceptions.vws_exceptions.UnknownTargetError: The given target + ID does not match a target in the database. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + start_time = asyncio.get_event_loop().time() + while True: + report = await self.get_target_summary_report( + target_id=target_id, + ) + if report.status != TargetStatuses.PROCESSING: + # Guard against the target still being seen as + # processing by other endpoints due to eventual + # consistency. + await asyncio.sleep( + delay=seconds_between_requests, + ) + return + + elapsed_time = asyncio.get_event_loop().time() - start_time + if elapsed_time > timeout_seconds: # pragma: no cover + raise TargetProcessingTimeoutError + + await asyncio.sleep( + delay=seconds_between_requests, + ) + + async def list_targets(self) -> list[str]: + """List target IDs. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#details-list. + + Returns: + The IDs of all targets in the database. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + response = await self.make_request( + method=HTTPMethod.GET, + data=b"", + request_path="/targets", + expected_result_code="Success", + content_type="application/json", + ) + + return list(json.loads(s=response.text)["results"]) + + async def get_target_summary_report( + self, target_id: str + ) -> TargetSummaryReport: + """Get a summary report for a target. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#summary-report. + + Args: + target_id: The ID of the target to get a summary + report for. + + Returns: + Details of the target. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.UnknownTargetError: The given target + ID does not match a target in the database. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + response = await self.make_request( + method=HTTPMethod.GET, + data=b"", + request_path=f"/summary/{target_id}", + expected_result_code="Success", + content_type="application/json", + ) + + result_data = dict(json.loads(s=response.text)) + return TargetSummaryReport( + status=TargetStatuses( + value=result_data["status"], + ), + database_name=result_data["database_name"], + target_name=result_data["target_name"], + upload_date=date.fromisoformat( + result_data["upload_date"], + ), + active_flag=bool(result_data["active_flag"]), + tracking_rating=int( + result_data["tracking_rating"], + ), + total_recos=int(result_data["total_recos"]), + current_month_recos=int( + result_data["current_month_recos"], + ), + previous_month_recos=int( + result_data["previous_month_recos"], + ), + ) + + async def get_database_summary_report( + self, + ) -> DatabaseSummaryReport: + """Get a summary report for the database. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#summary-report. + + Returns: + Details of the database. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + response = await self.make_request( + method=HTTPMethod.GET, + data=b"", + request_path="/summary", + expected_result_code="Success", + content_type="application/json", + ) + + response_data = dict(json.loads(s=response.text)) + return DatabaseSummaryReport( + active_images=int(response_data["active_images"]), + current_month_recos=int( + response_data["current_month_recos"], + ), + failed_images=int(response_data["failed_images"]), + inactive_images=int( + response_data["inactive_images"], + ), + name=str(object=response_data["name"]), + previous_month_recos=int( + response_data["previous_month_recos"], + ), + processing_images=int( + response_data["processing_images"], + ), + reco_threshold=int( + response_data["reco_threshold"], + ), + request_quota=int(response_data["request_quota"]), + request_usage=int(response_data["request_usage"]), + target_quota=int(response_data["target_quota"]), + total_recos=int(response_data["total_recos"]), + ) + + async def delete_target(self, target_id: str) -> None: + """Delete a given target. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#delete. + + Args: + target_id: The ID of the target to delete. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.UnknownTargetError: The given target + ID does not match a target in the database. + ~vws.exceptions.vws_exceptions.TargetStatusProcessingError: The + given target is in the processing state. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + await self.make_request( + method=HTTPMethod.DELETE, + data=b"", + request_path=f"/targets/{target_id}", + expected_result_code="Success", + content_type="application/json", + ) + + async def get_duplicate_targets(self, target_id: str) -> list[str]: + """Get targets which may be considered duplicates of a + given target. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#check. + + Args: + target_id: The ID of the target to delete. + + Returns: + The target IDs of duplicate targets. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.UnknownTargetError: The given target + ID does not match a target in the database. + ~vws.exceptions.vws_exceptions.ProjectInactiveError: The project is + inactive. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + response = await self.make_request( + method=HTTPMethod.GET, + data=b"", + request_path=f"/duplicates/{target_id}", + expected_result_code="Success", + content_type="application/json", + ) + + return list( + json.loads(s=response.text)["similar_targets"], + ) + + async def update_target( + self, + *, + target_id: str, + name: str | None = None, + width: float | None = None, + image: _ImageType | None = None, + active_flag: bool | None = None, + application_metadata: str | None = None, + ) -> None: + """Update a target in a Vuforia Web Services database. + + See + https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#update + for parameter details. + + Args: + target_id: The ID of the target to update. + name: The name of the target. + width: The width of the target. + image: The image of the target. + active_flag: Whether or not the target is active + for query. + application_metadata: The application metadata of + the target. + This must be base64 encoded, for example by + using:: + + base64.b64encode('input_string').decode('ascii') + + Giving ``None`` will not change the application + metadata. + + Raises: + ~vws.exceptions.vws_exceptions.AuthenticationFailureError: The + secret key is not correct. + ~vws.exceptions.vws_exceptions.BadImageError: There is a problem + with the given image. For example, it must be a JPEG or PNG + file in the grayscale or RGB color space. + ~vws.exceptions.vws_exceptions.FailError: There was an error with + the request. For example, the given access key does not match a + known database. + ~vws.exceptions.vws_exceptions.MetadataTooLargeError: The given + metadata is too large. The maximum size is 1 MB of data when + Base64 encoded. + ~vws.exceptions.vws_exceptions.ImageTooLargeError: The given image + is too large. + ~vws.exceptions.vws_exceptions.TargetNameExistError: A target with + the given ``name`` already exists. + ~vws.exceptions.vws_exceptions.ProjectInactiveError: The project is + inactive. + ~vws.exceptions.vws_exceptions.RequestTimeTooSkewedError: There is + an error with the time sent to Vuforia. + ~vws.exceptions.custom_exceptions.ServerError: There is an error + with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is + rate limiting access. + """ + data: dict[str, str | bool | float | int] = {} + + if name is not None: + data["name"] = name + + if width is not None: + data["width"] = width + + if image is not None: + image_data = _get_image_data(image=image) + image_data_encoded = base64.b64encode( + s=image_data, + ).decode(encoding="ascii") + data["image"] = image_data_encoded + + if active_flag is not None: + data["active_flag"] = active_flag + + if application_metadata is not None: + data["application_metadata"] = application_metadata + + content = json.dumps(obj=data).encode(encoding="utf-8") + + await self.make_request( + method=HTTPMethod.PUT, + data=content, + request_path=f"/targets/{target_id}", + expected_result_code="Success", + content_type="application/json", + ) diff --git a/src/vws/transports.py b/src/vws/transports.py index 7980c578..6bf9ceb9 100644 --- a/src/vws/transports.py +++ b/src/vws/transports.py @@ -1,6 +1,7 @@ """HTTP transport implementations for VWS clients.""" -from typing import Protocol, runtime_checkable +from collections.abc import Awaitable +from typing import Protocol, Self, runtime_checkable import httpx import requests @@ -156,3 +157,122 @@ def __call__( tell_position=len(content), content=content, ) + + +@runtime_checkable +class AsyncTransport(Protocol): + """Protocol for async HTTP transports used by VWS clients. + + An async transport is a callable that makes an HTTP request + and returns a ``Response``. + """ + + def __call__( + self, + *, + method: str, + url: str, + headers: dict[str, str], + data: bytes, + request_timeout: float | tuple[float, float], + ) -> Awaitable[Response]: + """Make an async HTTP request. + + Args: + method: The HTTP method (e.g. "GET", "POST"). + url: The full URL to request. + headers: Headers to send with the request. + data: The request body as bytes. + request_timeout: The timeout for the request. A float + sets both the connect and read timeouts. A + (connect, read) tuple sets them individually. + + Returns: + A Response populated from the HTTP response. + """ + ... # pylint: disable=unnecessary-ellipsis + + +@beartype(conf=BeartypeConf(is_pep484_tower=True)) +class AsyncHTTPXTransport: + """Async HTTP transport using the ``httpx`` library. + + This is the default transport for async VWS clients. + A single ``httpx.AsyncClient`` is reused across requests + for connection pooling. + """ + + def __init__(self) -> None: + """Create an ``AsyncHTTPXTransport``.""" + self._client = httpx.AsyncClient() + + async def aclose(self) -> None: + """Close the underlying ``httpx.AsyncClient``.""" + await self._client.aclose() + + async def __aenter__(self) -> Self: + """Enter the async context manager.""" + return self + + async def __aexit__(self, *_args: object) -> None: + """Exit the async context manager and close the client.""" + await self.aclose() + + async def __call__( + self, + *, + method: str, + url: str, + headers: dict[str, str], + data: bytes, + request_timeout: float | tuple[float, float], + ) -> Response: + """Make an async HTTP request using ``httpx``. + + Args: + method: The HTTP method. + url: The full URL. + headers: Request headers. + data: The request body. + request_timeout: The request timeout. + + Returns: + A Response populated from the httpx response. + """ + if isinstance(request_timeout, tuple): + connect_timeout, read_timeout = request_timeout + httpx_timeout = httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=None, + pool=None, + ) + else: + httpx_timeout = httpx.Timeout( + connect=request_timeout, + read=request_timeout, + write=None, + pool=None, + ) + + httpx_response = await self._client.request( + method=method, + url=url, + headers=headers, + content=data, + timeout=httpx_timeout, + follow_redirects=True, + ) + + content = bytes(httpx_response.content) + request_content = httpx_response.request.content + + return Response( + text=httpx_response.text, + url=str(object=httpx_response.url), + status_code=httpx_response.status_code, + headers=dict(httpx_response.headers), + request_body=bytes(request_content) or None, + tell_position=len(content), + content=content, + ) diff --git a/tests/conftest.py b/tests/conftest.py index eb8b3a57..820908ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,24 @@ """Configuration, plugins and fixtures for `pytest`.""" import io -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator from pathlib import Path from typing import BinaryIO, Literal import pytest +import pytest_asyncio from mock_vws import MockVWS from mock_vws.database import CloudDatabase, VuMarkDatabase from mock_vws.target import VuMarkTarget -from vws import VWS, CloudRecoService, VuMarkService +from vws import ( + VWS, + AsyncCloudRecoService, + AsyncVuMarkService, + AsyncVWS, + CloudRecoService, + VuMarkService, +) @pytest.fixture(name="_mock_database") @@ -70,6 +78,49 @@ def cloud_reco_client(*, _mock_database: CloudDatabase) -> CloudRecoService: ) +@pytest_asyncio.fixture +async def async_vws_client( + *, + _mock_database: CloudDatabase, +) -> AsyncGenerator[AsyncVWS]: + """An async VWS client which connects to a mock database.""" + async with AsyncVWS( + server_access_key=_mock_database.server_access_key, + server_secret_key=_mock_database.server_secret_key, + ) as client: + yield client + + +@pytest_asyncio.fixture +async def async_cloud_reco_client( + *, + _mock_database: CloudDatabase, +) -> AsyncGenerator[AsyncCloudRecoService]: + """An async ``CloudRecoService`` client which connects to a mock + database. + """ + async with AsyncCloudRecoService( + client_access_key=_mock_database.client_access_key, + client_secret_key=_mock_database.client_secret_key, + ) as client: + yield client + + +@pytest_asyncio.fixture +async def async_vumark_service_client( + *, + _mock_vumark_database: VuMarkDatabase, +) -> AsyncGenerator[AsyncVuMarkService]: + """An async ``VuMarkService`` client which connects to a mock VuMark + database. + """ + async with AsyncVuMarkService( + server_access_key=_mock_vumark_database.server_access_key, + server_secret_key=_mock_vumark_database.server_secret_key, + ) as client: + yield client + + @pytest.fixture(name="image_file", params=["r+b", "rb"]) def fixture_image_file( *, diff --git a/tests/test_async_cloud_reco_exceptions.py b/tests/test_async_cloud_reco_exceptions.py new file mode 100644 index 00000000..cbe2b22c --- /dev/null +++ b/tests/test_async_cloud_reco_exceptions.py @@ -0,0 +1,120 @@ +"""Tests for exceptions raised when using the +AsyncCloudRecoService. +""" + +import io +import uuid +from http import HTTPStatus + +import pytest +from mock_vws import MockVWS +from mock_vws.database import CloudDatabase +from mock_vws.states import States + +from vws import AsyncCloudRecoService +from vws.exceptions.cloud_reco_exceptions import ( + AuthenticationFailureError, + InactiveProjectError, + MaxNumResultsOutOfRangeError, +) +from vws.exceptions.custom_exceptions import ( + RequestEntityTooLargeError, +) + + +@pytest.mark.asyncio +async def test_too_many_max_results( + *, + async_cloud_reco_client: AsyncCloudRecoService, + high_quality_image: io.BytesIO, +) -> None: + """A ``MaxNumResultsOutOfRange`` error is raised if the given + ``max_num_results`` is out of range. + """ + with pytest.raises( + expected_exception=MaxNumResultsOutOfRangeError, + ) as exc: + await async_cloud_reco_client.query( + image=high_quality_image, + max_num_results=51, + ) + + expected_value = ( + "Integer out of range (51) in form data part " + "'max_result'. " + "Accepted range is from 1 to 50 (inclusive)." + ) + assert str(object=exc.value) == exc.value.response.text == expected_value + + +@pytest.mark.asyncio +async def test_image_too_large( + *, + async_cloud_reco_client: AsyncCloudRecoService, + png_too_large: io.BytesIO | io.BufferedRandom, +) -> None: + """A ``RequestEntityTooLarge`` exception is raised if an + image which is too large is given. + """ + with pytest.raises( + expected_exception=RequestEntityTooLargeError, + ) as exc: + await async_cloud_reco_client.query( + image=png_too_large, + ) + + assert ( + exc.value.response.status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE + ) + + +@pytest.mark.asyncio +async def test_authentication_failure( + high_quality_image: io.BytesIO, +) -> None: + """An ``AuthenticationFailure`` exception is raised when the + client secret key is incorrect. + """ + database = CloudDatabase() + async_cloud_reco_client = AsyncCloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=uuid.uuid4().hex, + ) + with MockVWS() as mock: + mock.add_cloud_database(cloud_database=database) + + with pytest.raises( + expected_exception=AuthenticationFailureError, + ) as exc: + await async_cloud_reco_client.query( + image=high_quality_image, + ) + + assert exc.value.response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.asyncio +async def test_inactive_project( + high_quality_image: io.BytesIO, +) -> None: + """An ``InactiveProject`` exception is raised when querying + an inactive database. + """ + database = CloudDatabase(state=States.PROJECT_INACTIVE) + with MockVWS() as mock: + mock.add_cloud_database(cloud_database=database) + async_cloud_reco_client = AsyncCloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=database.client_secret_key, + ) + + with pytest.raises( + expected_exception=InactiveProjectError, + ) as exc: + await async_cloud_reco_client.query( + image=high_quality_image, + ) + + response = exc.value.response + assert response.status_code == HTTPStatus.FORBIDDEN + assert response.tell_position != 0 diff --git a/tests/test_async_query.py b/tests/test_async_query.py new file mode 100644 index 00000000..2de1ce83 --- /dev/null +++ b/tests/test_async_query.py @@ -0,0 +1,225 @@ +"""Tests for the ``AsyncCloudRecoService`` querying functionality.""" + +import io +import uuid +from typing import BinaryIO + +import pytest +from mock_vws import MockVWS +from mock_vws.database import CloudDatabase + +from vws import AsyncCloudRecoService, AsyncVWS +from vws.include_target_data import CloudRecoIncludeTargetData + + +class TestQuery: + """Tests for making async image queries.""" + + @staticmethod + @pytest.mark.asyncio + async def test_no_matches( + *, + async_cloud_reco_client: AsyncCloudRecoService, + image: io.BytesIO | BinaryIO, + ) -> None: + """An empty list is returned if there are no matches.""" + result = await async_cloud_reco_client.query( + image=image, + ) + assert result == [] + + @staticmethod + @pytest.mark.asyncio + async def test_match( + *, + async_vws_client: AsyncVWS, + async_cloud_reco_client: AsyncCloudRecoService, + image: io.BytesIO | BinaryIO, + ) -> None: + """Details of matching targets are returned.""" + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + [matching_target] = await async_cloud_reco_client.query( + image=image, + ) + assert matching_target.target_id == target_id + + +class TestCustomBaseVWQURL: + """Tests for using a custom base VWQ URL.""" + + @staticmethod + @pytest.mark.asyncio + async def test_custom_base_url( + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to query a target in a database under + a custom VWQ URL. + """ + base_vwq_url = "http://example.com" + with MockVWS(base_vwq_url=base_vwq_url) as mock: + database = CloudDatabase() + mock.add_cloud_database(cloud_database=database) + async_vws_client = AsyncVWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + ) + + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + + async_cloud_reco_client = AsyncCloudRecoService( + client_access_key=database.client_access_key, + client_secret_key=database.client_secret_key, + base_vwq_url=base_vwq_url, + ) + + matches = await async_cloud_reco_client.query( + image=image, + ) + assert len(matches) == 1 + match = matches[0] + assert match.target_id == target_id + + +class TestMaxNumResults: + """Tests for the ``max_num_results`` parameter of + ``query``. + """ + + @staticmethod + @pytest.mark.asyncio + async def test_custom( + *, + async_vws_client: AsyncVWS, + async_cloud_reco_client: AsyncCloudRecoService, + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to set a custom + ``max_num_results``. + """ + target_id = await async_vws_client.add_target( + name=uuid.uuid4().hex, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + target_id_2 = await async_vws_client.add_target( + name=uuid.uuid4().hex, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + target_id_3 = await async_vws_client.add_target( + name=uuid.uuid4().hex, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id_2, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id_3, + ) + max_num_results = 2 + matches = await async_cloud_reco_client.query( + image=image, + max_num_results=max_num_results, + ) + assert len(matches) == max_num_results + + +class TestIncludeTargetData: + """Tests for the ``include_target_data`` parameter of + ``query``. + """ + + @staticmethod + @pytest.mark.asyncio + async def test_none( + *, + async_vws_client: AsyncVWS, + async_cloud_reco_client: AsyncCloudRecoService, + image: io.BytesIO | BinaryIO, + ) -> None: + """When ``CloudRecoIncludeTargetData.NONE`` is given, + target data is not returned in any match. + """ + target_id = await async_vws_client.add_target( + name=uuid.uuid4().hex, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + [match] = await async_cloud_reco_client.query( + image=image, + include_target_data=(CloudRecoIncludeTargetData.NONE), + ) + assert match.target_data is None + + @staticmethod + @pytest.mark.asyncio + async def test_all( + *, + async_vws_client: AsyncVWS, + async_cloud_reco_client: AsyncCloudRecoService, + image: io.BytesIO | BinaryIO, + ) -> None: + """When ``CloudRecoIncludeTargetData.ALL`` is given, + target data is returned in all matches. + """ + target_id = await async_vws_client.add_target( + name=uuid.uuid4().hex, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + target_id_2 = await async_vws_client.add_target( + name=uuid.uuid4().hex, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id_2, + ) + top_match, second_match = await async_cloud_reco_client.query( + image=image, + max_num_results=2, + include_target_data=(CloudRecoIncludeTargetData.ALL), + ) + assert top_match.target_data is not None + assert second_match.target_data is not None diff --git a/tests/test_async_vws.py b/tests/test_async_vws.py new file mode 100644 index 00000000..c08cf78a --- /dev/null +++ b/tests/test_async_vws.py @@ -0,0 +1,500 @@ +"""Tests for async helper functions for managing a Vuforia database.""" + +import base64 +import io +import uuid +from typing import BinaryIO + +import pytest +from mock_vws import MockVWS +from mock_vws.database import CloudDatabase + +from vws import AsyncCloudRecoService, AsyncVuMarkService, AsyncVWS +from vws.exceptions.custom_exceptions import ( + TargetProcessingTimeoutError, +) +from vws.reports import ( + DatabaseSummaryReport, + TargetRecord, + TargetStatuses, +) +from vws.vumark_accept import VuMarkAccept + + +class TestAddTarget: + """Tests for adding a target.""" + + @staticmethod + @pytest.mark.asyncio + @pytest.mark.parametrize( + argnames="application_metadata", + argvalues=[None, b"a"], + ) + @pytest.mark.parametrize( + argnames="active_flag", + argvalues=[True, False], + ) + async def test_add_target( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + application_metadata: bytes | None, + async_cloud_reco_client: AsyncCloudRecoService, + active_flag: bool, + ) -> None: + """No exception is raised when adding one target.""" + name = "x" + width = 1 + if application_metadata is None: + encoded_metadata = None + else: + encoded_metadata_bytes = base64.b64encode( + s=application_metadata, + ) + encoded_metadata = encoded_metadata_bytes.decode( + encoding="utf-8", + ) + + target_id = await async_vws_client.add_target( + name=name, + width=width, + image=image, + application_metadata=encoded_metadata, + active_flag=active_flag, + ) + target_record = ( + await async_vws_client.get_target_record( + target_id=target_id, + ) + ).target_record + assert target_record.name == name + assert target_record.width == width + assert target_record.active_flag is active_flag + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + matching_targets = await async_cloud_reco_client.query( + image=image, + ) + if active_flag: + [matching_target] = matching_targets + assert matching_target.target_id == target_id + assert matching_target.target_data is not None + query_metadata = matching_target.target_data.application_metadata + assert query_metadata == encoded_metadata + else: + assert matching_targets == [] + + @staticmethod + @pytest.mark.asyncio + async def test_add_two_targets( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """No exception is raised when adding two targets with + different names. + + This demonstrates that the image seek position is not + changed. + """ + for name in ("a", "b"): + await async_vws_client.add_target( + name=name, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + +class TestCustomBaseVWSURL: + """Tests for using a custom base VWS URL.""" + + @staticmethod + @pytest.mark.asyncio + async def test_custom_base_url( + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to add a target to a database under a + custom VWS URL. + """ + base_vws_url = "http://example.com" + with MockVWS(base_vws_url=base_vws_url) as mock: + database = CloudDatabase() + mock.add_cloud_database(cloud_database=database) + async_vws_client = AsyncVWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + base_vws_url=base_vws_url, + ) + + await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + +class TestListTargets: + """Tests for listing targets.""" + + @staticmethod + @pytest.mark.asyncio + async def test_list_targets( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to get a list of target IDs.""" + id_1 = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + id_2 = await async_vws_client.add_target( + name="a", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + targets = await async_vws_client.list_targets() + assert sorted(targets) == sorted([id_1, id_2]) + + +class TestDelete: + """Test for deleting a target.""" + + @staticmethod + @pytest.mark.asyncio + async def test_delete_target( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to delete a target.""" + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + targets = await async_vws_client.list_targets() + assert target_id in targets + await async_vws_client.delete_target( + target_id=target_id, + ) + targets = await async_vws_client.list_targets() + assert target_id not in targets + + +class TestGetTargetSummaryReport: + """Tests for getting a summary report for a target.""" + + @staticmethod + @pytest.mark.asyncio + async def test_get_target_summary_report( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """Details of a target are returned by + ``get_target_summary_report``. + """ + target_name = uuid.uuid4().hex + target_id = await async_vws_client.add_target( + name=target_name, + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + report = await async_vws_client.get_target_summary_report( + target_id=target_id, + ) + + assert report.target_name == target_name + assert report.active_flag is True + assert report.total_recos == 0 + + +class TestGetDatabaseSummaryReport: + """Tests for getting a summary report for a database.""" + + @staticmethod + @pytest.mark.asyncio + async def test_get_target( + async_vws_client: AsyncVWS, + ) -> None: + """Details of a database are returned by + ``get_database_summary_report``. + """ + report = await async_vws_client.get_database_summary_report() + + assert isinstance(report, DatabaseSummaryReport) + assert report.active_images == 0 + + +class TestGetTargetRecord: + """Tests for getting a record of a target.""" + + @staticmethod + @pytest.mark.asyncio + async def test_get_target_record( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """Details of a target are returned by + ``get_target_record``. + """ + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + result = await async_vws_client.get_target_record( + target_id=target_id, + ) + expected_target_record = TargetRecord( + target_id=target_id, + active_flag=True, + name="x", + width=1, + tracking_rating=-1, + reco_rating="", + ) + + assert result.target_record == expected_target_record + assert result.status == TargetStatuses.PROCESSING + + +class TestWaitForTargetProcessed: + """Tests for waiting for a target to be processed.""" + + @staticmethod + @pytest.mark.asyncio + async def test_wait_for_target_processed( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to wait until a target is processed.""" + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + report = await async_vws_client.get_target_summary_report( + target_id=target_id, + ) + assert report.status == TargetStatuses.PROCESSING + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + report = await async_vws_client.get_target_summary_report( + target_id=target_id, + ) + assert report.status != TargetStatuses.PROCESSING + + @staticmethod + @pytest.mark.asyncio + async def test_custom_timeout( + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to set a maximum timeout.""" + with MockVWS(processing_time_seconds=0.5) as mock: + database = CloudDatabase() + mock.add_cloud_database(cloud_database=database) + async_vws_client = AsyncVWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + ) + + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + with pytest.raises( + expected_exception=(TargetProcessingTimeoutError), + ): + await async_vws_client.wait_for_target_processed( + target_id=target_id, + timeout_seconds=0.1, + ) + + +class TestGetDuplicateTargets: + """Tests for getting duplicate targets.""" + + @staticmethod + @pytest.mark.asyncio + async def test_get_duplicate_targets( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to get the IDs of similar targets.""" + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + similar_target_id = await async_vws_client.add_target( + name="a", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + await async_vws_client.wait_for_target_processed( + target_id=similar_target_id, + ) + duplicates = await async_vws_client.get_duplicate_targets( + target_id=target_id, + ) + assert duplicates == [similar_target_id] + + +class TestUpdateTarget: + """Tests for updating a target.""" + + @staticmethod + @pytest.mark.asyncio + async def test_update_target( + *, + async_vws_client: AsyncVWS, + async_cloud_reco_client: AsyncCloudRecoService, + image: io.BytesIO | BinaryIO, + different_high_quality_image: io.BytesIO, + ) -> None: + """It is possible to update a target.""" + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + [matching_target] = await async_cloud_reco_client.query( + image=image, + ) + assert matching_target.target_id == target_id + query_target_data = matching_target.target_data + assert query_target_data is not None + assert query_target_data.application_metadata is None + + new_name = uuid.uuid4().hex + new_width = 2.0 + new_application_metadata = base64.b64encode( + s=b"a", + ).decode(encoding="ascii") + await async_vws_client.update_target( + target_id=target_id, + name=new_name, + width=new_width, + active_flag=True, + image=different_high_quality_image, + application_metadata=new_application_metadata, + ) + + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + target_details = await async_vws_client.get_target_record( + target_id=target_id, + ) + assert target_details.target_record.name == new_name + assert target_details.target_record.active_flag + + @staticmethod + @pytest.mark.asyncio + async def test_no_fields_given( + *, + async_vws_client: AsyncVWS, + image: io.BytesIO | BinaryIO, + ) -> None: + """It is possible to give no update fields.""" + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=image, + active_flag=True, + application_metadata=None, + ) + await async_vws_client.wait_for_target_processed( + target_id=target_id, + ) + await async_vws_client.update_target( + target_id=target_id, + ) + + +class TestGenerateVumarkInstance: + """Tests for generating VuMark instances.""" + + @staticmethod + @pytest.mark.asyncio + @pytest.mark.parametrize( + argnames=("accept", "expected_prefix"), + argvalues=[ + pytest.param( + VuMarkAccept.PNG, + b"\x89PNG\r\n\x1a\n", + id="png", + ), + pytest.param( + VuMarkAccept.SVG, + b"<", + id="svg", + ), + pytest.param( + VuMarkAccept.PDF, + b"%PDF", + id="pdf", + ), + ], + ) + async def test_generate_vumark_instance( + *, + async_vumark_service_client: AsyncVuMarkService, + vumark_target_id: str, + accept: VuMarkAccept, + expected_prefix: bytes, + ) -> None: + """The returned bytes match the requested format.""" + result = await async_vumark_service_client.generate_vumark_instance( + target_id=vumark_target_id, + instance_id="12345", + accept=accept, + ) + assert result.startswith(expected_prefix) diff --git a/tests/test_async_vws_exceptions.py b/tests/test_async_vws_exceptions.py new file mode 100644 index 00000000..cd35854f --- /dev/null +++ b/tests/test_async_vws_exceptions.py @@ -0,0 +1,304 @@ +"""Tests for VWS exceptions raised from async clients.""" + +import io +import uuid +from http import HTTPStatus + +import pytest +from mock_vws import MockVWS +from mock_vws.database import CloudDatabase +from mock_vws.states import States + +from vws import AsyncVuMarkService, AsyncVWS +from vws.exceptions.custom_exceptions import ( + ServerError, +) +from vws.exceptions.vws_exceptions import ( + AuthenticationFailureError, + BadImageError, + FailError, + ImageTooLargeError, + InvalidInstanceIdError, + MetadataTooLargeError, + ProjectInactiveError, + TargetNameExistError, + TargetStatusProcessingError, + UnknownTargetError, +) +from vws.vumark_accept import VuMarkAccept + + +@pytest.mark.asyncio +async def test_image_too_large( + *, + async_vws_client: AsyncVWS, + png_too_large: io.BytesIO | io.BufferedRandom, +) -> None: + """When giving an image which is too large, an + ``ImageTooLarge`` exception is raised. + """ + with pytest.raises( + expected_exception=ImageTooLargeError, + ) as exc: + await async_vws_client.add_target( + name="x", + width=1, + image=png_too_large, + active_flag=True, + application_metadata=None, + ) + + assert exc.value.response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_invalid_given_id( + async_vws_client: AsyncVWS, +) -> None: + """Giving an invalid ID causes an ``UnknownTarget`` + exception to be raised. + """ + target_id = "12345abc" + with pytest.raises( + expected_exception=UnknownTargetError, + ) as exc: + await async_vws_client.delete_target( + target_id=target_id, + ) + assert exc.value.response.status_code == HTTPStatus.NOT_FOUND + assert exc.value.target_id == target_id + + +@pytest.mark.asyncio +async def test_add_bad_name( + *, + async_vws_client: AsyncVWS, + high_quality_image: io.BytesIO, +) -> None: + """When a name with a bad character is given, a + ``ServerError`` exception is raised. + """ + max_char_value = 65535 + bad_name = chr(max_char_value + 1) + with pytest.raises( + expected_exception=ServerError, + ) as exc: + await async_vws_client.add_target( + name=bad_name, + width=1, + image=high_quality_image, + active_flag=True, + application_metadata=None, + ) + + assert exc.value.response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_fail(high_quality_image: io.BytesIO) -> None: + """A ``Fail`` exception is raised when the server access key + does not exist. + """ + with MockVWS(): + async_vws_client = AsyncVWS( + server_access_key=uuid.uuid4().hex, + server_secret_key=uuid.uuid4().hex, + ) + + with pytest.raises( + expected_exception=FailError, + ) as exc: + await async_vws_client.add_target( + name="x", + width=1, + image=high_quality_image, + active_flag=True, + application_metadata=None, + ) + + assert exc.value.response.status_code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_bad_image( + async_vws_client: AsyncVWS, +) -> None: + """A ``BadImage`` exception is raised when a non-image is + given. + """ + not_an_image = io.BytesIO(initial_bytes=b"Not an image") + with pytest.raises( + expected_exception=BadImageError, + ) as exc: + await async_vws_client.add_target( + name="x", + width=1, + image=not_an_image, + active_flag=True, + application_metadata=None, + ) + + assert exc.value.response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_target_name_exist( + *, + async_vws_client: AsyncVWS, + high_quality_image: io.BytesIO, +) -> None: + """A ``TargetNameExist`` exception is raised after adding + two targets with the same name. + """ + await async_vws_client.add_target( + name="x", + width=1, + image=high_quality_image, + active_flag=True, + application_metadata=None, + ) + with pytest.raises( + expected_exception=TargetNameExistError, + ) as exc: + await async_vws_client.add_target( + name="x", + width=1, + image=high_quality_image, + active_flag=True, + application_metadata=None, + ) + + assert exc.value.response.status_code == HTTPStatus.FORBIDDEN + assert exc.value.target_name == "x" + + +@pytest.mark.asyncio +async def test_project_inactive( + high_quality_image: io.BytesIO, +) -> None: + """A ``ProjectInactive`` exception is raised if adding a + target to an inactive database. + """ + database = CloudDatabase(state=States.PROJECT_INACTIVE) + with MockVWS() as mock: + mock.add_cloud_database(cloud_database=database) + async_vws_client = AsyncVWS( + server_access_key=database.server_access_key, + server_secret_key=database.server_secret_key, + ) + + with pytest.raises( + expected_exception=ProjectInactiveError, + ) as exc: + await async_vws_client.add_target( + name="x", + width=1, + image=high_quality_image, + active_flag=True, + application_metadata=None, + ) + + assert exc.value.response.status_code == HTTPStatus.FORBIDDEN + + +@pytest.mark.asyncio +async def test_target_status_processing( + *, + async_vws_client: AsyncVWS, + high_quality_image: io.BytesIO, +) -> None: + """A ``TargetStatusProcessing`` exception is raised if + trying to delete a target which is processing. + """ + target_id = await async_vws_client.add_target( + name="x", + width=1, + image=high_quality_image, + active_flag=True, + application_metadata=None, + ) + + with pytest.raises( + expected_exception=TargetStatusProcessingError, + ) as exc: + await async_vws_client.delete_target( + target_id=target_id, + ) + + assert exc.value.response.status_code == HTTPStatus.FORBIDDEN + assert exc.value.target_id == target_id + + +@pytest.mark.asyncio +async def test_metadata_too_large( + *, + async_vws_client: AsyncVWS, + high_quality_image: io.BytesIO, +) -> None: + """A ``MetadataTooLarge`` exception is raised if the metadata + given is too large. + """ + with pytest.raises( + expected_exception=MetadataTooLargeError, + ) as exc: + await async_vws_client.add_target( + name="x", + width=1, + image=high_quality_image, + active_flag=True, + application_metadata="a" * 1024 * 1024 * 10, + ) + + assert exc.value.response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_authentication_failure( + high_quality_image: io.BytesIO, +) -> None: + """An ``AuthenticationFailure`` exception is raised when the + server secret key is incorrect. + """ + database = CloudDatabase() + + async_vws_client = AsyncVWS( + server_access_key=database.server_access_key, + server_secret_key=uuid.uuid4().hex, + ) + + with MockVWS() as mock: + mock.add_cloud_database(cloud_database=database) + + with pytest.raises( + expected_exception=AuthenticationFailureError, + ) as exc: + await async_vws_client.add_target( + name="x", + width=1, + image=high_quality_image, + active_flag=True, + application_metadata=None, + ) + + assert exc.value.response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.asyncio +async def test_invalid_instance_id( + *, + async_vumark_service_client: AsyncVuMarkService, + vumark_target_id: str, +) -> None: + """An ``InvalidInstanceId`` exception is raised when an + empty instance ID is given. + """ + with pytest.raises( + expected_exception=InvalidInstanceIdError, + ) as exc: + await async_vumark_service_client.generate_vumark_instance( + target_id=vumark_target_id, + instance_id="", + accept=VuMarkAccept.PNG, + ) + + assert exc.value.response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY diff --git a/tests/test_transports.py b/tests/test_transports.py index 37eaad4e..b8a34f70 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -3,10 +3,12 @@ from http import HTTPStatus import httpx +import pytest import respx +from vws import AsyncCloudRecoService, AsyncVuMarkService, AsyncVWS from vws.response import Response -from vws.transports import HTTPXTransport +from vws.transports import AsyncHTTPXTransport, HTTPXTransport class TestHTTPXTransport: @@ -59,3 +61,146 @@ def test_tuple_timeout() -> None: assert route.called assert isinstance(response, Response) assert response.status_code == HTTPStatus.OK + + +class TestAsyncHTTPXTransport: + """Tests for ``AsyncHTTPXTransport``.""" + + @staticmethod + @pytest.mark.asyncio + @respx.mock + async def test_float_timeout() -> None: + """``AsyncHTTPXTransport`` works with a float timeout.""" + route = respx.post(url="https://example.com/test").mock( + return_value=httpx.Response( + status_code=HTTPStatus.OK, + text="OK", + ), + ) + transport = AsyncHTTPXTransport() + response = await transport( + method="POST", + url="https://example.com/test", + headers={"Content-Type": "text/plain"}, + data=b"hello", + request_timeout=30.0, + ) + assert route.called + assert isinstance(response, Response) + assert response.status_code == HTTPStatus.OK + assert response.text == "OK" + assert response.tell_position == len(b"OK") + + @staticmethod + @pytest.mark.asyncio + @respx.mock + async def test_tuple_timeout() -> None: + """``AsyncHTTPXTransport`` works with a (connect, read) + timeout tuple. + """ + route = respx.post(url="https://example.com/test").mock( + return_value=httpx.Response( + status_code=HTTPStatus.OK, + text="OK", + ), + ) + transport = AsyncHTTPXTransport() + response = await transport( + method="POST", + url="https://example.com/test", + headers={"Content-Type": "text/plain"}, + data=b"hello", + request_timeout=(5.0, 30.0), + ) + assert route.called + assert isinstance(response, Response) + assert response.status_code == HTTPStatus.OK + + @staticmethod + @pytest.mark.asyncio + @respx.mock + async def test_context_manager() -> None: + """``AsyncHTTPXTransport`` can be used as an async context + manager. + """ + route = respx.post(url="https://example.com/test").mock( + return_value=httpx.Response( + status_code=HTTPStatus.OK, + text="OK", + ), + ) + async with AsyncHTTPXTransport() as transport: + response = await transport( + method="POST", + url="https://example.com/test", + headers={"Content-Type": "text/plain"}, + data=b"hello", + request_timeout=30.0, + ) + assert route.called + assert isinstance(response, Response) + assert response.status_code == HTTPStatus.OK + + +class _NoCloseTransport: + """A minimal async transport without ``aclose``.""" + + async def __call__( # pragma: no cover + self, + *, + method: str, + url: str, + headers: dict[str, str], + data: bytes, + request_timeout: float | tuple[float, float], + ) -> Response: + """Not implemented.""" + raise NotImplementedError + + +_DUMMY_KEY = "x" + + +class TestAsyncClientAclose: + """Tests for ``aclose`` on async clients with transports that + lack ``aclose``. + """ + + @staticmethod + @pytest.mark.asyncio + async def test_vws_aclose_no_transport_aclose() -> None: + """``AsyncVWS.aclose`` works when the transport has no + ``aclose``. + """ + client = AsyncVWS( + server_access_key=_DUMMY_KEY, + server_secret_key=_DUMMY_KEY, + transport=_NoCloseTransport(), + ) + await client.aclose() + + @staticmethod + @pytest.mark.asyncio + async def test_cloud_reco_aclose_no_transport_aclose() -> None: + """``AsyncCloudRecoService.aclose`` works when the transport + has no ``aclose``. + """ + client = AsyncCloudRecoService( + client_access_key=_DUMMY_KEY, + client_secret_key=_DUMMY_KEY, + transport=_NoCloseTransport(), + ) + await client.aclose() + + @staticmethod + @pytest.mark.asyncio + async def test_vumark_aclose_no_transport_aclose() -> None: + """``AsyncVuMarkService.aclose`` works when the transport + has no ``aclose``. + """ + client = AsyncVuMarkService( + server_access_key=_DUMMY_KEY, + server_secret_key=_DUMMY_KEY, + transport=_NoCloseTransport(), + ) + await client.aclose()