From cd2cc450306f8265f149ec995e1dccc8f468ef16 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Mon, 16 Mar 2026 16:18:48 +0000 Subject: [PATCH 1/2] feat: Mutations Batcher shim --- google/cloud/bigtable/batcher.py | 283 ++------------ google/cloud/bigtable/data/_async/client.py | 2 - .../bigtable/data/_async/mutations_batcher.py | 2 +- tests/system/v2_client/test_data_api.py | 105 ++++++ tests/unit/v2_client/test_batcher.py | 356 ++++++++++-------- 5 files changed, 344 insertions(+), 404 deletions(-) diff --git a/google/cloud/bigtable/batcher.py b/google/cloud/bigtable/batcher.py index f9b85386d..7652a8461 100644 --- a/google/cloud/bigtable/batcher.py +++ b/google/cloud/bigtable/batcher.py @@ -13,14 +13,12 @@ # limitations under the License. """User friendly container for Google Cloud Bigtable MutationBatcher.""" -import threading import queue -import concurrent.futures import atexit -from google.api_core.exceptions import from_grpc_status -from dataclasses import dataclass +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.mutations import RowMutationEntry FLUSH_COUNT = 100 # after this many elements, send out the batch @@ -41,131 +39,6 @@ def __init__(self, message, exc): super().__init__(self.message) -class _MutationsBatchQueue(object): - """Private Threadsafe Queue to hold rows for batching.""" - - def __init__(self, max_mutation_bytes=MAX_MUTATION_SIZE, flush_count=FLUSH_COUNT): - """Specify the queue constraints""" - self._queue = queue.Queue() - self.total_mutation_count = 0 - self.total_size = 0 - self.max_mutation_bytes = max_mutation_bytes - self.flush_count = flush_count - - def get(self): - """ - Retrieve an item from the queue. Recalculate queue size. - - If the queue is empty, return None. - """ - try: - row = self._queue.get_nowait() - mutation_size = row.get_mutations_size() - self.total_mutation_count -= len(row._get_mutations()) - self.total_size -= mutation_size - return row - except queue.Empty: - return None - - def put(self, item): - """Insert an item to the queue. Recalculate queue size.""" - - mutation_count = len(item._get_mutations()) - - self._queue.put(item) - - self.total_size += item.get_mutations_size() - self.total_mutation_count += mutation_count - - def full(self): - """Check if the queue is full.""" - if ( - self.total_mutation_count >= self.flush_count - or self.total_size >= self.max_mutation_bytes - ): - return True - return False - - -@dataclass -class _BatchInfo: - """Keeping track of size of a batch""" - - mutations_count: int = 0 - rows_count: int = 0 - mutations_size: int = 0 - - -class _FlowControl(object): - def __init__( - self, - max_mutations=MAX_OUTSTANDING_ELEMENTS, - max_mutation_bytes=MAX_OUTSTANDING_BYTES, - ): - """Control the inflight requests. Keep track of the mutations, row bytes and row counts. - As requests to backend are being made, adjust the number of mutations being processed. - - If threshold is reached, block the flow. - Reopen the flow as requests are finished. - """ - self.max_mutations = max_mutations - self.max_mutation_bytes = max_mutation_bytes - self.inflight_mutations = 0 - self.inflight_size = 0 - self.event = threading.Event() - self.event.set() - self._lock = threading.Lock() - - def is_blocked(self): - """Returns True if: - - - inflight mutations >= max_mutations, or - - inflight bytes size >= max_mutation_bytes, or - """ - - return ( - self.inflight_mutations >= self.max_mutations - or self.inflight_size >= self.max_mutation_bytes - ) - - def control_flow(self, batch_info): - """ - Calculate the resources used by this batch - """ - - with self._lock: - self.inflight_mutations += batch_info.mutations_count - self.inflight_size += batch_info.mutations_size - self.set_flow_control_status() - - def wait(self): - """ - Wait until flow control pushback has been released. - It awakens as soon as `event` is set. - """ - self.event.wait() - - def set_flow_control_status(self): - """Check the inflight mutations and size. - - If values exceed the allowed threshold, block the event. - """ - if self.is_blocked(): - self.event.clear() # sleep - else: - self.event.set() # awaken the threads - - def release(self, batch_info): - """ - Release the resources. - Decrement the row size to allow enqueued mutations to be run. - """ - with self._lock: - self.inflight_mutations -= batch_info.mutations_count - self.inflight_size -= batch_info.mutations_size - self.set_flow_control_status() - - class MutationsBatcher(object): """A MutationsBatcher is used in batch cases where the number of mutations is large or unknown. It will store :class:`DirectRow` in memory until one of the @@ -214,29 +87,41 @@ def __init__( flush_interval=1, batch_completed_callback=None, ): - self._rows = _MutationsBatchQueue( - max_mutation_bytes=max_row_bytes, flush_count=flush_count - ) self.table = table - self._executor = concurrent.futures.ThreadPoolExecutor() - atexit.register(self.close) - self._timer = threading.Timer(flush_interval, self.flush) - self._timer.start() - self.flow_control = _FlowControl( - max_mutations=MAX_OUTSTANDING_ELEMENTS, - max_mutation_bytes=MAX_OUTSTANDING_BYTES, - ) - self.futures_mapping = {} - self.exceptions = queue.Queue() + self._flush_count = flush_count + self._max_row_bytes = max_row_bytes + self._flush_interval = flush_interval self._user_batch_completed_callback = batch_completed_callback + self._init_batcher() + atexit.register(self.close) + self._exceptions = queue.Queue() @property def flush_count(self): - return self._rows.flush_count + return self._flush_count @property def max_row_bytes(self): - return self._rows.max_mutation_bytes + return self._max_row_bytes + + def _init_batcher(self): + self._batcher = self.table._table_impl.mutations_batcher( + flush_interval=self._flush_interval, + flush_limit_mutation_count=self._flush_count, + flush_limit_bytes=self._max_row_bytes, + ) + self._batcher._user_batch_completed_callback = ( + self._user_batch_completed_callback + ) + + def _close_batcher(self): + try: + self._batcher.close() + except MutationsExceptionGroup as exc_group: + for error in exc_group.exceptions: + # Return the cause of the FailedMutationEntryError to the user, + # as this might be more what they're expecting. + self._exceptions.put(error.__cause__) def __enter__(self): """Starting the MutationsBatcher as a context manager""" @@ -260,10 +145,7 @@ def mutate(self, row): * :exc:`~.table._BigtableRetryableError` if any row returned a transient error. * :exc:`RuntimeError` if the number of responses doesn't match the number of rows that were retried """ - self._rows.put(row) - - if self._rows.full(): - self._flush_async() + self._batcher.append(RowMutationEntry(row.row_key, row._get_mutations())) def mutate_rows(self, rows): """Add multiple rows to the batch. If the current batch meets one of the size @@ -298,102 +180,8 @@ def flush(self): :raises: * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. """ - rows_to_flush = [] - row = self._rows.get() - while row is not None: - rows_to_flush.append(row) - row = self._rows.get() - response = self._flush_rows(rows_to_flush) - return response - - def _flush_async(self): - """Sends the current batch to Cloud Bigtable asynchronously. - - :raises: - * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. - """ - next_row = self._rows.get() - while next_row is not None: - # start a new batch - rows_to_flush = [next_row] - batch_info = _BatchInfo( - mutations_count=len(next_row._get_mutations()), - rows_count=1, - mutations_size=next_row.get_mutations_size(), - ) - # fill up batch with rows - next_row = self._rows.get() - while next_row is not None and self._row_fits_in_batch( - next_row, batch_info - ): - rows_to_flush.append(next_row) - batch_info.mutations_count += len(next_row._get_mutations()) - batch_info.rows_count += 1 - batch_info.mutations_size += next_row.get_mutations_size() - next_row = self._rows.get() - # send batch over network - # wait for resources to become available - self.flow_control.wait() - # once unblocked, submit the batch - # event flag will be set by control_flow to block subsequent thread, but not blocking this one - self.flow_control.control_flow(batch_info) - future = self._executor.submit(self._flush_rows, rows_to_flush) - # schedule release of resources from flow control - self.futures_mapping[future] = batch_info - future.add_done_callback(self._batch_completed_callback) - - def _batch_completed_callback(self, future): - """Callback for when the mutation has finished to clean up the current batch - and release items from the flow controller. - Raise exceptions if there's any. - Release the resources locked by the flow control and allow enqueued tasks to be run. - """ - processed_rows = self.futures_mapping[future] - self.flow_control.release(processed_rows) - del self.futures_mapping[future] - - def _row_fits_in_batch(self, row, batch_info): - """Checks if a row can fit in the current batch. - - :type row: class - :param row: :class:`~google.cloud.bigtable.row.DirectRow`. - - :type batch_info: :class:`_BatchInfo` - :param batch_info: Information about the current batch. - - :rtype: bool - :returns: True if the row can fit in the current batch. - """ - new_rows_count = batch_info.rows_count + 1 - new_mutations_count = batch_info.mutations_count + len(row._get_mutations()) - new_mutations_size = batch_info.mutations_size + row.get_mutations_size() - return ( - new_rows_count <= self.flush_count - and new_mutations_size <= self.max_row_bytes - and new_mutations_count <= self.flow_control.max_mutations - and new_mutations_size <= self.flow_control.max_mutation_bytes - ) - - def _flush_rows(self, rows_to_flush): - """Mutate the specified rows. - - :raises: - * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. - """ - responses = [] - if len(rows_to_flush) > 0: - response = self.table.mutate_rows(rows_to_flush) - - if self._user_batch_completed_callback: - self._user_batch_completed_callback(response) - - for result in response: - if result.code != 0: - exc = from_grpc_status(result.code, result.message) - self.exceptions.put(exc) - responses.append(result) - - return responses + self._close_batcher() + self._init_batcher() def __exit__(self, exc_type, exc_value, exc_traceback): """Clean up resources. Flush and shutdown the ThreadPoolExecutor.""" @@ -406,9 +194,8 @@ def close(self): :raises: * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. """ - self.flush() - self._executor.shutdown(wait=True) + self._close_batcher() atexit.unregister(self.close) - if self.exceptions.qsize() > 0: - exc = list(self.exceptions.queue) + if self._exceptions.qsize() > 0: + exc = list(self._exceptions.queue) raise MutationsBatchError("Errors in batch mutations.", exc=exc) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 61f932735..35fe42814 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -16,7 +16,6 @@ from __future__ import annotations from typing import ( - Callable, cast, Any, AsyncIterable, @@ -116,7 +115,6 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery - from google.rpc import status_pb2 if CrossSync.is_async: from google.cloud.bigtable.data._async.mutations_batcher import ( diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 6d6aefc9a..fb70dddb6 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast +from typing import Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque diff --git a/tests/system/v2_client/test_data_api.py b/tests/system/v2_client/test_data_api.py index 6b46ba498..78d816c75 100644 --- a/tests/system/v2_client/test_data_api.py +++ b/tests/system/v2_client/test_data_api.py @@ -1241,3 +1241,108 @@ def callback(results): time.sleep(0.01) # ensure all mutations were sent assert len(all_results) == num_sent + + +def test_mutations_batcher_exceptions(data_table, rows_to_delete): + """Test the mutations batcher exception handling""" + import mock + from google.cloud.bigtable.batcher import MutationsBatcher, MutationsBatchError + from google.cloud.bigtable_v2 import MutateRowsResponse + from google.rpc import code_pb2, status_pb2 + + num_sent = 5 + + error_response = [ + MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=i, + status=status_pb2.Status( + code=code_pb2.INTERNAL, + message="Test error", + ), + ) + for i in range(num_sent) + ] + ) + ] + + # Simulate only failures + with pytest.raises(MutationsBatchError): + with mock.patch.object( + data_table._instance._client.table_data_client, "mutate_rows" + ) as mutate_mock: + mutate_mock.side_effect = [error_response] * 100000 + with MutationsBatcher( + data_table, + flush_count=10, + flush_interval=1, + ) as batcher: + for i in range(num_sent): + row = data_table.direct_row("row{}".format(i)) + row.set_cell( + COLUMN_FAMILY_ID1, COL_NAME1, "val{}".format(i).encode("utf-8") + ) + rows_to_delete.append(row) + batcher.mutate(row) + batcher.flush() + + # Test that exceptions are only raised on close. + with mock.patch.object( + data_table._instance._client.table_data_client, "mutate_rows" + ) as mutate_mock: + mutate_mock.side_effect = [error_response] * 100000 + batcher = MutationsBatcher( + data_table, + flush_count=10, + flush_interval=1, + ) + for i in range(num_sent): + row = data_table.direct_row("row{}".format(i)) + row.set_cell( + COLUMN_FAMILY_ID1, COL_NAME1, "val{}".format(i).encode("utf-8") + ) + rows_to_delete.append(row) + batcher.mutate(row) + batcher.flush() + + with pytest.raises(MutationsBatchError): + batcher.close() + + +def test_mutations_batcher_manual_flush(data_table, rows_to_delete): + """Test the mutations batcher manual flush""" + import mock + from google.cloud.bigtable.batcher import MutationsBatcher + from google.rpc import status_pb2, code_pb2 + + num_batches = 5 + batch_size = 4 + callback = mock.MagicMock() + + with MutationsBatcher( + data_table, + flush_count=500, + flush_interval=5, + batch_completed_callback=callback, + ) as batcher: + for i in range(num_batches): + for j in range(batch_size): + num = i * batch_size + j + row = data_table.direct_row(f"row{num}".encode("utf-8")) + row.set_cell(COLUMN_FAMILY_ID1, COL_NAME1, f"val{num}".encode("utf-8")) + rows_to_delete.append(row) + batcher.mutate(row) + batcher.flush() + callback.assert_called_with( + [status_pb2.Status(code=code_pb2.OK)] * batch_size + ) + + # ensure all mutations were sent + rows = data_table.read_rows() + rows.consume_all() + for row_num in range(0, num_batches * batch_size): + row = rows.rows[f"row{row_num}".encode("utf-8")] + assert row.cells[COLUMN_FAMILY_ID1][COL_NAME1][ + 0 + ].value == f"val{row_num}".encode("utf-8") diff --git a/tests/unit/v2_client/test_batcher.py b/tests/unit/v2_client/test_batcher.py index fcf606972..80d0ac67c 100644 --- a/tests/unit/v2_client/test_batcher.py +++ b/tests/unit/v2_client/test_batcher.py @@ -20,60 +20,102 @@ from google.cloud.bigtable.row import DirectRow from google.cloud.bigtable.batcher import ( - _FlowControl, MutationsBatcher, MutationsBatchError, ) +from ._testing import _make_credentials + +PROJECT = "PROJECT" +INSTANCE_ID = "instance-id" TABLE_ID = "table-id" TABLE_NAME = "/tables/" + TABLE_ID -def test_mutation_batcher_constructor(): - table = _Table(TABLE_NAME) - with MutationsBatcher(table) as mutation_batcher: - assert table is mutation_batcher.table - - -def test_mutation_batcher_w_user_callback(): - table = _Table(TABLE_NAME) - - def callback_fn(response): - callback_fn.count = len(response) +@pytest.fixture +def _setup_batcher(): + from google.cloud.bigtable.client import Client + from google.cloud.bigtable.table import Table + + import google.cloud.bigtable.data._sync_autogen.mutations_batcher + + client = Client(project=PROJECT, credentials=_make_credentials()) + instance = client.instance(INSTANCE_ID) + + with mock.patch.object( + google.cloud.bigtable.data._sync_autogen.mutations_batcher.CrossSync._Sync_Impl, + "_MutateRowsOperation", + ) as operation_mock: + yield Table(TABLE_ID, instance=instance), operation_mock + + +@pytest.fixture +def _atexit_mock(): + atexit_mock = _AtexitMock() + with mock.patch.multiple( + "atexit", register=atexit_mock.register, unregister=atexit_mock.unregister + ): + yield atexit_mock + + +def test_mutations_batcher_constructor(_setup_batcher, _atexit_mock): + flush_count = 5 + flush_interval = 0.1 + max_row_bytes = 10000 + table, _ = _setup_batcher + with mock.patch.object( + table._table_impl, "mutations_batcher" + ) as batcher_impl_constructor: + with MutationsBatcher( + table, + flush_count=flush_count, + flush_interval=flush_interval, + max_row_bytes=max_row_bytes, + ) as mutation_batcher: + assert table is mutation_batcher.table + batcher_impl_constructor.assert_called_once_with( + flush_interval=flush_interval, + flush_limit_mutation_count=flush_count, + flush_limit_bytes=max_row_bytes, + ) + assert mutation_batcher.close in _atexit_mock._functions + + +def test_mutations_batcher_w_user_callback(_setup_batcher): + table, _ = _setup_batcher + + callback_fn = mock.Mock() + batch_size = 4 with MutationsBatcher( - table, flush_count=1, batch_completed_callback=callback_fn + table, flush_count=batch_size, batch_completed_callback=callback_fn ) as mutation_batcher: - rows = [ - DirectRow(row_key=b"row_key"), - DirectRow(row_key=b"row_key_2"), - DirectRow(row_key=b"row_key_3"), - DirectRow(row_key=b"row_key_4"), - ] + rows = [DirectRow(row_key=f"row_key_{i}".encode()) for i in range(batch_size)] + for row in rows: + row.delete() mutation_batcher.mutate_rows(rows) - assert callback_fn.count == 4 + assert len(callback_fn.call_args[0][0]) == batch_size -def test_mutation_batcher_mutate_row(): - table = _Table(TABLE_NAME) - with MutationsBatcher(table=table) as mutation_batcher: - rows = [ - DirectRow(row_key=b"row_key"), - DirectRow(row_key=b"row_key_2"), - DirectRow(row_key=b"row_key_3"), - DirectRow(row_key=b"row_key_4"), - ] +def test_mutations_batcher_mutate_row(_setup_batcher): + table, operation_mock = _setup_batcher + batch_size = 4 + + with MutationsBatcher(table, flush_count=batch_size) as mutation_batcher: + rows = [DirectRow(row_key=f"row_key_{i}".encode()) for i in range(batch_size)] + for row in rows: + row.delete() mutation_batcher.mutate_rows(rows) - assert table.mutation_calls == 1 + operation_mock.assert_called_once() -def test_mutation_batcher_mutate(): - table = _Table(TABLE_NAME) - with MutationsBatcher(table=table) as mutation_batcher: +def test_mutations_batcher_mutate(_setup_batcher): + table, operation_mock = _setup_batcher + with MutationsBatcher(table=table, flush_count=1) as mutation_batcher: row = DirectRow(row_key=b"row_key") row.set_cell("cf1", b"c1", 1) row.set_cell("cf1", b"c2", 2) @@ -82,47 +124,36 @@ def test_mutation_batcher_mutate(): mutation_batcher.mutate(row) - assert table.mutation_calls == 1 + operation_mock.assert_called_once() -def test_mutation_batcher_flush_w_no_rows(): - table = _Table(TABLE_NAME) +def test_mutations_batcher_manual_flush(_setup_batcher, _atexit_mock): + table, operation_mock = _setup_batcher with MutationsBatcher(table=table) as mutation_batcher: - mutation_batcher.flush() - - assert table.mutation_calls == 0 - + original_batcher_impl = mutation_batcher._batcher + assert original_batcher_impl._on_exit in _atexit_mock._functions -def test_mutation_batcher_mutate_w_max_flush_count(): - table = _Table(TABLE_NAME) - with MutationsBatcher(table=table, flush_count=3) as mutation_batcher: - row_1 = DirectRow(row_key=b"row_key_1") - row_2 = DirectRow(row_key=b"row_key_2") - row_3 = DirectRow(row_key=b"row_key_3") + row = DirectRow(row_key=b"row_key") + row.set_cell("cf1", b"c1", 1) + mutation_batcher.mutate(row) - mutation_batcher.mutate(row_1) - mutation_batcher.mutate(row_2) - mutation_batcher.mutate(row_3) + mutation_batcher.flush() - assert table.mutation_calls == 1 + operation_mock.assert_called_once() + assert mutation_batcher._batcher != original_batcher_impl + assert original_batcher_impl._on_exit not in _atexit_mock._functions -@mock.patch("google.cloud.bigtable.batcher.MAX_OUTSTANDING_ELEMENTS", new=3) -def test_mutation_batcher_mutate_w_max_mutations(): - table = _Table(TABLE_NAME) +def test_mutations_batcher_flush_w_no_rows(_setup_batcher): + table, operation_mock = _setup_batcher with MutationsBatcher(table=table) as mutation_batcher: - row = DirectRow(row_key=b"row_key") - row.set_cell("cf1", b"c1", 1) - row.set_cell("cf1", b"c2", 2) - row.set_cell("cf1", b"c3", 3) - - mutation_batcher.mutate(row) + mutation_batcher.flush() - assert table.mutation_calls == 1 + operation_mock.assert_not_called() -def test_mutation_batcher_mutate_w_max_row_bytes(): - table = _Table(TABLE_NAME) +def test_mutations_batcher_mutate_w_max_row_bytes(_setup_batcher): + table, operation_mock = _setup_batcher with MutationsBatcher( table=table, max_row_bytes=3 * 1024 * 1024 ) as mutation_batcher: @@ -136,11 +167,11 @@ def test_mutation_batcher_mutate_w_max_row_bytes(): mutation_batcher.mutate(row) - assert table.mutation_calls == 1 + operation_mock.assert_called_once() -def test_mutations_batcher_flushed_when_closed(): - table = _Table(TABLE_NAME) +def test_mutations_batcher_flushed_when_closed(_setup_batcher): + table, operation_mock = _setup_batcher mutation_batcher = MutationsBatcher(table=table, max_row_bytes=3 * 1024 * 1024) number_of_bytes = 1 * 1024 * 1024 @@ -151,15 +182,15 @@ def test_mutations_batcher_flushed_when_closed(): row.set_cell("cf1", b"c2", max_value) mutation_batcher.mutate(row) - assert table.mutation_calls == 0 + operation_mock.assert_not_called() mutation_batcher.close() - assert table.mutation_calls == 1 + operation_mock.assert_called_once() -def test_mutations_batcher_context_manager_flushed_when_closed(): - table = _Table(TABLE_NAME) +def test_mutations_batcher_context_manager_flushed_when_closed(_setup_batcher): + table, operation_mock = _setup_batcher with MutationsBatcher( table=table, max_row_bytes=3 * 1024 * 1024 ) as mutation_batcher: @@ -171,99 +202,118 @@ def test_mutations_batcher_context_manager_flushed_when_closed(): row.set_cell("cf1", b"c2", max_value) mutation_batcher.mutate(row) + operation_mock.assert_not_called() - assert table.mutation_calls == 1 + operation_mock.assert_called_once() -@mock.patch("google.cloud.bigtable.batcher.MutationsBatcher.flush") -def test_mutations_batcher_flush_interval(mocked_flush): - table = _Table(TABLE_NAME) +def test_mutations_batcher_flush_interval(_setup_batcher): + table, operation_mock = _setup_batcher flush_interval = 0.5 mutation_batcher = MutationsBatcher(table=table, flush_interval=flush_interval) - - assert mutation_batcher._timer.interval == flush_interval - mocked_flush.assert_not_called() + row = DirectRow(row_key=b"row_key") + row.set_cell("cf1", b"c1", b"1") + mutation_batcher.mutate(row) + operation_mock.assert_not_called() time.sleep(0.4) - mocked_flush.assert_not_called() + operation_mock.assert_not_called() - time.sleep(0.1) - mocked_flush.assert_called_once_with() + # Test could be flaky, so giving the thread some extra buffer time + time.sleep(0.25) + operation_mock.assert_called_once() mutation_batcher.close() -def test_mutations_batcher_response_with_error_codes(): - from google.rpc.status_pb2 import Status - - mocked_response = [Status(code=1), Status(code=5)] - - with mock.patch("tests.unit.v2_client.test_batcher._Table") as mocked_table: - table = mocked_table.return_value - mutation_batcher = MutationsBatcher(table=table) - - row1 = DirectRow(row_key=b"row_key") - row2 = DirectRow(row_key=b"row_key") - table.mutate_rows.return_value = mocked_response - - mutation_batcher.mutate_rows([row1, row2]) - with pytest.raises(MutationsBatchError) as exc: - mutation_batcher.close() - assert exc.value.message == "Errors in batch mutations." - assert len(exc.value.exc) == 2 - - assert exc.value.exc[0].message == mocked_response[0].message - assert exc.value.exc[1].message == mocked_response[1].message - - -def test_flow_control_event_is_set_when_not_blocked(): - flow_control = _FlowControl() - - flow_control.set_flow_control_status() - assert flow_control.event.is_set() - - -def test_flow_control_event_is_not_set_when_blocked(): - flow_control = _FlowControl() - - flow_control.inflight_mutations = flow_control.max_mutations - flow_control.inflight_size = flow_control.max_mutation_bytes - - flow_control.set_flow_control_status() - assert not flow_control.event.is_set() - - -@mock.patch("concurrent.futures.ThreadPoolExecutor.submit") -def test_flush_async_batch_count(mocked_executor_submit): - table = _Table(TABLE_NAME) - mutation_batcher = MutationsBatcher(table=table, flush_count=2) - - number_of_bytes = 1 * 1024 * 1024 - max_value = b"1" * number_of_bytes - for index in range(5): - row = DirectRow(row_key=f"row_key_{index}") - row.set_cell("cf1", b"c1", max_value) - mutation_batcher.mutate(row) - mutation_batcher._flush_async() - - # 3 batches submitted. 2 batches of 2 items, and the last one a single item batch. - assert mocked_executor_submit.call_count == 3 - - -class _Instance(object): - def __init__(self, client=None): - self._client = client - - -class _Table(object): - def __init__(self, name, client=None): - self.name = name - self._instance = _Instance(client) - self.mutation_calls = 0 - - def mutate_rows(self, rows): - from google.rpc.status_pb2 import Status - - self.mutation_calls += 1 - - return [Status(code=0) for _ in rows] +def test_mutations_batcher_response_with_error_codes(_setup_batcher): + from google.api_core import exceptions + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + + table, operation_mock = _setup_batcher + + causes = [ + exceptions.InternalServerError("Something happened"), + exceptions.DataLoss("Data loss"), + ] + excs = [ + FailedMutationEntryError( + failed_idx=i, failed_mutation_entry=mock.Mock(), cause=cause + ) + for i, cause in enumerate(causes) + ] + error = MutationsExceptionGroup(excs=excs, total_entries=len(excs)) + + operation_mock.return_value.start.side_effect = error + + mutations_batcher = MutationsBatcher(table=table) + row1 = DirectRow(row_key=b"row_key") + row1.set_cell("cf1", b"c1", b"1") + row2 = DirectRow(row_key=b"row_key_2") + row2.set_cell("cf1", b"c1", b"1") + mutations_batcher.mutate_rows([row1, row2]) + mutations_batcher.flush() + + with pytest.raises(MutationsBatchError) as raised_error: + mutations_batcher.close() + assert raised_error.value.message == "Errors in batch mutations." + assert len(raised_error.value.exc) == 2 + + assert raised_error.value.exc[0].message == causes[0].message + assert raised_error.value.exc[1].message == causes[1].message + + +def test_mutations_batcher_response_with_error_codes_multiple_flushes(_setup_batcher): + from google.api_core import exceptions + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + + table, operation_mock = _setup_batcher + + causes = [ + exceptions.InternalServerError("Something happened"), + exceptions.DataLoss("Data loss"), + ] + excs = [ + FailedMutationEntryError( + failed_idx=i, failed_mutation_entry=mock.Mock(), cause=cause + ) + for i, cause in enumerate(causes) + ] + error1 = MutationsExceptionGroup(excs=excs[0:1], total_entries=1) + error2 = MutationsExceptionGroup(excs=excs[1:2], total_entries=1) + + operation_mock.return_value.start.side_effect = error1 + + mutations_batcher = MutationsBatcher(table=table) + row1 = DirectRow(row_key=b"row_key") + row1.set_cell("cf1", b"c1", b"1") + mutations_batcher.mutate(row1) + mutations_batcher.flush() + + operation_mock.return_value.start.side_effect = error2 + + row2 = DirectRow(row_key=b"row_key_2") + row2.set_cell("cf1", b"c1", b"1") + mutations_batcher.mutate(row2) + mutations_batcher.flush() + + with pytest.raises(MutationsBatchError) as raised_error: + mutations_batcher.close() + assert raised_error.value.message == "Errors in batch mutations." + assert len(raised_error.value.exc) == 2 + + assert raised_error.value.exc[0].message == causes[0].message + assert raised_error.value.exc[1].message == causes[1].message + + +class _AtexitMock: + def __init__(self): + self._functions = set() + + def register(self, func): + self._functions.add(func) + + def unregister(self, func): + self._functions.remove(func) From f47e6063e12468073066b27903c43684eccfbd14 Mon Sep 17 00:00:00 2001 From: Kevin Zheng Date: Tue, 17 Mar 2026 18:11:00 +0000 Subject: [PATCH 2/2] mypy --- .../bigtable/data/_async/mutations_batcher.py | 6 ++++-- google/cloud/bigtable/data/_helpers.py | 21 +++++++++++-------- .../data/_sync_autogen/mutations_batcher.py | 6 ++++-- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index fb70dddb6..768991a1b 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import Sequence, TYPE_CHECKING, cast +from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -276,7 +276,9 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) - self._user_batch_completed_callback = None + self._user_batch_completed_callback: Optional[ + Callable[[list[status_pb2.Status]], None] + ] = None # clean up on program exit atexit.register(self._on_exit) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 0f411f88a..02919c748 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -16,7 +16,7 @@ """ from __future__ import annotations -from typing import Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import cast, Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union import time import enum from collections import namedtuple @@ -272,14 +272,17 @@ def _get_status(exc: Optional[Exception]) -> status_pb2.Status: Returns: status_pb2.Status: A Status proto object. """ - if ( - isinstance(exc, core_exceptions.GoogleAPICallError) - and exc.grpc_status_code is not None - ): - return status_pb2.Status( # type: ignore[unreachable] - code=exc.grpc_status_code.value[0], - message=exc.message, - details=exc.details, + if isinstance(exc, core_exceptions.GoogleAPICallError): + status_code = cast(Optional["grpc.StatusCode"], exc.grpc_status_code) + if status_code is not None: + return status_pb2.Status( + code=status_code.value[0], + message=exc.message, + details=exc.details, + ) + return status_pb2.Status( + code=code_pb2.Code.UNKNOWN, + message="An unknown error has occurred", ) return status_pb2.Status( diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index c72a79fc7..e17efefac 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -16,7 +16,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import Sequence, TYPE_CHECKING, cast +from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -238,7 +238,9 @@ def __init__( self._newest_exceptions: deque[Exception] = deque( maxlen=self._exception_list_limit ) - self._user_batch_completed_callback = None + self._user_batch_completed_callback: Optional[ + Callable[[list[status_pb2.Status]], None] + ] = None atexit.register(self._on_exit) def _timer_routine(self, interval: float | None) -> None: