diff --git a/.changes/next-release/feature-s3-15737.json b/.changes/next-release/feature-s3-15737.json new file mode 100644 index 000000000000..d0352f3c0692 --- /dev/null +++ b/.changes/next-release/feature-s3-15737.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "``s3``", + "description": "Improve S3 performance for listing objects in transfer tasks (`#10293 `__)" +} diff --git a/awscli/customizations/s3/bucketlister.py b/awscli/customizations/s3/bucketlister.py new file mode 100644 index 000000000000..2c9985d02c3c --- /dev/null +++ b/awscli/customizations/s3/bucketlister.py @@ -0,0 +1,441 @@ +# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import io +import logging +import threading +import xml.sax.handler +from collections import namedtuple +from dataclasses import dataclass +from typing import Optional +from xml.sax.handler import ContentHandler + +from dateutil.parser import parse +from dateutil.tz import tzlocal + +from awscli.compat import queue + +LOGGER = logging.getLogger(__name__) + + +_THREADED_BUCKET_LISTER_COMPLETE = object() +_ThreadedBucketListerError = namedtuple( + '_ThreadedBucketListerError', ['exception'] +) +_ThreadedBucketPage = namedtuple('_ThreadedBucketPage', ['contents']) + + +def _date_parser(date_string, tzinfo): + return parse(date_string).astimezone(tzinfo) + + +class _StopParsing(Exception): + pass + + +class BucketLister: + """List keys in a bucket.""" + + def __init__(self, client, date_parser=_date_parser): + self._client = client + self._date_parser = date_parser + self._local_tz = tzlocal() + + def _get_list_objects_v2_paginator_kwargs( + self, bucket, prefix=None, page_size=None, extra_args=None + ): + kwargs = { + 'Bucket': bucket, + 'PaginationConfig': {'PageSize': page_size}, + } + if prefix is not None: + kwargs['Prefix'] = prefix + if extra_args is not None: + kwargs.update(extra_args) + return kwargs + + def _yield_page_contents(self, bucket, contents): + if not contents: + return + for content in contents: + source_path = bucket + '/' + content['Key'] + content['LastModified'] = self._date_parser( + content['LastModified'], + self._local_tz, + ) + yield source_path, content + + def list_objects( + self, bucket, prefix=None, page_size=None, extra_args=None + ): + kwargs = self._get_list_objects_v2_paginator_kwargs( + bucket=bucket, + prefix=prefix, + page_size=page_size, + extra_args=extra_args, + ) + paginator = self._client.get_paginator('list_objects_v2') + pages = paginator.paginate(**kwargs) + for page in pages: + yield from self._yield_page_contents( + bucket=bucket, + contents=page.get('Contents', []), + ) + + +class ThreadedBucketLister(BucketLister): + """List keys in a bucket using a background producer thread.""" + + _BUFFER_WAIT_SECONDS = 0.1 + _MAX_PAGES_BUFFER = 10 + + def list_objects( + self, bucket, prefix=None, page_size=None, extra_args=None + ): + request_kwargs = self._get_list_objects_v2_request_kwargs( + bucket=bucket, + prefix=prefix, + page_size=page_size, + extra_args=extra_args, + ) + page_queue = queue.Queue(maxsize=self._MAX_PAGES_BUFFER) + stop_event = threading.Event() + producer = threading.Thread( + target=self._run_producer, + kwargs={ + 'request_kwargs': request_kwargs, + 'page_queue': page_queue, + 'stop_event': stop_event, + }, + name='ThreadedBucketLister', + daemon=True, + ) + producer.start() + try: + while True: + next_item = page_queue.get(True) + if next_item is _THREADED_BUCKET_LISTER_COMPLETE: + break + if isinstance(next_item, _ThreadedBucketListerError): + raise next_item.exception + yield from self._yield_page_contents( + bucket=bucket, + contents=next_item.contents, + ) + finally: + stop_event.set() + producer.join() + + def _get_list_objects_v2_request_kwargs( + self, bucket, prefix=None, page_size=None, extra_args=None + ): + kwargs = {'Bucket': bucket} + if prefix is not None: + kwargs['Prefix'] = prefix + if page_size is not None: + kwargs['MaxKeys'] = page_size + if extra_args is not None: + kwargs.update(extra_args) + return kwargs + + def _put_page_queue_item(self, page_queue, stop_event, item): + # In Python 3.13, we have queue.shutdown() to avoid having to poll + # with a timeout like we do below. Until that's the min version + # supported, we need to handle this ourselves and avoid non-timeout + # put() calls. + while not stop_event.is_set(): + try: + page_queue.put(item, timeout=self._BUFFER_WAIT_SECONDS) + return True + except queue.Full: + continue + return False + + def _run_producer(self, request_kwargs, page_queue, stop_event): + quick_pager = _QuickPageListObjectsV2( + self._client, + stop_event=stop_event, + ) + unprocessed_pages = {} + next_page_number = 1 + try: + quick_pager.start_pagination(request_kwargs) + while not stop_event.is_set(): + available_page = quick_pager.next_completed_page( + timeout=self._BUFFER_WAIT_SECONDS + ) + if available_page is None: + continue + unprocessed_pages[available_page.page_number] = ( + available_page + ) + while next_page_number in unprocessed_pages: + next_page = unprocessed_pages.pop(next_page_number) + if next_page.exception is not None: + raise next_page.exception + page = next_page.page_response + if stop_event.is_set(): + return + contents = page.get('Contents', []) + next_page_number += 1 + if contents: + if not self._put_page_queue_item( + page_queue, + stop_event, + _ThreadedBucketPage(contents=contents), + ): + return + is_last_page = not page.get( + 'IsTruncated' + ) or not page.get('NextContinuationToken') + if is_last_page: + self._put_page_queue_item( + page_queue, + stop_event, + _THREADED_BUCKET_LISTER_COMPLETE, + ) + return + except Exception as e: + if not stop_event.is_set(): + self._put_page_queue_item( + page_queue, + stop_event, + _ThreadedBucketListerError(exception=e), + ) + finally: + quick_pager.shutdown() + + +@dataclass +class ListObjectsV2PageTask: + page_number: int + request_kwargs: dict + task_queue: queue.Queue + next_task_queue: queue.Queue + quick_page_scheduled: bool = False + + def create_next_page_request(self, next_request_kwargs): + return ListObjectsV2PageTask( + page_number=self.page_number + 1, + request_kwargs=next_request_kwargs, + task_queue=self.next_task_queue, + next_task_queue=self.task_queue, + quick_page_scheduled=False, + ) + + +@dataclass +class ListObjectsV2PageResponse: + page_number: int + # Either page_response or exception will be non-None. + page_response: Optional[dict] = None + exception: Optional[Exception] = None + + +class _QuickPageListObjectsV2: + _BEFORE_PARSE_EVENT = 'before-parse.s3.ListObjectsV2' + _REQUEST_WORKER_COMPLETE = object() + _BUFFER_WAIT_SECONDS = 0.1 + _MAX_PAGES_BUFFER = 10 + + def __init__(self, client, stop_event): + self._client = client + self._stop_event = stop_event + self._task_queues = [queue.Queue(), queue.Queue()] + self._complete_page_queue = queue.Queue(maxsize=self._MAX_PAGES_BUFFER) + self._shutdown_triggered = False + self._thread_local = threading.local() + self._before_parse_unique_id = ( + f'awscli-threaded-bucket-lister-prefetch-before-parse-{id(self)}' + ) + self._threads = [ + threading.Thread( + target=self._thread_task_handler, + args=(self._task_queues[0],), + name='ThreadedBucketListerRequestA', + daemon=True, + ), + threading.Thread( + target=self._thread_task_handler, + args=(self._task_queues[1],), + name='ThreadedBucketListerRequestB', + daemon=True, + ), + ] + + def start_pagination(self, request_kwargs): + self._client.meta.events.register( + self._BEFORE_PARSE_EVENT, + self._on_before_parse, + unique_id=self._before_parse_unique_id, + ) + for thread in self._threads: + thread.start() + self._task_queues[0].put( + ListObjectsV2PageTask( + page_number=1, + request_kwargs=request_kwargs, + task_queue=self._task_queues[0], + next_task_queue=self._task_queues[1], + ) + ) + + def next_completed_page(self, timeout=None): + try: + return self._complete_page_queue.get(timeout=timeout) + except queue.Empty: + return None + + def _put_completed_page(self, completed_page): + while not self._stop_event.is_set(): + try: + self._complete_page_queue.put( + completed_page, + timeout=self._BUFFER_WAIT_SECONDS, + ) + return True + except queue.Full: + continue + return False + + def _thread_task_handler(self, task_queue): + while True: + task = task_queue.get() + if task is self._REQUEST_WORKER_COMPLETE: + return + if self._stop_event.is_set(): + return + self._process_list_objects_task(task) + + def _process_list_objects_task(self, task): + self._thread_local.current_task = task + try: + page = self._client.list_objects_v2(**task.request_kwargs) + except Exception as e: + self._put_completed_page( + ListObjectsV2PageResponse( + page_number=task.page_number, + page_response=None, + exception=e, + ) + ) + return + if not self._put_completed_page( + ListObjectsV2PageResponse( + page_number=task.page_number, + page_response=page, + ) + ): + return + if not task.quick_page_scheduled: + if not page.get('IsTruncated'): + return + next_continuation_token = page.get('NextContinuationToken') + if next_continuation_token is None: + return + next_request_kwargs = task.request_kwargs.copy() + next_request_kwargs['ContinuationToken'] = next_continuation_token + next_page_task = task.create_next_page_request(next_request_kwargs) + task.next_task_queue.put(next_page_task) + + def _on_before_parse(self, response_dict, **kwargs): + if self._shutdown_triggered or self._stop_event.is_set(): + return + task = getattr(self._thread_local, 'current_task', None) + if task is None: + return + if response_dict.get('status_code', 0) >= 300: + return + body = response_dict.get('body') + if not isinstance(body, bytes): + return + try: + next_token, is_truncated = _extract_next_continuation_token(body) + except Exception: + LOGGER.debug( + 'Unable to extract NextContinuationToken for background ' + 'prefetch.', + exc_info=True, + ) + return + if not is_truncated or next_token is None: + return + next_request_kwargs = task.request_kwargs.copy() + next_request_kwargs['ContinuationToken'] = next_token + next_page_task = task.create_next_page_request(next_request_kwargs) + task.next_task_queue.put(next_page_task) + task.quick_page_scheduled = True + + def _queue_completion_tasks(self): + for task_queue in self._task_queues: + task_queue.put(self._REQUEST_WORKER_COMPLETE) + + def shutdown(self): + if not self._shutdown_triggered: + self._shutdown_triggered = True + self._stop_event.set() + self._client.meta.events.unregister( + self._BEFORE_PARSE_EVENT, + self._on_before_parse, + unique_id=self._before_parse_unique_id, + ) + self._queue_completion_tasks() + for thread in self._threads: + thread.join() + + +class _NextContinuationTokenHandler(ContentHandler): + def __init__(self): + self.next_continuation_token = None + self.is_truncated = None + self._current_element = None + self._text_parts = [] + + def startElement(self, name, attrs): + del attrs + if name in ('NextContinuationToken', 'IsTruncated'): + self._current_element = name + self._text_parts = [] + + def characters(self, content): + if self._current_element is not None: + self._text_parts.append(content) + + def endElement(self, name): + if name != self._current_element: + return + + text = ''.join(self._text_parts) + self._current_element = None + self._text_parts = [] + + if name == 'NextContinuationToken': + self.next_continuation_token = text + self.is_truncated = True + raise _StopParsing() + + if name == 'IsTruncated': + self.is_truncated = text.lower() == 'true' + if self.is_truncated is False: + raise _StopParsing() + + +def _extract_next_continuation_token(body): + handler = _NextContinuationTokenHandler() + parser = xml.sax.make_parser() + parser.setFeature(xml.sax.handler.feature_namespaces, False) + parser.setContentHandler(handler) + try: + parser.parse(io.BytesIO(body)) + except _StopParsing: + pass + return handler.next_continuation_token, handler.is_truncated is True diff --git a/awscli/customizations/s3/constants.py b/awscli/customizations/s3/constants.py index 8f65bb7a07aa..8990f25a2706 100644 --- a/awscli/customizations/s3/constants.py +++ b/awscli/customizations/s3/constants.py @@ -15,3 +15,7 @@ AUTO_RESOLVE_TRANSFER_CLIENT = 'auto' CLASSIC_TRANSFER_CLIENT = 'classic' CRT_TRANSFER_CLIENT = 'crt' + +# Constants for bucket_lister configuration +SINGLE_BUCKET_LISTER = 'single' +THREADED_BUCKET_LISTER = 'threaded' diff --git a/awscli/customizations/s3/factory.py b/awscli/customizations/s3/factory.py index 50ba57294400..3f4229d4369f 100644 --- a/awscli/customizations/s3/factory.py +++ b/awscli/customizations/s3/factory.py @@ -15,6 +15,7 @@ import awscrt.s3 from botocore.client import Config from botocore.httpsession import DEFAULT_CA_BUNDLE +from botocore.parsers import ResponseParserFactory from s3transfer.crt import ( BotocoreCRTCredentialsWrapper, BotocoreCRTRequestSerializer, @@ -33,11 +34,41 @@ LOGGER = logging.getLogger(__name__) +def _identity(value): + return value + + class ClientFactory: + _RESPONSE_PARSER_FACTORY_COMPONENT = 'response_parser_factory' + def __init__(self, session): self._session = session def create_client(self, params, is_source_client=False): + create_client_kwargs = self._get_client_kwargs( + params, is_source_client=is_source_client + ) + return self._session.create_client('s3', **create_client_kwargs) + + def create_listing_client(self, params, is_source_client=False): + original_factory = self._session.get_component( + self._RESPONSE_PARSER_FACTORY_COMPONENT + ) + listing_factory = ResponseParserFactory() + listing_factory.set_parser_defaults(timestamp_parser=_identity) + self._session.register_component( + self._RESPONSE_PARSER_FACTORY_COMPONENT, listing_factory + ) + try: + return self.create_client( + params, is_source_client=is_source_client + ) + finally: + self._session.register_component( + self._RESPONSE_PARSER_FACTORY_COMPONENT, original_factory + ) + + def _get_client_kwargs(self, params, is_source_client=False): create_client_kwargs = {'verify': params['verify_ssl']} if params.get('sse') == 'aws:kms': create_client_kwargs['config'] = Config(signature_version='s3v4') @@ -50,7 +81,7 @@ def create_client(self, params, is_source_client=False): create_client_kwargs['region_name'] = region create_client_kwargs['endpoint_url'] = endpoint_url - return self._session.create_client('s3', **create_client_kwargs) + return create_client_kwargs class TransferManagerFactory: diff --git a/awscli/customizations/s3/filegenerator.py b/awscli/customizations/s3/filegenerator.py index 088c0b7381eb..0c5ee0f0b0ce 100644 --- a/awscli/customizations/s3/filegenerator.py +++ b/awscli/customizations/s3/filegenerator.py @@ -19,9 +19,9 @@ from dateutil.tz import tzlocal from awscli.compat import queue +from awscli.customizations.s3.bucketlister import BucketLister from awscli.customizations.s3.utils import ( EPOCH_TIME, - BucketLister, create_warning, find_bucket_key, find_dest_path_comp_key, @@ -134,6 +134,7 @@ class FileGenerator: under the same common prefix. The generator yields corresponding ``FileInfo`` objects to send to a ``Comparator`` or ``S3Handler``. """ + _DEFAULT_BUCKET_LISTER_CLS = BucketLister def __init__( self, @@ -143,8 +144,11 @@ def __init__( page_size=None, result_queue=None, request_parameters=None, + listing_client=None, + bucket_lister_cls=None, ): self._client = client + self._listing_client = listing_client self.operation_name = operation_name self.follow_symlinks = follow_symlinks self.page_size = page_size @@ -154,6 +158,9 @@ def __init__( self.request_parameters = {} if request_parameters is not None: self.request_parameters = request_parameters + if bucket_lister_cls is None: + bucket_lister_cls = self._DEFAULT_BUCKET_LISTER_CLS + self._bucket_lister_cls = bucket_lister_cls def call(self, files): """ @@ -355,7 +362,9 @@ def list_objects(self, s3_path, dir_op): if not dir_op and prefix: yield self._list_single_object(s3_path) else: - lister = BucketLister(self._client) + lister = self._bucket_lister_cls( + self._listing_client or self._client + ) extra_args = self.request_parameters.get('ListObjectsV2', {}) for key in lister.list_objects( bucket=bucket, diff --git a/awscli/customizations/s3/subcommands.py b/awscli/customizations/s3/subcommands.py index 03630f872cee..d5efabf08b76 100644 --- a/awscli/customizations/s3/subcommands.py +++ b/awscli/customizations/s3/subcommands.py @@ -14,7 +14,6 @@ import os import sys -from botocore.client import Config from botocore.useragent import register_feature_id from botocore.utils import ensure_boolean, is_s3express_bucket from dateutil.parser import parse @@ -23,7 +22,11 @@ from awscli.compat import queue from awscli.customizations.commands import BasicCommand from awscli.customizations.exceptions import ParamValidationError -from awscli.customizations.s3 import transferconfig +from awscli.customizations.s3 import constants, transferconfig +from awscli.customizations.s3.bucketlister import ( + BucketLister, + ThreadedBucketLister, +) from awscli.customizations.s3.comparator import Comparator from awscli.customizations.s3.factory import ( ClientFactory, @@ -58,6 +61,12 @@ LOGGER = logging.getLogger(__name__) +_BUCKET_LISTERS = { + constants.SINGLE_BUCKET_LISTER: BucketLister, + constants.THREADED_BUCKET_LISTER: ThreadedBucketLister, +} + + RECURSIVE = { 'name': 'recursive', 'action': 'store_true', @@ -1063,11 +1072,18 @@ def _run_main(self, parsed_args, parsed_globals): register_feature_id('S3_TRANSFER') self._convert_path_args(parsed_args) params = self._get_params(parsed_args, parsed_globals, self._session) - source_client, transfer_client = self._get_source_and_transfer_clients( - params=params - ) + ( + source_client, + transfer_client, + source_listing_client, + destination_listing_client, + ) = self._get_source_and_transfer_clients(params=params) + runtime_config = self._get_runtime_config() + bucket_lister_cls = _BUCKET_LISTERS[runtime_config['bucket_lister']] transfer_manager = self._get_transfer_manager( - params=params, botocore_transfer_client=transfer_client + params=params, + botocore_transfer_client=transfer_client, + runtime_config=runtime_config, ) cmd = CommandArchitecture( self._session, @@ -1076,6 +1092,9 @@ def _run_main(self, parsed_args, parsed_globals): transfer_manager, source_client, transfer_client, + source_listing_client, + destination_listing_client, + bucket_lister_cls, ) cmd.create_instructions() return cmd.run() @@ -1109,10 +1128,22 @@ def _get_source_and_transfer_clients(self, params): params, is_source_client=True ) transfer_client = client_factory.create_client(params) - return source_client, transfer_client + source_listing_client = client_factory.create_listing_client( + params, is_source_client=True + ) + destination_listing_client = client_factory.create_listing_client( + params + ) + return ( + source_client, + transfer_client, + source_listing_client, + destination_listing_client, + ) - def _get_transfer_manager(self, params, botocore_transfer_client): - runtime_config = self._get_runtime_config() + def _get_transfer_manager( + self, params, botocore_transfer_client, runtime_config + ): return TransferManagerFactory(self._session).create_transfer_manager( params=params, runtime_config=runtime_config, @@ -1367,6 +1398,9 @@ def __init__( transfer_manager, source_client, transfer_client, + source_listing_client, + destination_listing_client, + bucket_lister_cls, ): self.session = session self.cmd = cmd @@ -1375,6 +1409,9 @@ def __init__( self._transfer_manager = transfer_manager self._source_client = source_client self._client = transfer_client + self._source_listing_client = source_listing_client + self._destination_listing_client = destination_listing_client + self._bucket_lister_cls = bucket_lister_cls def create_instructions(self): """ @@ -1472,17 +1509,21 @@ def run(self): fgen_kwargs = { 'client': self._source_client, + 'listing_client': self._source_listing_client, 'operation_name': operation_name, 'follow_symlinks': self.parameters['follow_symlinks'], 'page_size': self.parameters['page_size'], 'result_queue': result_queue, + 'bucket_lister_cls': self._bucket_lister_cls, } rgen_kwargs = { 'client': self._client, + 'listing_client': self._destination_listing_client, 'operation_name': '', 'follow_symlinks': self.parameters['follow_symlinks'], 'page_size': self.parameters['page_size'], 'result_queue': result_queue, + 'bucket_lister_cls': self._bucket_lister_cls, } fgen_request_parameters = ( diff --git a/awscli/customizations/s3/transferconfig.py b/awscli/customizations/s3/transferconfig.py index 5502ea93b0cc..5ad6cc2831ed 100644 --- a/awscli/customizations/s3/transferconfig.py +++ b/awscli/customizations/s3/transferconfig.py @@ -35,6 +35,7 @@ 'should_stream': None, 'disk_throughput': None, 'direct_io': None, + 'bucket_lister': constants.SINGLE_BUCKET_LISTER, } @@ -68,7 +69,11 @@ class RuntimeConfig: constants.AUTO_RESOLVE_TRANSFER_CLIENT, constants.CLASSIC_TRANSFER_CLIENT, constants.CRT_TRANSFER_CLIENT, - ] + ], + 'bucket_lister': [ + constants.THREADED_BUCKET_LISTER, + constants.SINGLE_BUCKET_LISTER, + ], } CHOICE_ALIASES = { 'preferred_transfer_client': { diff --git a/awscli/customizations/s3/utils.py b/awscli/customizations/s3/utils.py index 2f63a1f09d2d..a0f643192dce 100644 --- a/awscli/customizations/s3/utils.py +++ b/awscli/customizations/s3/utils.py @@ -19,7 +19,6 @@ from collections import deque, namedtuple from datetime import datetime -from dateutil.parser import parse from dateutil.tz import tzlocal, tzutc from awscli.compat import bytes_print, queue @@ -401,41 +400,6 @@ class SetFileUtimeError(Exception): pass -def _date_parser(date_string): - return parse(date_string).astimezone(tzlocal()) - - -class BucketLister: - """List keys in a bucket.""" - - def __init__(self, client, date_parser=_date_parser): - self._client = client - self._date_parser = date_parser - - def list_objects( - self, bucket, prefix=None, page_size=None, extra_args=None - ): - kwargs = { - 'Bucket': bucket, - 'PaginationConfig': {'PageSize': page_size}, - } - if prefix is not None: - kwargs['Prefix'] = prefix - if extra_args is not None: - kwargs.update(extra_args) - - paginator = self._client.get_paginator('list_objects_v2') - pages = paginator.paginate(**kwargs) - for page in pages: - contents = page.get('Contents', []) - for content in contents: - source_path = bucket + '/' + content['Key'] - content['LastModified'] = self._date_parser( - content['LastModified'] - ) - yield source_path, content - - class PrintTask( namedtuple('PrintTask', ['message', 'error', 'total_parts', 'warning']) ): diff --git a/awscli/topics/s3-config.rst b/awscli/topics/s3-config.rst index 945c4d1e1535..45af3379e653 100644 --- a/awscli/topics/s3-config.rst +++ b/awscli/topics/s3-config.rst @@ -37,6 +37,8 @@ command set: and downloading data to and from Amazon S3. * ``io_chunksize`` - The maximum size of read parts that can be queued in-memory to be written for a download. +* ``bucket_lister`` - The bucket listing implementation to use when discovering + S3 objects for transfer commands. For experimental ``s3`` configuration values, see the the `Experimental Configuration Values <#experimental-configuration-values>`__ @@ -77,6 +79,7 @@ configuration:: multipart_threshold = 64MB multipart_chunksize = 16MB max_bandwidth = 50MB/s + bucket_lister = threaded use_accelerate_endpoint = true addressing_style = path @@ -93,6 +96,7 @@ could instead run these commands:: $ aws configure set default.s3.multipart_threshold 64MB $ aws configure set default.s3.multipart_chunksize 16MB $ aws configure set default.s3.max_bandwidth 50MB/s + $ aws configure set default.s3.bucket_lister threaded $ aws configure set default.s3.use_accelerate_endpoint true $ aws configure set default.s3.addressing_style path @@ -239,6 +243,22 @@ In cases where network IO is the bottleneck, it is recommended to configure ``max_concurrent_requests`` instead. +bucket_lister +------------- + +**Default** - ``single`` + +Determines the bucket listing implementation to use when the AWS CLI discovers +S3 objects for transfer commands. Valid choices are: + +* ``single`` - Use a single-threaded bucket lister. + This is the default behavior. + +* ``threaded`` - Use background producer-consumer threads to retrieve pages of + objects from S3. This may speed up transfer commands that list a large + number of objects including recursive downloads, sync, and S3 to S3 copies. + + use_accelerate_endpoint ----------------------- diff --git a/tests/unit/customizations/s3/test_bucketlister.py b/tests/unit/customizations/s3/test_bucketlister.py new file mode 100644 index 000000000000..06427402cbe6 --- /dev/null +++ b/tests/unit/customizations/s3/test_bucketlister.py @@ -0,0 +1,478 @@ +# Copyright 2013 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import threading + +from botocore.hooks import HierarchicalEmitter +from dateutil.parser import parse +from dateutil.tz import tzlocal + +from awscli.customizations.s3.bucketlister import ( + BucketLister, + ThreadedBucketLister, +) +from awscli.testutils import mock, unittest + + +class BaseBucketListTest: + LISTER_CLS = None + + def setUp(self): + self.client = mock.Mock() + self.emitter = HierarchicalEmitter() + self.client.meta.events = self.emitter + self.date_parser = mock.Mock() + self.date_parser.return_value = mock.sentinel.now + self.responses = [] + self._response_index = 0 + self.client.get_paginator.return_value.paginate.side_effect = ( + self.fake_paginate + ) + self.client.list_objects_v2.side_effect = self.fake_list_objects_v2 + + def fake_paginate(self, *args, **kwargs): + return self.responses + + def fake_list_objects_v2(self, **kwargs): + del kwargs + if self._response_index >= len(self.responses): + raise AssertionError('No more ListObjectsV2 responses configured') + response = self.responses[self._response_index] + self._response_index += 1 + return response + + def create_lister(self): + return self.LISTER_CLS(self.client, self.date_parser) + + def test_list_objects(self): + now = mock.sentinel.now + individual_response_elements = [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'a', + 'Size': 1, + }, + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'b', + 'Size': 2, + }, + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'c', + 'Size': 3, + }, + ] + self.responses = [ + {'Contents': individual_response_elements[0:2]}, + {'Contents': [individual_response_elements[2]]}, + ] + lister = self.create_lister() + objects = list(lister.list_objects(bucket='foo')) + self.assertEqual( + objects, + [ + ('foo/a', individual_response_elements[0]), + ('foo/b', individual_response_elements[1]), + ('foo/c', individual_response_elements[2]), + ], + ) + for individual_response in individual_response_elements: + self.assertEqual(individual_response['LastModified'], now) + + def test_list_objects_passes_in_extra_args(self): + self.responses = [ + { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'mykey', + 'Size': 3, + } + ] + } + ] + lister = self.create_lister() + list( + lister.list_objects( + bucket='mybucket', extra_args={'RequestPayer': 'requester'} + ) + ) + self.client.get_paginator.return_value.paginate.assert_called_with( + Bucket='mybucket', + PaginationConfig={'PageSize': None}, + RequestPayer='requester', + ) + + def test_list_objects_uses_local_tz_aware_datetimes_by_default(self): + timestamp = '2014-02-27T04:20:38.000Z' + self.responses = [ + { + 'Contents': [ + { + 'LastModified': timestamp, + 'Key': 'mykey', + 'Size': 3, + } + ] + } + ] + lister = self.LISTER_CLS(self.client) + + objects = list(lister.list_objects(bucket='mybucket')) + + last_modified = objects[0][1]['LastModified'] + self.assertEqual(last_modified, parse(timestamp).astimezone(tzlocal())) + self.assertIsNotNone(last_modified.tzinfo) + + +class TestBucketList(BaseBucketListTest, unittest.TestCase): + LISTER_CLS = BucketLister + + +class TestThreadedBucketList(BaseBucketListTest, unittest.TestCase): + LISTER_CLS = ThreadedBucketLister + + def _emit_before_parse(self, body): + self.emitter.emit( + 'before-parse.s3.ListObjectsV2', + operation_model=None, + response_dict={ + 'body': body, + 'headers': {}, + 'status_code': 200, + }, + customized_response_dict={}, + ) + + def test_list_objects(self): + now = mock.sentinel.now + individual_response_elements = [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'a', + 'Size': 1, + }, + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'b', + 'Size': 2, + }, + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'c', + 'Size': 3, + }, + ] + self.responses = [ + { + 'Contents': individual_response_elements[0:2], + 'IsTruncated': True, + 'NextContinuationToken': 'token-2', + }, + { + 'Contents': [individual_response_elements[2]], + 'IsTruncated': False, + }, + ] + lister = self.create_lister() + objects = list(lister.list_objects(bucket='foo')) + self.assertEqual( + objects, + [ + ('foo/a', individual_response_elements[0]), + ('foo/b', individual_response_elements[1]), + ('foo/c', individual_response_elements[2]), + ], + ) + for individual_response in individual_response_elements: + self.assertEqual(individual_response['LastModified'], now) + + def test_list_objects_passes_in_extra_args(self): + self.responses = [ + { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'mykey', + 'Size': 3, + } + ] + } + ] + lister = self.create_lister() + list( + lister.list_objects( + bucket='mybucket', extra_args={'RequestPayer': 'requester'} + ) + ) + self.client.list_objects_v2.assert_called_once_with( + Bucket='mybucket', RequestPayer='requester' + ) + + def test_list_objects_uses_page_size_as_max_keys(self): + self.responses = [ + { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'mykey', + 'Size': 3, + } + ] + } + ] + lister = self.create_lister() + list(lister.list_objects(bucket='mybucket', page_size=25)) + self.client.list_objects_v2.assert_called_once_with( + Bucket='mybucket', MaxKeys=25 + ) + + def test_list_objects_prefetches_pages_in_background(self): + page_two_requested = threading.Event() + + def list_objects_v2(**kwargs): + continuation_token = kwargs.get('ContinuationToken') + if continuation_token is None: + self._emit_before_parse( + b'' + b'' + b'token-2' + b'' + b'true' + b'' + ) + self.assertTrue(page_two_requested.wait(timeout=1)) + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'a', + 'Size': 1, + }, + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'b', + 'Size': 2, + }, + ], + 'IsTruncated': True, + 'NextContinuationToken': 'token-2', + } + self.assertEqual(continuation_token, 'token-2') + page_two_requested.set() + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'c', + 'Size': 3, + } + ], + 'IsTruncated': False, + } + + self.client.list_objects_v2.side_effect = list_objects_v2 + objects = list( + ThreadedBucketLister( + self.client, self.date_parser + ).list_objects(bucket='foo') + ) + + self.assertTrue(page_two_requested.is_set()) + self.assertEqual( + objects, + [ + ('foo/a', mock.ANY), + ('foo/b', mock.ANY), + ('foo/c', mock.ANY), + ], + ) + + def test_list_objects_prefetches_pages_from_parsed_page(self): + page_two_requested = threading.Event() + allow_page_two = threading.Event() + + def list_objects_v2(**kwargs): + continuation_token = kwargs.get('ContinuationToken') + if continuation_token is None: + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'a', + 'Size': 1, + }, + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'b', + 'Size': 2, + }, + ], + 'IsTruncated': True, + 'NextContinuationToken': 'token-2', + } + self.assertEqual(continuation_token, 'token-2') + page_two_requested.set() + allow_page_two.wait(timeout=1) + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'c', + 'Size': 3, + } + ], + 'IsTruncated': False, + } + + self.client.list_objects_v2.side_effect = list_objects_v2 + objects = ThreadedBucketLister( + self.client, self.date_parser + ).list_objects(bucket='foo') + + try: + self.assertEqual(next(objects), ('foo/a', mock.ANY)) + self.assertTrue(page_two_requested.wait(timeout=1)) + allow_page_two.set() + self.assertEqual( + list(objects), + [ + ('foo/b', mock.ANY), + ('foo/c', mock.ANY), + ], + ) + finally: + allow_page_two.set() + objects.close() + + def test_list_objects_propagates_background_exception(self): + class BackgroundError(Exception): + pass + + def list_objects_v2(**kwargs): + if 'ContinuationToken' not in kwargs: + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'a', + 'Size': 1, + } + ], + 'IsTruncated': True, + 'NextContinuationToken': 'token-2', + } + raise BackgroundError('background failure') + + self.client.list_objects_v2.side_effect = list_objects_v2 + objects = ThreadedBucketLister( + self.client, self.date_parser + ).list_objects(bucket='foo') + + self.assertEqual(next(objects), ('foo/a', mock.ANY)) + with self.assertRaises(BackgroundError): + list(objects) + + def test_closing_lister_cleans_up_requester(self): + page_two_requested = threading.Event() + allow_page_two = threading.Event() + + def list_objects_v2(**kwargs): + if 'ContinuationToken' not in kwargs: + self._emit_before_parse( + b'' + b'' + b'token-2' + b'' + b'true' + b'' + ) + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'a', + 'Size': 1, + } + ], + 'IsTruncated': True, + 'NextContinuationToken': 'token-2', + } + page_two_requested.set() + allow_page_two.wait(timeout=1) + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': 'b', + 'Size': 2, + } + ], + 'IsTruncated': False, + } + + self.client.list_objects_v2.side_effect = list_objects_v2 + objects = ThreadedBucketLister( + self.client, self.date_parser + ).list_objects(bucket='foo') + + self.assertEqual(next(objects), ('foo/a', mock.ANY)) + self.assertTrue(page_two_requested.wait(timeout=1)) + allow_page_two.set() + objects.close() + + def test_closing_lister_stops_when_result_queue_is_full(self): + page_twenty_three_requested = threading.Event() + + def list_objects_v2(**kwargs): + continuation_token = kwargs.get('ContinuationToken') + if continuation_token is None: + page_number = 1 + else: + page_number = int(continuation_token.rsplit('-', 1)[1]) + if page_number == 23: + page_twenty_three_requested.set() + return { + 'Contents': [ + { + 'LastModified': '2014-02-27T04:20:38.000Z', + 'Key': f'key-{page_number}', + 'Size': page_number, + } + ], + 'IsTruncated': True, + 'NextContinuationToken': f'token-{page_number + 1}', + } + + self.client.list_objects_v2.side_effect = list_objects_v2 + objects = ThreadedBucketLister( + self.client, self.date_parser + ).list_objects(bucket='foo') + + self.assertEqual(next(objects)[0], 'foo/key-1') + self.assertTrue(page_twenty_three_requested.wait(timeout=2)) + + close_errors = [] + + def close_objects(): + try: + objects.close() + except Exception as e: + close_errors.append(e) + + close_thread = threading.Thread(target=close_objects, daemon=True) + close_thread.start() + close_thread.join(timeout=2) + + self.assertFalse(close_thread.is_alive()) + self.assertEqual(close_errors, []) diff --git a/tests/unit/customizations/s3/test_factory.py b/tests/unit/customizations/s3/test_factory.py index b577d5b4d1f4..ec95dc6810fe 100644 --- a/tests/unit/customizations/s3/test_factory.py +++ b/tests/unit/customizations/s3/test_factory.py @@ -94,6 +94,8 @@ def test_crt_get_optimized_platforms_match_expected_platforms(): class TestClientFactory(unittest.TestCase): def setUp(self): self.session = mock.Mock(Session) + self.original_factory = mock.Mock() + self.session.get_component.return_value = self.original_factory self.factory = ClientFactory(self.session) def test_create_client(self): @@ -153,6 +155,76 @@ def test_create_client_respects_source_region_for_copies(self): 's3', region_name='us-west-1', endpoint_url=None, verify=True ) + def test_create_listing_client_uses_temporary_identity_parser(self): + params = { + 'region': 'us-west-2', + 'endpoint_url': 'https://myendpoint', + 'verify_ssl': True, + } + temp_factory = mock.Mock() + with mock.patch( + 'awscli.customizations.s3.factory.ResponseParserFactory', + return_value=temp_factory, + ): + self.factory.create_listing_client(params=params) + + temp_factory.set_parser_defaults.assert_called_once() + timestamp_parser = temp_factory.set_parser_defaults.call_args[1][ + 'timestamp_parser' + ] + self.assertEqual(timestamp_parser('timestamp'), 'timestamp') + self.assertEqual( + self.session.register_component.call_args_list, + [ + mock.call('response_parser_factory', temp_factory), + mock.call('response_parser_factory', self.original_factory), + ], + ) + self.session.create_client.assert_called_with( + 's3', + region_name='us-west-2', + endpoint_url='https://myendpoint', + verify=True, + ) + + def test_create_listing_client_restores_factory_after_error(self): + params = { + 'region': 'us-west-2', + 'endpoint_url': None, + 'verify_ssl': None, + } + self.session.create_client.side_effect = RuntimeError('boom') + temp_factory = mock.Mock() + + with mock.patch( + 'awscli.customizations.s3.factory.ResponseParserFactory', + return_value=temp_factory, + ): + with self.assertRaisesRegex(RuntimeError, 'boom'): + self.factory.create_listing_client(params=params) + + self.assertEqual( + self.session.register_component.call_args_list[-1], + mock.call('response_parser_factory', self.original_factory), + ) + + def test_create_listing_client_respects_source_region_for_copies(self): + params = { + 'region': 'us-west-2', + 'endpoint_url': 'https://myendpoint', + 'verify_ssl': True, + 'source_region': 'us-west-1', + 'paths_type': 's3s3', + } + with mock.patch('awscli.customizations.s3.factory.ResponseParserFactory'): + self.factory.create_listing_client( + params, is_source_client=True + ) + + self.session.create_client.assert_called_with( + 's3', region_name='us-west-1', endpoint_url=None, verify=True + ) + class TestTransferManagerFactory(unittest.TestCase): def setUp(self): diff --git a/tests/unit/customizations/s3/test_filegenerator.py b/tests/unit/customizations/s3/test_filegenerator.py index 1b962ba8b5da..1c3b72522fd1 100644 --- a/tests/unit/customizations/s3/test_filegenerator.py +++ b/tests/unit/customizations/s3/test_filegenerator.py @@ -693,6 +693,47 @@ def test_s3_single_file_explicit_checksum_mode_overrides(self): call_kwargs = self.client.head_object.call_args[1] self.assertEqual(call_kwargs['ChecksumMode'], 'ENABLED') + def test_s3_directory_uses_configured_bucket_lister(self): + listing_client = mock.Mock() + bucket_lister = mock.Mock() + bucket_lister.return_value.list_objects.return_value = [] + file_gen = FileGenerator( + self.client, + '', + listing_client=listing_client, + bucket_lister_cls=bucket_lister, + ) + + list(file_gen.list_objects(self.bucket + '/', dir_op=True)) + + bucket_lister.assert_called_once_with(listing_client) + + def test_s3_single_file_still_uses_normal_client(self): + input_s3_file = { + 'src': {'path': self.file1, 'type': 's3'}, + 'dest': {'path': 'text1.txt', 'type': 'local'}, + 'dir_op': False, + 'use_src_name': False, + } + listing_client = mock.Mock() + self.client = mock.Mock() + self.client.meta.config.response_checksum_validation = 'when_required' + self.client.head_object.return_value = { + 'ContentLength': 100, + 'LastModified': '2014-01-09T20:45:49.000Z', + 'ETag': 'etag', + } + + files = list( + FileGenerator( + self.client, '', listing_client=listing_client + ).call(input_s3_file) + ) + + self.assertEqual(len(files), 1) + self.client.head_object.assert_called_once() + listing_client.head_object.assert_not_called() + def test_s3_directory(self): """ Generates s3 files under a common prefix. Also it ensures that diff --git a/tests/unit/customizations/s3/test_subcommands.py b/tests/unit/customizations/s3/test_subcommands.py index 0b99e6f1ff0d..b91ff0f96bb2 100644 --- a/tests/unit/customizations/s3/test_subcommands.py +++ b/tests/unit/customizations/s3/test_subcommands.py @@ -18,6 +18,9 @@ from awscli.compat import StringIO from awscli.customizations.exceptions import ParamValidationError +from awscli.customizations.s3.bucketlister import ( + BucketLister, +) from awscli.customizations.s3.s3 import S3 from awscli.customizations.s3.subcommands import ( CommandArchitecture, @@ -292,6 +295,8 @@ def setUp(self): self.transfer_manager = mock.Mock() self.source_client = mock.Mock() self.transfer_client = mock.Mock() + self.source_listing_client = mock.Mock() + self.destination_listing_client = mock.Mock() self.file_creator = FileCreator() self.loc_files = make_loc_files(self.file_creator) self.output = StringIO() @@ -310,7 +315,9 @@ def tearDown(self): super(CommandArchitectureTest, self).tearDown() clean_loc_files(self.file_creator) - def get_cmd_architecture(self, cmd, params): + def get_cmd_architecture( + self, cmd, params, bucket_lister_cls=BucketLister + ): return CommandArchitecture( session=self.session, cmd=cmd, @@ -318,6 +325,9 @@ def get_cmd_architecture(self, cmd, params): transfer_manager=self.transfer_manager, source_client=self.source_client, transfer_client=self.transfer_client, + source_listing_client=self.source_listing_client, + destination_listing_client=self.destination_listing_client, + bucket_lister_cls=bucket_lister_cls, ) def get_params(self, **override_kwargs): @@ -369,6 +379,146 @@ def test_create_instructions(self): ['file_generator', 'filters', 'file_info_builder', 's3_handler'], ) + def test_run_passes_source_and_destination_listing_clients(self): + class StopExecution(Exception): + pass + + params = self.get_params( + src='s3://source/', + dest='s3://dest/', + paths_type='s3s3', + follow_symlinks=True, + page_size=None, + request_payer=None, + case_conflict='ignore', + ) + cmd_arc = self.get_cmd_architecture('cp', params) + + with mock.patch( + 'awscli.customizations.s3.subcommands.FileFormat' + ) as file_format, mock.patch( + 'awscli.customizations.s3.subcommands.FileGenerator' + ) as file_generator: + file_format.return_value.format.side_effect = [ + { + 'src': {'path': 's3://source/', 'type': 's3'}, + 'dest': {'path': 's3://dest/', 'type': 's3'}, + 'dir_op': True, + }, + { + 'src': {'path': 's3://dest/', 'type': 's3'}, + 'dest': {'path': 's3://source/', 'type': 's3'}, + 'dir_op': True, + }, + ] + file_generator.side_effect = [mock.Mock(), StopExecution()] + + with self.assertRaises(StopExecution): + cmd_arc.run() + + self.assertEqual( + file_generator.call_args_list[0][1]['listing_client'], + self.source_listing_client, + ) + self.assertEqual( + file_generator.call_args_list[1][1]['listing_client'], + self.destination_listing_client, + ) + + def test_run_uses_single_bucket_lister_by_default(self): + class StopExecution(Exception): + pass + + params = self.get_params( + src='s3://source/', + dest='s3://dest/', + paths_type='s3s3', + follow_symlinks=True, + page_size=None, + request_payer=None, + case_conflict='ignore', + ) + cmd_arc = self.get_cmd_architecture('cp', params) + + with mock.patch( + 'awscli.customizations.s3.subcommands.FileFormat' + ) as file_format, mock.patch( + 'awscli.customizations.s3.subcommands.FileGenerator' + ) as file_generator: + file_format.return_value.format.side_effect = [ + { + 'src': {'path': 's3://source/', 'type': 's3'}, + 'dest': {'path': 's3://dest/', 'type': 's3'}, + 'dir_op': True, + }, + { + 'src': {'path': 's3://dest/', 'type': 's3'}, + 'dest': {'path': 's3://source/', 'type': 's3'}, + 'dir_op': True, + }, + ] + file_generator.side_effect = [mock.Mock(), StopExecution()] + + with self.assertRaises(StopExecution): + cmd_arc.run() + + self.assertEqual( + file_generator.call_args_list[0][1]['bucket_lister_cls'], + BucketLister, + ) + self.assertEqual( + file_generator.call_args_list[1][1]['bucket_lister_cls'], + BucketLister, + ) + + def test_run_uses_single_bucket_lister_when_configured(self): + class StopExecution(Exception): + pass + + params = self.get_params( + src='s3://source/', + dest='s3://dest/', + paths_type='s3s3', + follow_symlinks=True, + page_size=None, + request_payer=None, + case_conflict='ignore', + ) + cmd_arc = self.get_cmd_architecture( + 'cp', params, bucket_lister_cls=BucketLister + ) + + with mock.patch( + 'awscli.customizations.s3.subcommands.FileFormat' + ) as file_format, mock.patch( + 'awscli.customizations.s3.subcommands.FileGenerator' + ) as file_generator: + file_format.return_value.format.side_effect = [ + { + 'src': {'path': 's3://source/', 'type': 's3'}, + 'dest': {'path': 's3://dest/', 'type': 's3'}, + 'dir_op': True, + }, + { + 'src': {'path': 's3://dest/', 'type': 's3'}, + 'dest': {'path': 's3://source/', 'type': 's3'}, + 'dir_op': True, + }, + ] + file_generator.side_effect = [mock.Mock(), StopExecution()] + + with self.assertRaises(StopExecution): + cmd_arc.run() + + self.assertEqual( + file_generator.call_args_list[0][1]['bucket_lister_cls'], + BucketLister, + ) + self.assertEqual( + file_generator.call_args_list[1][1]['bucket_lister_cls'], + BucketLister, + ) + def test_choose_sync_strategy_default(self): self.session = mock.Mock(self.session) cmd_arc = self.get_cmd_architecture('sync', self.get_params()) diff --git a/tests/unit/customizations/s3/test_transferconfig.py b/tests/unit/customizations/s3/test_transferconfig.py index 0deecf860997..9791eff0c2a6 100644 --- a/tests/unit/customizations/s3/test_transferconfig.py +++ b/tests/unit/customizations/s3/test_transferconfig.py @@ -88,6 +88,21 @@ def test_set_preferred_transfer_client(self, provided, resolved): runtime_config = self.build_config_with(**config_kwargs) assert runtime_config['preferred_transfer_client'] == resolved + @pytest.mark.parametrize( + 'provided,resolved', + [ + (None, 'single'), + ('threaded', 'threaded'), + ('single', 'single'), + ], + ) + def test_set_bucket_lister(self, provided, resolved): + config_kwargs = {} + if provided is not None: + config_kwargs['bucket_lister'] = provided + runtime_config = self.build_config_with(**config_kwargs) + assert runtime_config['bucket_lister'] == resolved + @pytest.mark.parametrize( 'config_name,provided,expected', [ @@ -151,6 +166,10 @@ def test_validates_preferred_transfer_client_choices(self): with pytest.raises(transferconfig.InvalidConfigError): self.build_config_with(preferred_transfer_client='not-supported') + def test_validates_bucket_lister_choices(self): + with pytest.raises(transferconfig.InvalidConfigError): + self.build_config_with(bucket_lister='not-supported') + @pytest.mark.parametrize( 'attr,val,expected', [ diff --git a/tests/unit/customizations/s3/test_utils.py b/tests/unit/customizations/s3/test_utils.py index 47446864cb1c..af4fb7a2cbde 100644 --- a/tests/unit/customizations/s3/test_utils.py +++ b/tests/unit/customizations/s3/test_utils.py @@ -17,7 +17,6 @@ import time import pytest -from botocore.hooks import HierarchicalEmitter from dateutil.tz import tzlocal from s3transfer.compat import seekable @@ -25,7 +24,6 @@ from awscli.customizations.exceptions import ParamValidationError from awscli.customizations.s3.utils import ( AppendFilter, - BucketLister, NonSeekableStream, RequestParamsMapper, S3PathResolver, @@ -507,82 +505,6 @@ def test_priority_attr_is_missing(self): self.assertIs(q.get(), a) -class TestBucketList(unittest.TestCase): - def setUp(self): - self.client = mock.Mock() - self.emitter = HierarchicalEmitter() - self.client.meta.events = self.emitter - self.date_parser = mock.Mock() - self.date_parser.return_value = mock.sentinel.now - self.responses = [] - - def fake_paginate(self, *args, **kwargs): - for response in self.responses: - self.emitter.emit('after-call.s3.ListObjectsV2', parsed=response) - return self.responses - - def test_list_objects(self): - now = mock.sentinel.now - self.client.get_paginator.return_value.paginate = self.fake_paginate - individual_response_elements = [ - { - 'LastModified': '2014-02-27T04:20:38.000Z', - 'Key': 'a', - 'Size': 1, - }, - { - 'LastModified': '2014-02-27T04:20:38.000Z', - 'Key': 'b', - 'Size': 2, - }, - { - 'LastModified': '2014-02-27T04:20:38.000Z', - 'Key': 'c', - 'Size': 3, - }, - ] - self.responses = [ - {'Contents': individual_response_elements[0:2]}, - {'Contents': [individual_response_elements[2]]}, - ] - lister = BucketLister(self.client, self.date_parser) - objects = list(lister.list_objects(bucket='foo')) - self.assertEqual( - objects, - [ - ('foo/a', individual_response_elements[0]), - ('foo/b', individual_response_elements[1]), - ('foo/c', individual_response_elements[2]), - ], - ) - for individual_response in individual_response_elements: - self.assertEqual(individual_response['LastModified'], now) - - def test_list_objects_passes_in_extra_args(self): - self.client.get_paginator.return_value.paginate.return_value = [ - { - 'Contents': [ - { - 'LastModified': '2014-02-27T04:20:38.000Z', - 'Key': 'mykey', - 'Size': 3, - } - ] - } - ] - lister = BucketLister(self.client, self.date_parser) - list( - lister.list_objects( - bucket='mybucket', extra_args={'RequestPayer': 'requester'} - ) - ) - self.client.get_paginator.return_value.paginate.assert_called_with( - Bucket='mybucket', - PaginationConfig={'PageSize': None}, - RequestPayer='requester', - ) - - class TestGetFileStat(unittest.TestCase): def test_get_file_stat(self): now = datetime.datetime.now(tzlocal())