From 068f6f37f79fc949772303c60c0d8a6c4adfbc00 Mon Sep 17 00:00:00 2001 From: yaythomas Date: Thu, 29 Jan 2026 17:38:44 -0800 Subject: [PATCH] fix: BatchResult completion reason logic Fix BatchResult.from_items() to match JavaScript SDK behavior. The previous implementation incorrectly checked conditions in the wrong order, causing incorrect completion reasons for concurrent operations with failure tolerance configurations. The checkpointed completion reason is preserved during replay, so existing executions are unaffected. Code with conditional logic based on completion_reason might see different values after this update. Example: With 3 items (1 success, 2 failures) and tolerated_failure_count=1: - Before: ALL_COMPLETED (incorrect - all items finished) - After: FAILURE_TOLERANCE_EXCEEDED (correct - tolerance breached) Changes: - Extract completion reason logic to _get_completion_reason() method - Check failure tolerance BEFORE checking if all completed - Implement proper fail-fast when no completion config provided - Add comprehensive unit tests covering all edge cases This ensures tolerance checks take precedence over success criteria, preventing operations from incorrectly reporting ALL_COMPLETED when failure tolerance has been exceeded. closes #280 --- .../concurrency/models.py | 105 +++++++-- tests/concurrency_test.py | 223 ++++++++++++++---- 2 files changed, 271 insertions(+), 57 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/concurrency/models.py b/src/aws_durable_execution_sdk_python/concurrency/models.py index 29ffeaf..a8137eb 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/models.py +++ b/src/aws_durable_execution_sdk_python/concurrency/models.py @@ -114,6 +114,85 @@ def from_dict( completion_reason = CompletionReason(completion_reason_value) return cls(batch_items, completion_reason) + @staticmethod + def _get_completion_reason( + failure_count: int, + success_count: int, + completed_count: int, + total_count: int, + completion_config: CompletionConfig | None, + ) -> CompletionReason: + """ + Determine completion reason based on completion counts. + + Logic order: + 1. Check failure tolerance FIRST (before checking if all completed) + 2. Check if all completed + 3. Check if minimum successful reached + 4. Default to ALL_COMPLETED + + Args: + failure_count: Number of failed items + success_count: Number of succeeded items + completed_count: Total completed (succeeded + failed) + total_count: Total number of items + completion_config: Optional completion configuration + + Returns: + CompletionReason enum value + """ + # STEP 1: Check tolerance first, before checking if all completed + + # Handle fail-fast behavior (no completion config or empty completion config) + if completion_config is None: + if failure_count > 0: + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + else: + # Check if completion config has any criteria set + has_any_completion_criteria = ( + completion_config.min_successful is not None + or completion_config.tolerated_failure_count is not None + or completion_config.tolerated_failure_percentage is not None + ) + + if not has_any_completion_criteria: + # Empty completion config - fail fast on any failure + if failure_count > 0: + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + else: + # Check specific tolerance thresholds + if ( + completion_config.tolerated_failure_count is not None + and failure_count > completion_config.tolerated_failure_count + ): + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + if ( + completion_config.tolerated_failure_percentage is not None + and total_count > 0 + ): + failure_percentage = (failure_count / total_count) * 100 + if ( + failure_percentage + > completion_config.tolerated_failure_percentage + ): + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + # STEP 2: Check if all completed + if completed_count == total_count: + return CompletionReason.ALL_COMPLETED + + # STEP 3: Check if minimum successful reached + if ( + completion_config is not None + and completion_config.min_successful is not None + and success_count >= completion_config.min_successful + ): + return CompletionReason.MIN_SUCCESSFUL_REACHED + + # STEP 4: Default + return CompletionReason.ALL_COMPLETED + @classmethod def from_items( cls, @@ -123,12 +202,8 @@ def from_items( """ Infer completion reason based on batch item statuses and completion config. - This follows the same logic as the TypeScript implementation: - - If all items completed: ALL_COMPLETED - - If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED - - Otherwise: FAILURE_TOLERANCE_EXCEEDED + This follows the same logic as the TypeScript implementation. """ - statuses = (item.status for item in items) counts = Counter(statuses) succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0) @@ -138,18 +213,14 @@ def from_items( completed_count = succeeded_count + failed_count total_count = started_count + completed_count - # If all items completed (no started items), it's ALL_COMPLETED - if completed_count == total_count: - completion_reason = CompletionReason.ALL_COMPLETED - elif ( # If we have completion config and minSuccessful threshold is met - completion_config - and (min_successful := completion_config.min_successful) is not None - and succeeded_count >= min_successful - ): - completion_reason = CompletionReason.MIN_SUCCESSFUL_REACHED - else: - # Otherwise, assume failure tolerance was exceeded - completion_reason = CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # Determine completion reason using the same logic as JavaScript SDK + completion_reason = cls._get_completion_reason( + failure_count=failed_count, + success_count=succeeded_count, + completed_count=completed_count, + total_count=total_count, + completion_config=completion_config, + ) return cls(items, completion_reason) diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index b94f39f..0aeb3ac 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -335,7 +335,7 @@ def test_batch_result_from_dict_infer_all_completed_all_succeeded(): def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed(): - """Test BatchResult from_dict infers FAILURE_TOLERANCE_EXCEEDED when all items failed.""" + """Test BatchResult from_dict infers completion reason when all items failed.""" error_data = { "message": "Test error", "type": "TestError", @@ -350,18 +350,17 @@ def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed(): # No completionReason provided } - # even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED - # https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335 + # With no completion config and failures, should fail-fast with patch( "aws_durable_execution_sdk_python.concurrency.models.logger" ) as mock_logger: result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED mock_logger.warning.assert_called_once() def test_batch_result_from_dict_infer_all_completed_mixed_success_failure(): - """Test BatchResult from_dict infers ALL_COMPLETED when mix of success/failure but no started items.""" + """Test BatchResult from_dict infers completion reason with mix of success/failure.""" error_data = { "message": "Test error", "type": "TestError", @@ -377,12 +376,12 @@ def test_batch_result_from_dict_infer_all_completed_mixed_success_failure(): # No completionReason provided } - # the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed + # With no config and with failures, fail-fast with patch( "aws_durable_execution_sdk_python.concurrency.models.logger" ) as mock_logger: result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED mock_logger.warning.assert_called_once() @@ -440,25 +439,25 @@ def test_batch_result_from_dict_with_explicit_completion_reason(): def test_batch_result_infer_completion_reason_edge_cases(): """Test _infer_completion_reason method with various edge cases.""" - # Test with only started items + # Test with only started items and min_successful=0 started_items = [ BatchItem(0, BatchItemStatus.STARTED).to_dict(), BatchItem(1, BatchItemStatus.STARTED).to_dict(), ] items = {"all": started_items} batch = BatchResult.from_dict(items, CompletionConfig(0)) # SLF001 - # this state is not possible with CompletionConfig(0) + # With min_successful=0 and no failures, should be MIN_SUCCESSFUL_REACHED assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED - # Test with only started items + # Test with only started items and no config started_items = [ BatchItem(0, BatchItemStatus.STARTED).to_dict(), BatchItem(1, BatchItemStatus.STARTED).to_dict(), ] items = {"all": started_items} batch = BatchResult.from_dict(items) # SLF001 - # this state is not possible with CompletionConfig(0) - assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With no config and no completed items, defaults to ALL_COMPLETED + assert batch.completion_reason == CompletionReason.ALL_COMPLETED # Test with only failed items failed_items = [ @@ -471,7 +470,8 @@ def test_batch_result_infer_completion_reason_edge_cases(): ] failed_items = {"all": failed_items} batch = BatchResult.from_dict(failed_items) # SLF001 - assert batch.completion_reason == CompletionReason.ALL_COMPLETED + # With no config and failures, should fail-fast + assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED # Test with only succeeded items succeeded_items = [ @@ -482,7 +482,7 @@ def test_batch_result_infer_completion_reason_edge_cases(): batch = BatchResult.from_dict(succeeded_items) # SLF001 assert batch.completion_reason == CompletionReason.ALL_COMPLETED - # Test with mixed but no started (all completed) + # Test with mixed but no started (all completed) with tolerance mixed_items = [ BatchItem(0, BatchItemStatus.SUCCEEDED, "result1"), BatchItem( @@ -490,7 +490,9 @@ def test_batch_result_infer_completion_reason_edge_cases(): ), ] - batch = BatchResult.from_items(mixed_items) # SLF001 + batch = BatchResult.from_items( + mixed_items, CompletionConfig(tolerated_failure_count=1) + ) # SLF001 assert batch.completion_reason == CompletionReason.ALL_COMPLETED @@ -1101,10 +1103,9 @@ def failure_callable(): assert len(result.all) == 2 assert result.all[0].status == BatchItemStatus.SUCCEEDED assert result.all[1].status == BatchItemStatus.FAILED - # WHEN all items complete, THEN completion reason is ALL_COMPLETED. - # we don't consider thresholds and limits. - # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 - assert result.completion_reason == CompletionReason.ALL_COMPLETED + # NEW BEHAVIOR: With empty completion config (no criteria) and failures, + # should fail-fast and return FAILURE_TOLERANCE_EXCEEDED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED def test_concurrent_executor_execute_item_in_child_context(): @@ -1188,10 +1189,9 @@ def mock_run_in_child_context(func, name, config): return func(Mock()) result = executor.execute(execution_state, mock_run_in_child_context) - # WHEN all items complete, THEN completion reason is ALL_COMPLETED. - # we don't consider thresholds and limits. - # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 - assert result.completion_reason == CompletionReason.ALL_COMPLETED + # NEW BEHAVIOR: With tolerated_failure_count=0 and 1 failure, + # tolerance is exceeded, so FAILURE_TOLERANCE_EXCEEDED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED def test_single_task_suspend_bubbles_up(): @@ -1971,9 +1971,9 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # NEW BEHAVIOR: With min_successful=1 and no completed items, + # defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_running_branch(): @@ -2009,9 +2009,8 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_suspended_branch(): @@ -2046,9 +2045,8 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_suspended_with_timeout_branch(): @@ -2084,9 +2082,8 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, default to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_mixed_statuses(): @@ -2310,9 +2307,8 @@ def execute_item(self, child_context, executable): assert all(item.status == BatchItemStatus.STARTED for item in result.all) assert all(item.result is None for item in result.all) assert all(item.error is None for item in result.all) - # With completion config min_successful=1 and no completed items, - # this should be FAILURE_TOLERANCE_EXCEEDED - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_empty_executables(): @@ -2384,7 +2380,7 @@ def test_batch_result_from_dict_with_completion_config(): def test_batch_result_from_dict_all_completed(): - """Test BatchResult from_dict infers ALL_COMPLETED when all items are completed.""" + """Test BatchResult from_dict infers completion reason when all items are completed.""" data = { "all": [ {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, @@ -2403,11 +2399,12 @@ def test_batch_result_from_dict_all_completed(): # No completionReason provided } + # With no config and failures, fail-fast with patch( "aws_durable_execution_sdk_python.concurrency.models.logger" ) as mock_logger: result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED mock_logger.warning.assert_called_once() @@ -3153,3 +3150,149 @@ def test_timer_scheduler_fifo_ordering_with_same_timestamp(): # endregion TimerScheduler edge cases with exact same reschedule time + + +# region Completion Reason Inference Tests (from_items) + + +def test_from_items_no_config_with_failures(): + """Validates: Requirements 2.4 - Fail-fast with no config.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + result = BatchResult.from_items(items, completion_config=None) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_empty_config_with_failures(): + """Validates: Requirements 2.5 - Fail-fast with empty config.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig() # All fields None + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_tolerance_checked_before_all_completed(): + """Validates: Requirements 2.1, 2.2 - Tolerance priority.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 2, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(tolerated_failure_count=1) + result = BatchResult.from_items(items, completion_config=config) + # All completed but tolerance exceeded - should return TOLERANCE_EXCEEDED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_all_completed_within_tolerance(): + """Validates: Requirements 1.1 - All completed.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(tolerated_failure_count=1) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + + +def test_from_items_min_successful_reached(): + """Validates: Requirements 1.3 - Min successful.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem(1, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem(2, BatchItemStatus.STARTED), + ] + config = CompletionConfig(min_successful=2) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + +def test_from_items_tolerance_count_exceeded(): + """Validates: Requirements 1.2 - Tolerance count.""" + items = [ + BatchItem( + 0, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem(2, BatchItemStatus.STARTED), + ] + config = CompletionConfig(tolerated_failure_count=1) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_tolerance_percentage_exceeded(): + """Validates: Requirements 1.2 - Tolerance percentage.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 2, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 3, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(tolerated_failure_percentage=50.0) + # 3 failures out of 4 = 75% > 50% + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_tolerance_priority_over_min_successful(): + """Validates: Requirements 2.3 - Tolerance takes precedence.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem(1, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 2, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 3, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(min_successful=2, tolerated_failure_count=1) + # Min successful reached (2) but tolerance exceeded (2 > 1) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_empty_array(): + """Validates: Edge case - empty items.""" + items = [] + result = BatchResult.from_items(items, completion_config=None) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.total_count == 0 + + +def test_from_items_all_succeeded(): + """Validates: All items succeeded.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok1"), + BatchItem(1, BatchItemStatus.SUCCEEDED, result="ok2"), + ] + result = BatchResult.from_items(items, completion_config=None) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.success_count == 2 + + +# endregion Completion Reason Inference Tests