diff --git a/src/aws_durable_execution_sdk_python/concurrency/models.py b/src/aws_durable_execution_sdk_python/concurrency/models.py index 29ffeaf..a8137eb 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/models.py +++ b/src/aws_durable_execution_sdk_python/concurrency/models.py @@ -114,6 +114,85 @@ def from_dict( completion_reason = CompletionReason(completion_reason_value) return cls(batch_items, completion_reason) + @staticmethod + def _get_completion_reason( + failure_count: int, + success_count: int, + completed_count: int, + total_count: int, + completion_config: CompletionConfig | None, + ) -> CompletionReason: + """ + Determine completion reason based on completion counts. + + Logic order: + 1. Check failure tolerance FIRST (before checking if all completed) + 2. Check if all completed + 3. Check if minimum successful reached + 4. Default to ALL_COMPLETED + + Args: + failure_count: Number of failed items + success_count: Number of succeeded items + completed_count: Total completed (succeeded + failed) + total_count: Total number of items + completion_config: Optional completion configuration + + Returns: + CompletionReason enum value + """ + # STEP 1: Check tolerance first, before checking if all completed + + # Handle fail-fast behavior (no completion config or empty completion config) + if completion_config is None: + if failure_count > 0: + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + else: + # Check if completion config has any criteria set + has_any_completion_criteria = ( + completion_config.min_successful is not None + or completion_config.tolerated_failure_count is not None + or completion_config.tolerated_failure_percentage is not None + ) + + if not has_any_completion_criteria: + # Empty completion config - fail fast on any failure + if failure_count > 0: + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + else: + # Check specific tolerance thresholds + if ( + completion_config.tolerated_failure_count is not None + and failure_count > completion_config.tolerated_failure_count + ): + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + if ( + completion_config.tolerated_failure_percentage is not None + and total_count > 0 + ): + failure_percentage = (failure_count / total_count) * 100 + if ( + failure_percentage + > completion_config.tolerated_failure_percentage + ): + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + # STEP 2: Check if all completed + if completed_count == total_count: + return CompletionReason.ALL_COMPLETED + + # STEP 3: Check if minimum successful reached + if ( + completion_config is not None + and completion_config.min_successful is not None + and success_count >= completion_config.min_successful + ): + return CompletionReason.MIN_SUCCESSFUL_REACHED + + # STEP 4: Default + return CompletionReason.ALL_COMPLETED + @classmethod def from_items( cls, @@ -123,12 +202,8 @@ def from_items( """ Infer completion reason based on batch item statuses and completion config. - This follows the same logic as the TypeScript implementation: - - If all items completed: ALL_COMPLETED - - If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED - - Otherwise: FAILURE_TOLERANCE_EXCEEDED + This follows the same logic as the TypeScript implementation. """ - statuses = (item.status for item in items) counts = Counter(statuses) succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0) @@ -138,18 +213,14 @@ def from_items( completed_count = succeeded_count + failed_count total_count = started_count + completed_count - # If all items completed (no started items), it's ALL_COMPLETED - if completed_count == total_count: - completion_reason = CompletionReason.ALL_COMPLETED - elif ( # If we have completion config and minSuccessful threshold is met - completion_config - and (min_successful := completion_config.min_successful) is not None - and succeeded_count >= min_successful - ): - completion_reason = CompletionReason.MIN_SUCCESSFUL_REACHED - else: - # Otherwise, assume failure tolerance was exceeded - completion_reason = CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # Determine completion reason using the same logic as JavaScript SDK + completion_reason = cls._get_completion_reason( + failure_count=failed_count, + success_count=succeeded_count, + completed_count=completed_count, + total_count=total_count, + completion_config=completion_config, + ) return cls(items, completion_reason) diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index b94f39f..0aeb3ac 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -335,7 +335,7 @@ def test_batch_result_from_dict_infer_all_completed_all_succeeded(): def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed(): - """Test BatchResult from_dict infers FAILURE_TOLERANCE_EXCEEDED when all items failed.""" + """Test BatchResult from_dict infers completion reason when all items failed.""" error_data = { "message": "Test error", "type": "TestError", @@ -350,18 +350,17 @@ def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed(): # No completionReason provided } - # even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED - # https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335 + # With no completion config and failures, should fail-fast with patch( "aws_durable_execution_sdk_python.concurrency.models.logger" ) as mock_logger: result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED mock_logger.warning.assert_called_once() def test_batch_result_from_dict_infer_all_completed_mixed_success_failure(): - """Test BatchResult from_dict infers ALL_COMPLETED when mix of success/failure but no started items.""" + """Test BatchResult from_dict infers completion reason with mix of success/failure.""" error_data = { "message": "Test error", "type": "TestError", @@ -377,12 +376,12 @@ def test_batch_result_from_dict_infer_all_completed_mixed_success_failure(): # No completionReason provided } - # the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed + # With no config and with failures, fail-fast with patch( "aws_durable_execution_sdk_python.concurrency.models.logger" ) as mock_logger: result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED mock_logger.warning.assert_called_once() @@ -440,25 +439,25 @@ def test_batch_result_from_dict_with_explicit_completion_reason(): def test_batch_result_infer_completion_reason_edge_cases(): """Test _infer_completion_reason method with various edge cases.""" - # Test with only started items + # Test with only started items and min_successful=0 started_items = [ BatchItem(0, BatchItemStatus.STARTED).to_dict(), BatchItem(1, BatchItemStatus.STARTED).to_dict(), ] items = {"all": started_items} batch = BatchResult.from_dict(items, CompletionConfig(0)) # SLF001 - # this state is not possible with CompletionConfig(0) + # With min_successful=0 and no failures, should be MIN_SUCCESSFUL_REACHED assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED - # Test with only started items + # Test with only started items and no config started_items = [ BatchItem(0, BatchItemStatus.STARTED).to_dict(), BatchItem(1, BatchItemStatus.STARTED).to_dict(), ] items = {"all": started_items} batch = BatchResult.from_dict(items) # SLF001 - # this state is not possible with CompletionConfig(0) - assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With no config and no completed items, defaults to ALL_COMPLETED + assert batch.completion_reason == CompletionReason.ALL_COMPLETED # Test with only failed items failed_items = [ @@ -471,7 +470,8 @@ def test_batch_result_infer_completion_reason_edge_cases(): ] failed_items = {"all": failed_items} batch = BatchResult.from_dict(failed_items) # SLF001 - assert batch.completion_reason == CompletionReason.ALL_COMPLETED + # With no config and failures, should fail-fast + assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED # Test with only succeeded items succeeded_items = [ @@ -482,7 +482,7 @@ def test_batch_result_infer_completion_reason_edge_cases(): batch = BatchResult.from_dict(succeeded_items) # SLF001 assert batch.completion_reason == CompletionReason.ALL_COMPLETED - # Test with mixed but no started (all completed) + # Test with mixed but no started (all completed) with tolerance mixed_items = [ BatchItem(0, BatchItemStatus.SUCCEEDED, "result1"), BatchItem( @@ -490,7 +490,9 @@ def test_batch_result_infer_completion_reason_edge_cases(): ), ] - batch = BatchResult.from_items(mixed_items) # SLF001 + batch = BatchResult.from_items( + mixed_items, CompletionConfig(tolerated_failure_count=1) + ) # SLF001 assert batch.completion_reason == CompletionReason.ALL_COMPLETED @@ -1101,10 +1103,9 @@ def failure_callable(): assert len(result.all) == 2 assert result.all[0].status == BatchItemStatus.SUCCEEDED assert result.all[1].status == BatchItemStatus.FAILED - # WHEN all items complete, THEN completion reason is ALL_COMPLETED. - # we don't consider thresholds and limits. - # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 - assert result.completion_reason == CompletionReason.ALL_COMPLETED + # NEW BEHAVIOR: With empty completion config (no criteria) and failures, + # should fail-fast and return FAILURE_TOLERANCE_EXCEEDED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED def test_concurrent_executor_execute_item_in_child_context(): @@ -1188,10 +1189,9 @@ def mock_run_in_child_context(func, name, config): return func(Mock()) result = executor.execute(execution_state, mock_run_in_child_context) - # WHEN all items complete, THEN completion reason is ALL_COMPLETED. - # we don't consider thresholds and limits. - # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 - assert result.completion_reason == CompletionReason.ALL_COMPLETED + # NEW BEHAVIOR: With tolerated_failure_count=0 and 1 failure, + # tolerance is exceeded, so FAILURE_TOLERANCE_EXCEEDED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED def test_single_task_suspend_bubbles_up(): @@ -1971,9 +1971,9 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # NEW BEHAVIOR: With min_successful=1 and no completed items, + # defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_running_branch(): @@ -2009,9 +2009,8 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_suspended_branch(): @@ -2046,9 +2045,8 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_suspended_with_timeout_branch(): @@ -2084,9 +2082,8 @@ def execute_item(self, child_context, executable): assert result.all[0].result is None assert result.all[0].error is None assert result.all[0].index == 0 - # By default, if we've terminated the reasoning is failure tolerance exceeded - # according to the spec - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, default to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_mixed_statuses(): @@ -2310,9 +2307,8 @@ def execute_item(self, child_context, executable): assert all(item.status == BatchItemStatus.STARTED for item in result.all) assert all(item.result is None for item in result.all) assert all(item.error is None for item in result.all) - # With completion config min_successful=1 and no completed items, - # this should be FAILURE_TOLERANCE_EXCEEDED - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # With min_successful=1 and no completed items, defaults to ALL_COMPLETED + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_create_result_empty_executables(): @@ -2384,7 +2380,7 @@ def test_batch_result_from_dict_with_completion_config(): def test_batch_result_from_dict_all_completed(): - """Test BatchResult from_dict infers ALL_COMPLETED when all items are completed.""" + """Test BatchResult from_dict infers completion reason when all items are completed.""" data = { "all": [ {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, @@ -2403,11 +2399,12 @@ def test_batch_result_from_dict_all_completed(): # No completionReason provided } + # With no config and failures, fail-fast with patch( "aws_durable_execution_sdk_python.concurrency.models.logger" ) as mock_logger: result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED mock_logger.warning.assert_called_once() @@ -3153,3 +3150,149 @@ def test_timer_scheduler_fifo_ordering_with_same_timestamp(): # endregion TimerScheduler edge cases with exact same reschedule time + + +# region Completion Reason Inference Tests (from_items) + + +def test_from_items_no_config_with_failures(): + """Validates: Requirements 2.4 - Fail-fast with no config.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + result = BatchResult.from_items(items, completion_config=None) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_empty_config_with_failures(): + """Validates: Requirements 2.5 - Fail-fast with empty config.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig() # All fields None + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_tolerance_checked_before_all_completed(): + """Validates: Requirements 2.1, 2.2 - Tolerance priority.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 2, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(tolerated_failure_count=1) + result = BatchResult.from_items(items, completion_config=config) + # All completed but tolerance exceeded - should return TOLERANCE_EXCEEDED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_all_completed_within_tolerance(): + """Validates: Requirements 1.1 - All completed.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(tolerated_failure_count=1) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + + +def test_from_items_min_successful_reached(): + """Validates: Requirements 1.3 - Min successful.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem(1, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem(2, BatchItemStatus.STARTED), + ] + config = CompletionConfig(min_successful=2) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + +def test_from_items_tolerance_count_exceeded(): + """Validates: Requirements 1.2 - Tolerance count.""" + items = [ + BatchItem( + 0, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem(2, BatchItemStatus.STARTED), + ] + config = CompletionConfig(tolerated_failure_count=1) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_tolerance_percentage_exceeded(): + """Validates: Requirements 1.2 - Tolerance percentage.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 2, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 3, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(tolerated_failure_percentage=50.0) + # 3 failures out of 4 = 75% > 50% + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_tolerance_priority_over_min_successful(): + """Validates: Requirements 2.3 - Tolerance takes precedence.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem(1, BatchItemStatus.SUCCEEDED, result="ok"), + BatchItem( + 2, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + BatchItem( + 3, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + config = CompletionConfig(min_successful=2, tolerated_failure_count=1) + # Min successful reached (2) but tolerance exceeded (2 > 1) + result = BatchResult.from_items(items, completion_config=config) + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_from_items_empty_array(): + """Validates: Edge case - empty items.""" + items = [] + result = BatchResult.from_items(items, completion_config=None) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.total_count == 0 + + +def test_from_items_all_succeeded(): + """Validates: All items succeeded.""" + items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, result="ok1"), + BatchItem(1, BatchItemStatus.SUCCEEDED, result="ok2"), + ] + result = BatchResult.from_items(items, completion_config=None) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + assert result.success_count == 2 + + +# endregion Completion Reason Inference Tests