diff --git a/docs/source/api-reference.rst b/docs/source/api-reference.rst index c94d810e..3fc7441c 100644 --- a/docs/source/api-reference.rst +++ b/docs/source/api-reference.rst @@ -20,3 +20,7 @@ API Reference .. automodule:: vws.response :undoc-members: :members: + +.. automodule:: vws.transports + :undoc-members: + :members: diff --git a/pyproject.toml b/pyproject.toml index 4d14b8f9..4fd78dfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dynamic = [ ] dependencies = [ "beartype>=0.22.9", + "httpx>=0.28.0", "requests>=2.32.3", "urllib3>=2.2.3", "vws-auth-tools>=2024.7.12", @@ -357,6 +358,8 @@ ignore_path = [ # Ideally we would limit the paths to the source code where we want to ignore names, # but Vulture does not enable this. ignore_names = [ + # Public API classes imported by users from vws.transports + "HTTPXTransport", # pytest configuration "pytest_collect_file", "pytest_collection_modifyitems", diff --git a/spelling_private_dict.txt b/spelling_private_dict.txt index 2329ef7f..8831d3d3 100644 --- a/spelling_private_dict.txt +++ b/spelling_private_dict.txt @@ -54,6 +54,7 @@ hmac html http https +httpx iff io issuecomment diff --git a/src/vws/_vws_request.py b/src/vws/_vws_request.py index 2fc7bd3b..797860dc 100644 --- a/src/vws/_vws_request.py +++ b/src/vws/_vws_request.py @@ -2,11 +2,11 @@ API. """ -import requests from beartype import BeartypeConf, beartype from vws_auth_tools import authorization_header, rfc_1123_date from vws.response import Response +from vws.transports import Transport @beartype(conf=BeartypeConf(is_pep484_tower=True)) @@ -21,27 +21,30 @@ def target_api_request( base_vws_url: str, request_timeout_seconds: float | tuple[float, float], extra_headers: dict[str, str], + transport: Transport, ) -> Response: """Make a request to the Vuforia Target API. - This uses `requests` to make a request against https://vws.vuforia.com. - Args: content_type: The content type of the request. server_access_key: A VWS server access key. server_secret_key: A VWS server secret key. - method: The HTTP method which will be used in the request. - data: The request body which will be used in the request. - request_path: The path to the endpoint which will be used in the + method: The HTTP method which will be used in the + request. + data: The request body which will be used in the request. + request_path: The path to the endpoint which will be + used in the request. base_vws_url: The base URL for the VWS API. - request_timeout_seconds: The timeout for the request, as used by - ``requests.request``. This can be a float to set both the - connect and read timeouts, or a (connect, read) tuple. - extra_headers: Additional headers to include in the request. + request_timeout_seconds: The timeout for the request. + This can be a float to set both the connect and + read timeouts, or a (connect, read) tuple. + extra_headers: Additional headers to include in the + request. + transport: The HTTP transport to use for the request. Returns: - The response to the request made by `requests`. + The response to the request. """ date_string = rfc_1123_date() @@ -64,20 +67,10 @@ def target_api_request( url = base_vws_url.rstrip("/") + request_path - requests_response = requests.request( + return transport( method=method, url=url, headers=headers, data=data, timeout=request_timeout_seconds, ) - - return Response( - text=requests_response.text, - url=requests_response.url, - status_code=requests_response.status_code, - headers=dict(requests_response.headers), - request_body=requests_response.request.body, - tell_position=requests_response.raw.tell(), - content=bytes(requests_response.content), - ) diff --git a/src/vws/query.py b/src/vws/query.py index f5046bd1..fcca282d 100644 --- a/src/vws/query.py +++ b/src/vws/query.py @@ -6,7 +6,6 @@ from http import HTTPMethod, HTTPStatus from typing import Any, BinaryIO -import requests from beartype import BeartypeConf, beartype from urllib3.filepost import encode_multipart_formdata from vws_auth_tools import authorization_header, rfc_1123_date @@ -24,7 +23,7 @@ ) from vws.include_target_data import CloudRecoIncludeTargetData from vws.reports import QueryResult, TargetData -from vws.response import Response +from vws.transports import RequestsTransport, Transport _ImageType = io.BytesIO | BinaryIO @@ -50,21 +49,26 @@ def __init__( client_secret_key: str, base_vwq_url: str = "https://cloudreco.vuforia.com", request_timeout_seconds: float | tuple[float, float] = 30.0, + transport: Transport | None = None, ) -> None: """ Args: client_access_key: A VWS client access key. client_secret_key: A VWS client secret key. base_vwq_url: The base URL for the VWQ API. - request_timeout_seconds: The timeout for each HTTP request, as - used by ``requests.request``. This can be a float to set - both the connect and read timeouts, or a (connect, read) - tuple. + request_timeout_seconds: The timeout for each + HTTP request. This can be a float to set both + the connect and read timeouts, or a + (connect, read) tuple. + transport: The HTTP transport to use for + requests. Defaults to + ``RequestsTransport()``. """ self._client_access_key = client_access_key self._client_secret_key = client_secret_key self._base_vwq_url = base_vwq_url self._request_timeout_seconds = request_timeout_seconds + self._transport = transport or RequestsTransport() def query( self, @@ -143,22 +147,13 @@ def query( "Content-Type": content_type_header, } - requests_response = requests.request( + response = self._transport( method=method, url=self._base_vwq_url.rstrip("/") + request_path, headers=headers, data=content, timeout=self._request_timeout_seconds, ) - response = Response( - text=requests_response.text, - url=requests_response.url, - status_code=requests_response.status_code, - headers=dict(requests_response.headers), - request_body=requests_response.request.body, - tell_position=requests_response.raw.tell(), - content=bytes(requests_response.content), - ) if response.status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: raise RequestEntityTooLargeError(response=response) diff --git a/src/vws/transports.py b/src/vws/transports.py new file mode 100644 index 00000000..9e4c550c --- /dev/null +++ b/src/vws/transports.py @@ -0,0 +1,158 @@ +"""HTTP transport implementations for VWS clients.""" + +from typing import Protocol, runtime_checkable + +import httpx +import requests +from beartype import BeartypeConf, beartype + +from vws.response import Response + + +@runtime_checkable +class Transport(Protocol): + """Protocol for HTTP transports used by VWS clients. + + A transport is a callable that makes an HTTP request and + returns a ``Response``. + """ + + def __call__( + self, + *, + method: str, + url: str, + headers: dict[str, str], + data: bytes, + timeout: float | tuple[float, float], + ) -> Response: + """Make an HTTP request. + + Args: + method: The HTTP method (e.g. "GET", "POST"). + url: The full URL to request. + headers: Headers to send with the request. + data: The request body as bytes. + timeout: The timeout for the request. A float + sets both the connect and read timeouts. A + (connect, read) tuple sets them individually. + + Returns: + A Response populated from the HTTP response. + """ + ... # pylint: disable=unnecessary-ellipsis + + +@beartype(conf=BeartypeConf(is_pep484_tower=True)) +class RequestsTransport: + """HTTP transport using the ``requests`` library. + + This is the default transport. + """ + + def __call__( + self, + *, + method: str, + url: str, + headers: dict[str, str], + data: bytes, + timeout: float | tuple[float, float], + ) -> Response: + """Make an HTTP request using ``requests``. + + Args: + method: The HTTP method. + url: The full URL. + headers: Request headers. + data: The request body. + timeout: The request timeout. + + Returns: + A Response populated from the requests response. + """ + requests_response = requests.request( + method=method, + url=url, + headers=headers, + data=data, + timeout=timeout, + ) + + return Response( + text=requests_response.text, + url=requests_response.url, + status_code=requests_response.status_code, + headers=dict(requests_response.headers), + request_body=requests_response.request.body, + tell_position=requests_response.raw.tell(), + content=bytes(requests_response.content), + ) + + +@beartype(conf=BeartypeConf(is_pep484_tower=True)) +class HTTPXTransport: + """HTTP transport using the ``httpx`` library. + + Use this transport for environments where ``httpx`` is + preferred over ``requests``. + """ + + def __call__( + self, + *, + method: str, + url: str, + headers: dict[str, str], + data: bytes, + timeout: float | tuple[float, float], + ) -> Response: + """Make an HTTP request using ``httpx``. + + Args: + method: The HTTP method. + url: The full URL. + headers: Request headers. + data: The request body. + timeout: The request timeout. + + Returns: + A Response populated from the httpx response. + """ + if isinstance(timeout, tuple): + connect_timeout, read_timeout = timeout + httpx_timeout = httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=None, + pool=None, + ) + else: + httpx_timeout = httpx.Timeout( + connect=timeout, + read=timeout, + write=None, + pool=None, + ) + + httpx_response = httpx.request( + method=method, + url=url, + headers=headers, + content=data, + timeout=httpx_timeout, + follow_redirects=True, + ) + + content = bytes(httpx_response.content) + request_content = httpx_response.request.content + + return Response( + text=httpx_response.text, + url=str(object=httpx_response.url), + status_code=httpx_response.status_code, + headers=dict(httpx_response.headers), + request_body=bytes(request_content) or None, + tell_position=len(content), + content=content, + ) diff --git a/src/vws/vumark_service.py b/src/vws/vumark_service.py index a23ac33f..47d68d94 100644 --- a/src/vws/vumark_service.py +++ b/src/vws/vumark_service.py @@ -20,6 +20,7 @@ TooManyRequestsError, UnknownTargetError, ) +from vws.transports import RequestsTransport, Transport from vws.vumark_accept import VuMarkAccept @@ -29,25 +30,31 @@ class VuMarkService: def __init__( self, + *, server_access_key: str, server_secret_key: str, base_vws_url: str = "https://vws.vuforia.com", request_timeout_seconds: float | tuple[float, float] = 30.0, + transport: Transport | None = None, ) -> None: """ Args: server_access_key: A VWS server access key. server_secret_key: A VWS server secret key. base_vws_url: The base URL for the VWS API. - request_timeout_seconds: The timeout for each HTTP request, as - used by ``requests.request``. This can be a float to set - both the connect and read timeouts, or a (connect, read) - tuple. + request_timeout_seconds: The timeout for each + HTTP request. This can be a float to set both + the connect and read timeouts, or a + (connect, read) tuple. + transport: The HTTP transport to use for + requests. Defaults to + ``RequestsTransport()``. """ self._server_access_key = server_access_key self._server_secret_key = server_secret_key self._base_vws_url = base_vws_url self._request_timeout_seconds = request_timeout_seconds + self._transport = transport or RequestsTransport() def generate_vumark_instance( self, @@ -109,6 +116,7 @@ def generate_vumark_instance( base_vws_url=self._base_vws_url, request_timeout_seconds=self._request_timeout_seconds, extra_headers={"Accept": accept}, + transport=self._transport, ) if ( diff --git a/src/vws/vws.py b/src/vws/vws.py index 50b4f5e7..01315e6a 100644 --- a/src/vws/vws.py +++ b/src/vws/vws.py @@ -43,6 +43,7 @@ TargetSummaryReport, ) from vws.response import Response +from vws.transports import RequestsTransport, Transport _ImageType = io.BytesIO | BinaryIO @@ -68,21 +69,26 @@ def __init__( server_secret_key: str, base_vws_url: str = "https://vws.vuforia.com", request_timeout_seconds: float | tuple[float, float] = 30.0, + transport: Transport | None = None, ) -> None: """ Args: server_access_key: A VWS server access key. server_secret_key: A VWS server secret key. base_vws_url: The base URL for the VWS API. - request_timeout_seconds: The timeout for each HTTP request, as - used by ``requests.request``. This can be a float to set - both the connect and read timeouts, or a (connect, read) - tuple. + request_timeout_seconds: The timeout for each + HTTP request. This can be a float to set both + the connect and read timeouts, or a + (connect, read) tuple. + transport: The HTTP transport to use for + requests. Defaults to + ``RequestsTransport()``. """ self._server_access_key = server_access_key self._server_secret_key = server_secret_key self._base_vws_url = base_vws_url self._request_timeout_seconds = request_timeout_seconds + self._transport = transport or RequestsTransport() def make_request( self, @@ -96,29 +102,31 @@ def make_request( ) -> Response: """Make a request to the Vuforia Target API. - This uses `requests` to make a request against Vuforia. - Args: - method: The HTTP method which will be used in the request. - data: The request body which will be used in the request. - request_path: The path to the endpoint which will be used in the + method: The HTTP method which will be used in + the request. + data: The request body which will be used in the request. - expected_result_code: See "VWS API Result Codes" on + request_path: The path to the endpoint which + will be used in the request. + expected_result_code: See + "VWS API Result Codes" on https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api. content_type: The content type of the request. - extra_headers: Additional headers to include in the request. + extra_headers: Additional headers to include in + the request. Returns: - The response to the request made by `requests`. + The response to the request. Raises: - ~vws.exceptions.custom_exceptions.ServerError: There is an error - with Vuforia's servers. - ~vws.exceptions.vws_exceptions.TooManyRequestsError: Vuforia is - rate limiting access. - json.JSONDecodeError: The server did not respond with valid JSON. - This may happen if the server address is not a valid Vuforia - server. + ~vws.exceptions.custom_exceptions.ServerError: + There is an error with Vuforia's servers. + ~vws.exceptions.vws_exceptions.TooManyRequestsError: + Vuforia is rate limiting access. + json.JSONDecodeError: The server did not respond + with valid JSON. This may happen if the + server address is not a valid Vuforia server. """ response = target_api_request( content_type=content_type, @@ -130,6 +138,7 @@ def make_request( base_vws_url=self._base_vws_url, request_timeout_seconds=self._request_timeout_seconds, extra_headers=extra_headers or {}, + transport=self._transport, ) if ( diff --git a/tests/test_transports.py b/tests/test_transports.py new file mode 100644 index 00000000..7b77107c --- /dev/null +++ b/tests/test_transports.py @@ -0,0 +1,61 @@ +"""Tests for HTTP transport implementations.""" + +from http import HTTPStatus + +import httpx +import respx + +from vws.response import Response +from vws.transports import HTTPXTransport + + +class TestHTTPXTransport: + """Tests for ``HTTPXTransport``.""" + + @staticmethod + @respx.mock + def test_float_timeout() -> None: + """``HTTPXTransport`` works with a float timeout.""" + route = respx.post(url="https://example.com/test").mock( + return_value=httpx.Response( + status_code=HTTPStatus.OK, + text="OK", + ), + ) + transport = HTTPXTransport() + response = transport( + method="POST", + url="https://example.com/test", + headers={"Content-Type": "text/plain"}, + data=b"hello", + timeout=30.0, + ) + assert route.called + assert isinstance(response, Response) + assert response.status_code == HTTPStatus.OK + assert response.text == "OK" + assert response.tell_position == len(b"OK") + + @staticmethod + @respx.mock + def test_tuple_timeout() -> None: + """``HTTPXTransport`` works with a (connect, read) timeout + tuple. + """ + route = respx.post(url="https://example.com/test").mock( + return_value=httpx.Response( + status_code=HTTPStatus.OK, + text="OK", + ), + ) + transport = HTTPXTransport() + response = transport( + method="POST", + url="https://example.com/test", + headers={"Content-Type": "text/plain"}, + data=b"hello", + timeout=(5.0, 30.0), + ) + assert route.called + assert isinstance(response, Response) + assert response.status_code == HTTPStatus.OK