From e0f059d77989ed6989b7418aab6f95caf4ec7179 Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Mon, 1 Jun 2026 12:29:46 +0200 Subject: [PATCH 1/2] feat: implement broker with aiobotocore --- .github/workflows/code-check.yml | 2 +- docker-compose.yml | 13 -- pyproject.toml | 12 +- src/taskiq_sqs/aws.py | 28 --- src/taskiq_sqs/broker.py | 329 +++++++++++++++------------- src/taskiq_sqs/constants.py | 8 +- src/taskiq_sqs/exceptions.py | 8 + src/taskiq_sqs/queue.py | 31 +++ tests/conftest.py | 58 ++++- tests/test_broker.py | 10 - tests/test_broker_initialization.py | 23 ++ tests/test_broker_kick.py | 26 +++ tests/test_broker_listen.py | 22 ++ uv.lock | 179 --------------- 14 files changed, 351 insertions(+), 398 deletions(-) delete mode 100644 src/taskiq_sqs/aws.py create mode 100644 src/taskiq_sqs/queue.py delete mode 100644 tests/test_broker.py create mode 100644 tests/test_broker_initialization.py create mode 100644 tests/test_broker_kick.py create mode 100644 tests/test_broker_listen.py diff --git a/.github/workflows/code-check.yml b/.github/workflows/code-check.yml index 39e1b56..c3d3004 100644 --- a/.github/workflows/code-check.yml +++ b/.github/workflows/code-check.yml @@ -28,7 +28,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] # eventually add `windows-latest` and `macos-latest` - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] services: ministack: image: ministackorg/ministack:1.3.53 diff --git a/docker-compose.yml b/docker-compose.yml index ab61b25..9d021f4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,19 +22,6 @@ services: networks: - taskiq-sqs-network - redis: - image: bitnamilegacy/redis:7.4.2 - environment: - ALLOW_EMPTY_PASSWORD: "yes" # pragma: allowlist secret - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 5s - timeout: 5s - retries: 3 - start_period: 10s - ports: - - 6379:6379 - networks: taskiq-sqs-network: driver: bridge diff --git a/pyproject.toml b/pyproject.toml index 73e4619..c11e57b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,6 @@ name = "taskiq-sqs" version = "0.0.11" description = "SQS Broker for TaskIQ" -urls = {source = "https://github.com/taskiq-python/taskiq-sqs"} readme = "README.md" licence = "MIT" license = {file = "LICENSE"} @@ -32,14 +31,13 @@ keywords = ["taskiq", "broker", "aws", "sqs"] requires-python = ">=3.10" dependencies = [ "taskiq>=0.12.1", - "asyncer~=0.0.5", - "boto3~=1.34.34", "aiobotocore>=2.13.3", ] -# [project.urls] -# "Bug Tracker" = "https://github.com/taskiq-python/taskiq-sqs/issues" -# "Repository" = "https://github.com/taskiq-python/taskiq-sqs/" +[project.urls] +"Source" = "https://github.com/taskiq-python/taskiq-sqs" +"Bug Tracker" = "https://github.com/taskiq-python/taskiq-sqs/issues" +"Repository" = "https://github.com/taskiq-python/taskiq-sqs/" [dependency-groups] dev = [ @@ -59,8 +57,6 @@ lint = [ ] types = [ "mypy>=2.1.0", - "mypy-boto3-sqs>=1.34.101", - "boto3-stubs[essential]>=1.34.84", "types-aiobotocore[essential]>=3.7.0", ] examples = [ diff --git a/src/taskiq_sqs/aws.py b/src/taskiq_sqs/aws.py deleted file mode 100644 index b9f2bc6..0000000 --- a/src/taskiq_sqs/aws.py +++ /dev/null @@ -1,28 +0,0 @@ -import asyncio -import json -import os - -import urllib3 # boto3 peer dep (v1) - - -ECS_CONTAINER_METADATA_URI = "http://169.254.170.2" - - -class InvalidEnvironmentError(Exception): - """Exception for cases where credentials to AWS is not present.""" - - -async def get_container_credentials() -> dict[str, str]: - """Fetches the ECS task role credentials provided by the metadata service.""" - if not (relative_uri := os.environ.get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")): - raise InvalidEnvironmentError( - "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI not defined. This may not be an ECS container.", - ) - - http = urllib3.PoolManager() - resp = await asyncio.to_thread( - http.request, - "GET", - f"{ECS_CONTAINER_METADATA_URI}{relative_uri}", - ) - return json.loads(resp.data) diff --git a/src/taskiq_sqs/broker.py b/src/taskiq_sqs/broker.py index 5862440..4dc5edf 100644 --- a/src/taskiq_sqs/broker.py +++ b/src/taskiq_sqs/broker.py @@ -1,211 +1,226 @@ -import asyncio +import contextlib import logging -from collections import defaultdict -from collections.abc import AsyncGenerator, Callable, Mapping +from collections.abc import AsyncGenerator, Awaitable, Callable, Generator from datetime import datetime, timezone from typing import TYPE_CHECKING -import boto3 -from asyncer import asyncify +from aiobotocore.session import get_session from botocore.exceptions import ClientError from taskiq import AsyncBroker -from taskiq.abc.result_backend import AsyncResultBackend from taskiq.acks import AckableMessage from taskiq.message import BrokerMessage -from taskiq_sqs.aws import get_container_credentials +from taskiq_sqs import constants from taskiq_sqs.exceptions import BrokerInitError +from taskiq_sqs.queue import SQSQueue if TYPE_CHECKING: - from mypy_boto3_sqs.service_resource import Queue, SQSServiceResource + from types_aiobotocore_sqs.client import SQSClient + from types_aiobotocore_sqs.type_defs import ( + GetQueueUrlResultTypeDef, + MessageTypeDef, + SendMessageRequestTypeDef, + ) logger = logging.getLogger(__name__) -def stamp() -> int: # noqa: D103 - return int(datetime.now(tz=timezone.utc).timestamp()) - - class SQSBroker(AsyncBroker): """AWS SQS TaskIQ broker.""" - def __init__( # noqa: D107 + def __init__( # noqa: PLR0913 self, - sqs_queue_url: str, - wait_time_seconds: int = 0, # Used for long polling - max_number_of_messages: int = 1, # size of batch to receive from the queue - result_backend: AsyncResultBackend | None = None, - task_id_generator: Callable[[], str] | None = None, - sqs_region_override: str | None = None, + queue_name: str, + endpoint_url: str | None = None, + aws_region_name: str = constants.AWS_DEFAULT_REGION, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + wait_time_seconds: int = 0, + max_number_of_messages: int = 1, force_ecs_container_credentials: bool = False, ) -> None: - super().__init__(result_backend, task_id_generator) - - if not sqs_queue_url or not sqs_queue_url.startswith("http"): - raise BrokerInitError(details="A valid SQS queue url is required") - - # NOTE: This bypasses the normal order of operations for boto3 auth and - # goes straight to using the ECS role creds from the metadata - # service. This can be useful in edge cases where there are higher - # priority credentials you do not want to use for this service. - self.force_ecs_container_credentials = force_ecs_container_credentials - self.sqs_region_override = sqs_region_override - self.sqs_queue_url = sqs_queue_url - self._sqs: SQSServiceResource | None = None - self._sqs_queue: Queue | None = None - self._creds_expiration: datetime | None = None + """Initialize the SQS broker. + + :param: queue_name: The name of the SQS queue. + :param: endpoint_url: The SQS endpoint URL. + :param aws_region_name: The AWS region name. + :param aws_access_key_id: The AWS access key ID. + :param aws_secret_access_key: The AWS secret access key. + :param: wait_time_seconds: The wait time used for long polling. + :param: max_number_of_messages: Size of batch to receive from the queue. + :param: force_ecs_container_credentials: This bypasses the normal order of operations for boto3 auth and + goes straight to using the ECS role creds from the metadata service. This can be useful in edge cases + where there are higher priority credentials you do not want to use for this service. + """ + super().__init__() + + self._aws_region = aws_region_name + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._aws_endpoint_url = endpoint_url + + self._session = get_session() + self._startup_called = False + + self._sqs_queue_url: str | None = None + + if max_number_of_messages > constants.MAX_NUMBER_OF_MESSAGES or max_number_of_messages < 1: + raise BrokerInitError(details="MaxNumberOfMessages can be no greater than 10 or less than 1") + self._max_number_of_messages = max_number_of_messages - if max_number_of_messages > 10: # noqa: PLR2004 - raise BrokerInitError(details="MaxNumberOfMessages can be no greater than 10") + if wait_time_seconds > constants.MAX_WAIT_TIME_SECONDS or wait_time_seconds < 0: + raise BrokerInitError(details="WaitTimeSeconds can be no greater than 20 or less than 0") + self._wait_time_seconds = wait_time_seconds - self.wait_time_seconds = max(wait_time_seconds, 0) - self.max_number_of_messages = max(max_number_of_messages, 1) + try: + self._default_queue: SQSQueue = SQSQueue( + name=queue_name, + max_number_of_messages=self._max_number_of_messages, + wait_time_seconds=self._wait_time_seconds, + ) + except ValueError as error: + raise BrokerInitError(details="Invalid default queue configuration.") from error + + self._force_ecs_container_credentials = force_ecs_container_credentials + self._creds_expiration: datetime | None = None + + @contextlib.contextmanager + def _handle_exceptions(self) -> Generator[None, None, None]: + """Handle exceptions raised by the SQS client.""" + try: + yield + except ClientError as e: + error = e.response.get("Error", {}) + code = error.get("Code") + error_message = error.get("Message") + if code == "AWS.SimpleQueueService.NonExistentQueue": + raise BrokerInitError( + details=f"Queue not found {self._default_queue.name}", + ) from e + elif code in ["InvalidParameterValue", "NoSuchBucket"]: + raise BrokerInitError(details=error_message or "") from e + else: + raise BrokerInitError(details=code or "") from e @property def _sqs_credentials_expired(self) -> datetime | bool | None: return self._creds_expiration and self._creds_expiration < datetime.now(tz=timezone.utc) - async def _sqs_client(self) -> "SQSServiceResource": - if self._sqs and not self._sqs_credentials_expired: - return self._sqs + async def _get_sqs_client(self) -> "SQSClient": + self._client_context_creator = self._session.create_client( + "sqs", + region_name=self._aws_region, + endpoint_url=self._aws_endpoint_url, + aws_access_key_id=self._aws_access_key_id, + aws_secret_access_key=self._aws_secret_access_key, + ) + return await self._client_context_creator.__aenter__() + + async def _close_client(self) -> None: + """Closes the SQS/S3 client.""" + await self._client_context_creator.__aexit__(None, None, None) - creds: Mapping[str, str] = defaultdict(None) + async def _get_queue_url(self) -> str: + if not self._sqs_queue_url: + with self._handle_exceptions(): + queue_result: GetQueueUrlResultTypeDef = await self._sqs_client.get_queue_url( + QueueName=self._default_queue.name, + ) + self._sqs_queue_url = queue_result["QueueUrl"] + return self._sqs_queue_url - if self.force_ecs_container_credentials: - creds = await get_container_credentials() - # NOTE: This is probably not an optional prop in the response - if creds.get("Expiration"): - self._creds_expiration = datetime.fromisoformat(creds["Expiration"]) + async def startup(self) -> None: + """Starts the SQS broker and checks that queue exists.""" + self._startup_called = True + self._sqs_client = await self._get_sqs_client() - return boto3.resource( - "sqs", - region_name=self.sqs_region_override, - aws_access_key_id=creds.get("AccessKeyId"), - aws_secret_access_key=creds.get("SecretAccessKey"), - aws_session_token=creds.get("Token"), - ) + queue_url = await self._get_queue_url() + logger.info("Resolved queue '%s' URL: %s", self._default_queue.name, queue_url) - async def _get_queue(self) -> "Queue": - if self._sqs_queue and not self._sqs_credentials_expired: - return self._sqs_queue + await super().startup() - sqs = await self._sqs_client() - self._sqs_queue = await asyncify(sqs.get_queue_by_name)( - QueueName=self.sqs_queue_url.split("/")[-1], - ) + async def shutdown(self) -> None: + """Shuts down the SQS broker.""" + await self._close_client() + await super().shutdown() - if not self._sqs_queue: - raise BrokerInitError(details="SQS queue not found") + async def _build_kick_kwargs( + self, + message: BrokerMessage, + ) -> "SendMessageRequestTypeDef": + """Build the kwargs for the SQS client kick method. - return self._sqs_queue + This function can be extended by the end user to + add additional kwargs in the message delivery. + :param message: BrokerMessage object. + """ + kwargs: SendMessageRequestTypeDef = { + "QueueUrl": await self._get_queue_url(), + "MessageBody": message.message.decode("utf-8"), + } + return kwargs - async def kick( + async def _send_message( self, message: BrokerMessage, ) -> None: - """This method is used to kick tasks out from current program. + """Send a single message. - Using this method tasks are sent to - workers. + :param message: + """ + kwargs = await self._build_kick_kwargs(message) + with self._handle_exceptions(): + await self._sqs_client.send_message(**kwargs) - You don't need to send broker message. It's helper for brokers, - please send only bytes from message.message. + async def kick(self, message: BrokerMessage) -> None: + """Kick tasks out from current program to configured SQS queue. - :param message: name of a task. + :param message: BrokerMessage object. """ - queue = await self._get_queue() - # Must be explicitly set as a label to a unix timestamp - expiry = message.labels.pop("sqs_expiry", 0) + await self._send_message(message) - try: - await asyncify(queue.send_message)( - # SQS structured message attributes - MessageAttributes={ - "expiry": { - "StringValue": str(expiry), - "DataType": "Number", - }, - }, - MessageBody=message.message.decode("utf-8"), - MessageGroupId=message.task_name, - ) - except Exception: - # taskiq suppresses the original exception, but it wold be good to know about - logger.exception("Unhandled exception in SQSBroker") - raise + def _build_ack_function( + self, + queue_url: str, + receipt_handle: str, + ) -> Callable[[], Awaitable[None]]: + """ + This method is used to build an ack for the message. - async def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: - """This function listens to new messages and yields them. + :param queue_url: queue url where the message is located + :param receipt_handle: message to build ack for. + """ - This it the main point for workers. - This function is used to get new tasks from the network. + async def ack() -> None: + with self._handle_exceptions(): + await self._sqs_client.delete_message( + QueueUrl=queue_url, + ReceiptHandle=receipt_handle, + ) - If your broker support acknowledgement, then you - should wrap your message in AckableMessage dataclass. + return ack - If your messages was wrapped in AckableMessage dataclass, - taskiq will call ack when finish processing message. + async def listen(self) -> AsyncGenerator[AckableMessage, None]: + """ + This function listens to new messages and yields them. - :yield: incoming messages. - :return: nothing. + :yield: incoming AckableMessages. """ - # TODO: Consider using AckableMessage and confirm with the queue to reduce lost messages + queue_url = await self._get_queue_url() + while True: - no_backoff = False - queue = await self._get_queue() - - try: - for message in await asyncify(queue.receive_messages)( - MessageAttributeNames=[".*"], - # If there's competition on this queue (multiple processes of workers pulling - # from the same queue), and processing takes longer than the visibility timeout, - # multiple workers may end up processing the same message. - MaxNumberOfMessages=self.max_number_of_messages, - # Use long poling. - WaitTimeSeconds=self.wait_time_seconds, - ): - try: - if message.message_attributes and (expiry_typed := message.message_attributes.get("expiry")): - expiry = int(expiry_typed.get("StringValue", 0)) - now = stamp() - if 0 < expiry < now: - logger.warning( - "Message expired %s seconds ago. Skipping.", - now - expiry, - ) - await asyncify(message.delete)() - no_backoff = True - continue - except TypeError: - # Ignore weird expiries. Not critical. - pass - - yield message.body.encode("utf-8") - - try: - await asyncify(message.delete)() - except ClientError as err: - if "receipt handle has expired" in str(err): - # while not ideal, we shouldn't die on this - logger.exception( - "Message receipt handle has expired. This could indicate duplicate" - "processing or tasks being processed late.", - ) - else: - raise - - no_backoff = True - except ClientError as err: - # Creds will get refreshed when _get_queue() is called again - if "ExpiredToken" in str(err): - logger.warning("ECS credentials expired.") - continue - else: - raise - - sleepdur = 0.01 if no_backoff else 1 - logger.debug("No messages on queue. Broker is sleeping for %d seconds...", sleepdur) - await asyncio.sleep(sleepdur) - no_backoff = False + results = await self._sqs_client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=self._max_number_of_messages, + WaitTimeSeconds=self._wait_time_seconds, + ) + messages: list[MessageTypeDef] = results["Messages"] + + for message in messages: + if (body := message.get("Body")) and (receipt_handle := message.get("ReceiptHandle")): + yield AckableMessage( + data=body.encode("utf-8"), + ack=self._build_ack_function(queue_url, receipt_handle), + ) diff --git a/src/taskiq_sqs/constants.py b/src/taskiq_sqs/constants.py index 7e9175a..1a14e5d 100644 --- a/src/taskiq_sqs/constants.py +++ b/src/taskiq_sqs/constants.py @@ -1 +1,7 @@ -AWS_DEFAULT_REGION = "us-east-1" +from typing import Final + + +AWS_DEFAULT_REGION: Final[str] = "us-east-1" + +MAX_WAIT_TIME_SECONDS: Final[int] = 20 +MAX_NUMBER_OF_MESSAGES: Final[int] = 10 diff --git a/src/taskiq_sqs/exceptions.py b/src/taskiq_sqs/exceptions.py index b731162..1c8c864 100644 --- a/src/taskiq_sqs/exceptions.py +++ b/src/taskiq_sqs/exceptions.py @@ -12,6 +12,13 @@ class BrokerInitError(BaseTaskiqSQSError): details: str +class InvalidEnvironmentError(BaseTaskiqSQSError): + """Error in case something wrong with environment variables.""" + + __template__ = "Something wrong with env: {details}" + details: str + + class ResultBackendError(BaseTaskiqSQSError): """Base error for all taskiq-aio-sqs broker exceptions.""" @@ -25,6 +32,7 @@ class BucketNotFoundError(BaseTaskiqSQSError): __template__ = "Bucket '{bucket_name}' not found during initialization and declare=False" bucket_name: str + class ResultIsMissingError(BaseTaskiqSQSError): """Error if there is no result when we trying to get it.""" diff --git a/src/taskiq_sqs/queue.py b/src/taskiq_sqs/queue.py new file mode 100644 index 0000000..920b9ee --- /dev/null +++ b/src/taskiq_sqs/queue.py @@ -0,0 +1,31 @@ +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(slots=True, kw_only=True, frozen=True) +class SQSQueue: + """Per-queue SQS configuration for SQSBroker. + + Attributes: + name: The SQS queue name (or "queue-name.fifo" for FIFO queues). + is_fifo: Whether this is a FIFO queue (default: False). + max_number_of_messages: Maximum messages to retrieve per poll (1-10, + default: 1). + wait_time_seconds: Long polling wait time in seconds (0-20, default: 0). + visibility_timeout: Optional visibility timeout (in seconds) for received + messages. While a message is being processed, it remains invisible to + other consumers. + options: Optional mapping of additional SQS queue attributes. + """ + + name: str + max_number_of_messages: int = 1 + wait_time_seconds: int = 0 + options: Mapping[str, Any] = field(default_factory=dict) + + def __str__(self) -> str: # noqa: D105 + return self.name + + def __hash__(self) -> int: # noqa: D105 + return hash(self.name) diff --git a/tests/conftest.py b/tests/conftest.py index bf4cc56..f22286e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ +import uuid from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, TypedDict import pytest from aiobotocore.session import get_session +from taskiq import BrokerMessage +from types_aiobotocore_sqs.client import SQSClient -from taskiq_sqs import S3ResultBackend +from taskiq_sqs import S3ResultBackend, SQSBroker from taskiq_sqs.bucket import S3Bucket @@ -13,6 +16,7 @@ ENDPOINT_URL = "http://localhost:4566" TEST_BUCKET = "test-bucket" +QUEUE_NAME = "test-queue" class AWSCredentials(TypedDict): @@ -79,3 +83,55 @@ async def s3_backend( assert backend._s3_client yield backend await backend.shutdown() + + +@pytest.fixture +async def sqs_client(aws_credentials: AWSCredentials) -> AsyncGenerator[SQSClient, Any]: + client_context = get_session().create_client( + "sqs", + endpoint_url=aws_credentials["endpoint_url"], + aws_access_key_id=aws_credentials["aws_access_key_id"], + aws_secret_access_key=aws_credentials["aws_secret_access_key"], + region_name=aws_credentials["aws_region_name"], + ) + yield await client_context.__aenter__() + await client_context.__aexit__(None, None, None) + + +@pytest.fixture +async def sqs_queue(sqs_client: SQSClient) -> AsyncGenerator[str, Any]: + queue_name = f"{QUEUE_NAME}-{uuid.uuid4().hex}" + response = await sqs_client.create_queue(QueueName=queue_name) + queue_url = response["QueueUrl"] + yield queue_url + await sqs_client.delete_queue(QueueUrl=queue_url) + + +def _queue_name_from_url(queue_url: str) -> str: + return queue_url.rsplit("/", maxsplit=1)[-1] + + +@pytest.fixture +async def sqs_broker( + aws_credentials: AWSCredentials, + sqs_queue: str, +) -> AsyncGenerator[SQSBroker, Any]: + broker = SQSBroker( + queue_name=_queue_name_from_url(sqs_queue), + **aws_credentials, + ) + await broker.startup() + assert broker._sqs_client + assert broker._sqs_queue_url + yield broker + await broker.shutdown() + + +@pytest.fixture +def broker_message() -> BrokerMessage: + return BrokerMessage( + task_id="test_task", + task_name="test_task", + message=b"test_message", + labels={}, + ) diff --git a/tests/test_broker.py b/tests/test_broker.py deleted file mode 100644 index 8f5f0c2..0000000 --- a/tests/test_broker.py +++ /dev/null @@ -1,10 +0,0 @@ -from taskiq_sqs import SQSBroker - - -class TestInitParameters: - async def test_initialization_logic(self) -> None: - broker = SQSBroker("http://localhost:4566/000000000000/my-queue") - assert broker.sqs_queue_url == "http://localhost:4566/000000000000/my-queue" - assert broker.force_ecs_container_credentials is False - assert broker.sqs_region_override is None - assert broker._sqs_queue is None diff --git a/tests/test_broker_initialization.py b/tests/test_broker_initialization.py new file mode 100644 index 0000000..07d8d5b --- /dev/null +++ b/tests/test_broker_initialization.py @@ -0,0 +1,23 @@ +import pytest + +from tests.conftest import AWSCredentials + +from taskiq_sqs import SQSBroker +from taskiq_sqs.exceptions import BrokerInitError + + +@pytest.mark.asyncio +async def test_get_queue_url_client_error(aws_credentials: AWSCredentials) -> None: + broker = SQSBroker(queue_name="nonexistent-queue", **aws_credentials) + with pytest.raises(BrokerInitError): + await broker.startup() + + +@pytest.mark.asyncio +async def test_max_number_of_messages_error(aws_credentials: AWSCredentials) -> None: + with pytest.raises(BrokerInitError): + SQSBroker( + queue_name="nonexistent-queue", + max_number_of_messages=15, + **aws_credentials, + ) diff --git a/tests/test_broker_kick.py b/tests/test_broker_kick.py new file mode 100644 index 0000000..d13079a --- /dev/null +++ b/tests/test_broker_kick.py @@ -0,0 +1,26 @@ +import pytest +from taskiq import BrokerMessage + +from taskiq_sqs import SQSBroker +from taskiq_sqs.exceptions import BrokerInitError + + +async def test_when_kick_called__than_message_should_be_published_to_queue( + sqs_broker: SQSBroker, + sqs_queue: str, + broker_message: BrokerMessage, +) -> None: + await sqs_broker.kick(broker_message) + response = await sqs_broker._sqs_client.receive_message(QueueUrl=sqs_queue) + assert "Messages" in response + assert len(response["Messages"]) == 1 + assert response["Messages"][0]["Body"] == "test_message" + + +async def test_when_during_kick_queue_not_found__then_should_raise_an_error( + sqs_broker: SQSBroker, + broker_message: BrokerMessage, +) -> None: + sqs_broker._sqs_queue_url = "nonexistent-queue" + with pytest.raises(BrokerInitError): + await sqs_broker.kick(broker_message) diff --git a/tests/test_broker_listen.py b/tests/test_broker_listen.py new file mode 100644 index 0000000..f6b9e17 --- /dev/null +++ b/tests/test_broker_listen.py @@ -0,0 +1,22 @@ +from taskiq_sqs import SQSBroker + + +async def test_when_listen__than_we_should_delete_message_from_queue( + sqs_broker: SQSBroker, sqs_queue: str, +) -> None: + await sqs_broker._sqs_client.send_message( + QueueUrl=sqs_queue, + MessageBody="test_message", + ) + + messages = [] + async for message in sqs_broker.listen(): + messages.append(message) + await message.ack() + break + + assert len(messages) == 1 + assert messages[0].data == b"test_message" + + response = await sqs_broker._sqs_client.receive_message(QueueUrl=sqs_queue) + assert "Messages" not in response diff --git a/uv.lock b/uv.lock index 33983ff..f4cdc9f 100644 --- a/uv.lock +++ b/uv.lock @@ -244,20 +244,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, ] -[[package]] -name = "asyncer" -version = "0.0.17" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.15'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d2/4c/62b6044679e08788322bbd0dee5b487a6f7f60bb4e2bd45617ff0d94d1e3/asyncer-0.0.17.tar.gz", hash = "sha256:8a41e185e7ec2ecd583c269d72907a0f9f832e744b6c7474aeb21e349c4becf4", size = 19516, upload-time = "2026-02-21T16:35:54.068Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/c5/b72735a095b4b3170b34150e89314a40dd0450d0eb6746b331cb664479d1/asyncer-0.0.17-py3-none-any.whl", hash = "sha256:b0055950e094fb84fd8d21611c7e7b6f5715ddcb57c522c058f64c20badd1438", size = 9252, upload-time = "2026-02-21T16:35:55.022Z" }, -] - [[package]] name = "attrs" version = "26.1.0" @@ -291,45 +277,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/a4/a26d5b25671d27e03afb5401a0be5899d94ff8fab6a698b1ac5be3ec29ef/bandit-1.9.4-py3-none-any.whl", hash = "sha256:f89ffa663767f5a0585ea075f01020207e966a9c0f2b9ef56a57c7963a3f6f8e", size = 134741, upload-time = "2026-02-25T06:44:13.694Z" }, ] -[[package]] -name = "boto3" -version = "1.34.162" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "botocore" }, - { name = "jmespath" }, - { name = "s3transfer" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/74/c1/f80cfbe564c89cdb080cd9ac2079ce05a2fac869bf8fbc45929ed3190da9/boto3-1.34.162.tar.gz", hash = "sha256:873f8f5d2f6f85f1018cbb0535b03cceddc7b655b61f66a0a56995238804f41f", size = 108585, upload-time = "2024-08-15T19:25:38.714Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/65/41/faa5081761be3bac3999f912996c14c4dc9d06eab86c234bd6441f54bd64/boto3-1.34.162-py3-none-any.whl", hash = "sha256:d6f6096bdab35a0c0deff469563b87d184a28df7689790f7fe7be98502b7c590", size = 139174, upload-time = "2024-08-15T19:25:35.384Z" }, -] - -[[package]] -name = "boto3-stubs" -version = "1.34.162" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "botocore-stubs" }, - { name = "types-s3transfer" }, - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0b/6d/e8a5eebf5bf6f495b10d14e10af04290f731d6375e75c53166a307502dda/boto3_stubs-1.34.162.tar.gz", hash = "sha256:6d60b7b9652e1c99f3caba00779e1b94ba7062b0431147a00543af8b1f5252f4", size = 88880, upload-time = "2024-08-15T19:32:52.51Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/13/d5/ce007019b026d96667278fa42d753315f6721bd615dbeac4ca91180a20f3/boto3_stubs-1.34.162-py3-none-any.whl", hash = "sha256:47c651272782a2e894082087eeaeb87a7e809e7e282748560cf39c155031abef", size = 56738, upload-time = "2024-08-15T19:32:49.269Z" }, -] - -[package.optional-dependencies] -essential = [ - { name = "mypy-boto3-cloudformation" }, - { name = "mypy-boto3-dynamodb" }, - { name = "mypy-boto3-ec2" }, - { name = "mypy-boto3-lambda" }, - { name = "mypy-boto3-rds" }, - { name = "mypy-boto3-s3" }, - { name = "mypy-boto3-sqs" }, -] - [[package]] name = "botocore" version = "1.34.162" @@ -828,90 +775,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/2a/13ca1f292f6db1b98ff495ef3467736b331621c5917cad984b7043e7348d/mypy-2.1.0-py3-none-any.whl", hash = "sha256:a663814603a5c563fb87a4f96fb473eeb30d1f5a4885afcf44f9db000a366289", size = 2693302, upload-time = "2026-05-11T18:31:29.246Z" }, ] -[[package]] -name = "mypy-boto3-cloudformation" -version = "1.34.111" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/59/c3/f48efbcc17fb03fb167993028889be0bfbb582720e3eaa719786c5c53085/mypy_boto3_cloudformation-1.34.111.tar.gz", hash = "sha256:a02e201d1a9d9a8fb4db5b942d5c537a4e8861c611f0d986126674ac557cb9e8", size = 57941, upload-time = "2024-05-22T19:32:33.697Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/f3/c18601b0f21080c4b6183d924ae0052bce0d792ef97b1cbdccfb6d535313/mypy_boto3_cloudformation-1.34.111-py3-none-any.whl", hash = "sha256:526e928c504fa2880b1774aa10629a04fe0ec70ed2864ab3d3f7772386a1a925", size = 70105, upload-time = "2024-05-22T19:32:26.298Z" }, -] - -[[package]] -name = "mypy-boto3-dynamodb" -version = "1.34.148" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/03/8b/79ce3b60d347b3c9a1571b587a36f7c0c042ece1d39377a07cefbf072a4d/mypy_boto3_dynamodb-1.34.148.tar.gz", hash = "sha256:c85489b92cbbbe4f6997070372022df914d4cb8eb707fdc73aa18ce6ba25c578", size = 50077, upload-time = "2024-07-24T19:47:48.859Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/74/2a13f51997fd06b403e0109ded7976a451b76ef1569b8803b41c6b47139e/mypy_boto3_dynamodb-1.34.148-py3-none-any.whl", hash = "sha256:f1a7aabff5c6e926b9b272df87251c9d6dfceb4c1fb159fb5a2df52062cd7e87", size = 60366, upload-time = "2024-07-24T19:47:42.264Z" }, -] - -[[package]] -name = "mypy-boto3-ec2" -version = "1.34.159" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/01/7f/3f1bb06b88168e6c27487e8b42d09d1093322cdfd47d1fbcfce206ef25f3/mypy_boto3_ec2-1.34.159.tar.gz", hash = "sha256:b9badb833dd01e2076c445b3b8609ec4842221620dc8f701dc146b8ceff05283", size = 408995, upload-time = "2024-08-12T19:33:55.996Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/85/b27eeea58de39738e2954b455b95383c988974c6dda006da0e640087d431/mypy_boto3_ec2-1.34.159-py3-none-any.whl", hash = "sha256:d155c4295cd38750bf50adf9540951187f8f05800cd6e6b8fd2058ff0eeccfb4", size = 401497, upload-time = "2024-08-12T19:33:52.374Z" }, -] - -[[package]] -name = "mypy-boto3-lambda" -version = "1.34.77" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/21/b3/49903a707edbe4f67ba5c4876288041e3841553c3c99b0172bc6682c18f5/mypy-boto3-lambda-1.34.77.tar.gz", hash = "sha256:7b81d2a5604fb592e92fe0b284ecd259de071703360a33b71c9b54df46d81c9c", size = 43125, upload-time = "2024-04-03T19:33:10.645Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4a/f1/52da9a9148885ba3e3b9f0b49a2f6c638a1ef997397bdbf26e1547bb5cb6/mypy_boto3_lambda-1.34.77-py3-none-any.whl", hash = "sha256:e21022d2eef12aa731af80790410afdba9412b056339823252813bae2adbf553", size = 50206, upload-time = "2024-04-03T19:32:53.355Z" }, -] - -[[package]] -name = "mypy-boto3-rds" -version = "1.34.152" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/37/c1/3e03c44ba84906be50786f324bc0488ee5bf36d60906e225611a7d3cfae6/mypy_boto3_rds-1.34.152.tar.gz", hash = "sha256:a3e25da87116e4b7ec4f1419a35fd3c7491f1cf631d9467cc835bc9c5c23fabe", size = 94343, upload-time = "2024-08-01T19:48:31.582Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/52/78/ea038ad37f311c3e0ea9ddf82f9b209dbcbb7b4721a7c7667f550384da24/mypy_boto3_rds-1.34.152-py3-none-any.whl", hash = "sha256:71106812e6e6a89daa99f9b4534c580456336373e6dccb45b652e4d221c7beea", size = 101026, upload-time = "2024-08-01T19:48:01.465Z" }, -] - -[[package]] -name = "mypy-boto3-s3" -version = "1.34.162" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/88/5a/432a718a7472b83191b3b58064472ca75f7b014eff0252985e45b2542463/mypy_boto3_s3-1.34.162.tar.gz", hash = "sha256:7e2fbda0fbd97a17a172a503bade7c4a2615d5ebf6fa532c274b8020bb3c6894", size = 75773, upload-time = "2024-08-15T19:32:44.858Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/21/f4b6928de4ed0097ff0d69f114d9dff7a343e8a6ace3745866420d0bcaa4/mypy_boto3_s3-1.34.162-py3-none-any.whl", hash = "sha256:c7ab11369041a62c7d7f4c6dd1d3aab53470339df4b1e1da94df88914c25be29", size = 83945, upload-time = "2024-08-15T19:32:41.795Z" }, -] - -[[package]] -name = "mypy-boto3-sqs" -version = "1.34.121" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/81/75/83be93078eb1e78f7f7d04e33c2c7e313c81420660d3fb7a28522487e07a/mypy_boto3_sqs-1.34.121.tar.gz", hash = "sha256:bdbc623235ffc8127cb8753f49323f74a919df552247b0b2caaf85cf9bb495b8", size = 22200, upload-time = "2024-06-06T19:34:01.045Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/18/96a8c01ff1b9663810263a06f8b84bca8ecca18da75b23bedd9f01f1b1a6/mypy_boto3_sqs-1.34.121-py3-none-any.whl", hash = "sha256:e92aefacfa08e7094b79002576ef261e4075f5af9c25219fc47fb8452f53fc5f", size = 33040, upload-time = "2024-06-06T19:33:53.264Z" }, -] - [[package]] name = "mypy-extensions" version = "1.1.0" @@ -1404,18 +1267,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/b2/920464c907b191e37469d477a1aa8bc048b8f36c4c1610dfa4ab87b39e18/ruff-0.15.15-py3-none-win_arm64.whl", hash = "sha256:3c8ceca6792f38196b8f589bc92eccd03eef286602da92e5dc05cc42ef6441b7", size = 11138498, upload-time = "2026-05-28T14:16:38.425Z" }, ] -[[package]] -name = "s3transfer" -version = "0.10.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "botocore" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/0a/1cdbabf9edd0ea7747efdf6c9ab4e7061b085aa7f9bfc36bb1601563b069/s3transfer-0.10.4.tar.gz", hash = "sha256:29edc09801743c21eb5ecbc617a152df41d3c287f67b615f73e5f750583666a7", size = 145287, upload-time = "2024-11-20T21:06:05.981Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/66/05/7957af15543b8c9799209506df4660cba7afc4cf94bfb60513827e96bed6/s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e", size = 83175, upload-time = "2024-11-20T21:06:03.961Z" }, -] - [[package]] name = "six" version = "1.17.0" @@ -1425,15 +1276,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, -] - [[package]] name = "stevedore" version = "5.8.0" @@ -1476,17 +1318,13 @@ version = "0.0.11" source = { editable = "." } dependencies = [ { name = "aiobotocore" }, - { name = "asyncer" }, - { name = "boto3" }, { name = "taskiq" }, ] [package.dev-dependencies] dev = [ { name = "bandit" }, - { name = "boto3-stubs", extra = ["essential"] }, { name = "mypy" }, - { name = "mypy-boto3-sqs" }, { name = "prek" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -1507,26 +1345,20 @@ test = [ { name = "pytest-asyncio" }, ] types = [ - { name = "boto3-stubs", extra = ["essential"] }, { name = "mypy" }, - { name = "mypy-boto3-sqs" }, { name = "types-aiobotocore", extra = ["essential"] }, ] [package.metadata] requires-dist = [ { name = "aiobotocore", specifier = ">=2.13.3" }, - { name = "asyncer", specifier = "~=0.0.5" }, - { name = "boto3", specifier = "~=1.34.34" }, { name = "taskiq", specifier = ">=0.12.1" }, ] [package.metadata.requires-dev] dev = [ { name = "bandit", specifier = ">=1.9.4" }, - { name = "boto3-stubs", extras = ["essential"], specifier = ">=1.34.84" }, { name = "mypy", specifier = ">=2.1.0" }, - { name = "mypy-boto3-sqs", specifier = ">=1.34.101" }, { name = "prek", specifier = ">=0.4.3" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-asyncio", specifier = ">=0.23.8" }, @@ -1545,9 +1377,7 @@ test = [ { name = "pytest-asyncio", specifier = ">=0.23.8" }, ] types = [ - { name = "boto3-stubs", extras = ["essential"], specifier = ">=1.34.84" }, { name = "mypy", specifier = ">=2.1.0" }, - { name = "mypy-boto3-sqs", specifier = ">=1.34.101" }, { name = "types-aiobotocore", extras = ["essential"], specifier = ">=3.7.0" }, ] @@ -1722,15 +1552,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/a6/704fbf052edf2497d09292e5f9555b400594924f1154d6c86f2173c2fc11/types_awscrt-0.33.0-py3-none-any.whl", hash = "sha256:95adb57388e1cacc6e7e96fb7ddc735e60096a6151930640bdbe496d9400493a", size = 45687, upload-time = "2026-05-25T06:56:02.549Z" }, ] -[[package]] -name = "types-s3transfer" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/64/42689150509eb3e6e82b33ee3d89045de1592488842ddf23c56957786d05/types_s3transfer-0.16.0.tar.gz", hash = "sha256:b4636472024c5e2b62278c5b759661efeb52a81851cde5f092f24100b1ecb443", size = 13557, upload-time = "2025-12-08T08:13:09.928Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/27/e88220fe6274eccd3bdf95d9382918716d312f6f6cef6a46332d1ee2feff/types_s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:1c0cd111ecf6e21437cb410f5cddb631bfb2263b77ad973e79b9c6d0cb24e0ef", size = 19247, upload-time = "2025-12-08T08:13:08.426Z" }, -] - [[package]] name = "typing-extensions" version = "4.15.0" From cd021e4cf62623d73621d0e0348bb59fc9073f05 Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Mon, 1 Jun 2026 13:04:10 +0200 Subject: [PATCH 2/2] fix: update example and fix issue in cases there is no messages --- examples/example_broker.py | 36 +++++++++++++++++++++++++++--------- src/taskiq_sqs/broker.py | 2 +- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/examples/example_broker.py b/examples/example_broker.py index 1dc6984..1659ac4 100644 --- a/examples/example_broker.py +++ b/examples/example_broker.py @@ -8,8 +8,8 @@ import asyncio -import boto3 import dotenv +from aiobotocore.session import get_session from taskiq_sqs import S3Bucket, S3ResultBackend, SQSBroker @@ -17,24 +17,42 @@ dotenv.load_dotenv() QUEUE_NAME = "my-queue" -QUEUE_URL = f"http://localhost:4566/000000000000/{QUEUE_NAME}" - - -boto3.client("sqs").create_queue(QueueName=QUEUE_NAME) - -broker = SQSBroker(QUEUE_URL, sqs_region_override="us-east-1").with_result_backend( - S3ResultBackend(bucket=S3Bucket(name="response-bucket")) +ENDPOINT_URL = "http://localhost:4566" +AWS_REGION = "us-east-1" + + +broker = SQSBroker( + queue_name=QUEUE_NAME, + endpoint_url=ENDPOINT_URL, + aws_region_name=AWS_REGION, +).with_result_backend( + S3ResultBackend( + bucket=S3Bucket(name="response-bucket"), + endpoint_url=ENDPOINT_URL, + aws_region_name=AWS_REGION, + ), ) @broker.task() async def i_love_aws() -> None: """I hope my cloud bill doesn't get too high!""" - await asyncio.sleep(5.5) + await asyncio.sleep(2) print("Hello there!") +async def ensure_queue_exists() -> None: + session = get_session() + async with session.create_client( + "sqs", + region_name=AWS_REGION, + endpoint_url=ENDPOINT_URL, + ) as sqs: + await sqs.create_queue(QueueName=QUEUE_NAME) + + async def main() -> None: + await ensure_queue_exists() await broker.startup() task = await i_love_aws.kiq() print(await task.wait_result()) diff --git a/src/taskiq_sqs/broker.py b/src/taskiq_sqs/broker.py index 4dc5edf..dfcbbd9 100644 --- a/src/taskiq_sqs/broker.py +++ b/src/taskiq_sqs/broker.py @@ -216,7 +216,7 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]: MaxNumberOfMessages=self._max_number_of_messages, WaitTimeSeconds=self._wait_time_seconds, ) - messages: list[MessageTypeDef] = results["Messages"] + messages: list[MessageTypeDef] = results.get("Messages", []) for message in messages: if (body := message.get("Body")) and (receipt_handle := message.get("ReceiptHandle")):