diff --git a/lib/mars/workflows/parallel.rb b/lib/mars/workflows/parallel.rb index 1038b64..01067b4 100644 --- a/lib/mars/workflows/parallel.rb +++ b/lib/mars/workflows/parallel.rb @@ -12,11 +12,8 @@ def initialize(name, steps:, aggregator: nil, **kwargs) def run(context) context = ensure_context(context) - errors = [] child_contexts = [] - results = execute_steps(context, errors, child_contexts) - - raise AggregateError, errors if errors.any? + results = execute_steps(context, child_contexts) context.merge(child_contexts) context.current_input = results @@ -27,21 +24,19 @@ def run(context) attr_reader :steps, :aggregator - def execute_steps(context, errors, child_contexts) - Async do |workflow| + def execute_steps(context, child_contexts) + Sync do |workflow| tasks = steps.map do |step| child_ctx = context.fork(state: step.state) child_contexts << child_ctx workflow.async do workflow_step(step, child_ctx) - rescue StandardError => e - errors << { error: e, step_name: step.name } end end tasks.map(&:wait) - end.result + end end def workflow_step(step, child_ctx) diff --git a/spec/mars/workflows/parallel_spec.rb b/spec/mars/workflows/parallel_spec.rb index 7b4afb8..e732af0 100644 --- a/spec/mars/workflows/parallel_spec.rb +++ b/spec/mars/workflows/parallel_spec.rb @@ -161,16 +161,41 @@ def run(context) = context.current_input end it "propagates errors from steps" do + Console.logger.level = Console::Logger::FATAL # avoid logging the errors for this test + add_step = add_step_class.new(5, name: "add") error_step = error_step_class.new("Step failed", name: "error_step_one") error_step_two = error_step_class.new("Step failed two", name: "error_step_two") workflow = described_class.new("error_workflow", steps: [add_step, error_step, error_step_two]) - expect { workflow.run(10) }.to raise_error( - MARS::Workflows::AggregateError, - "error_step_one: Step failed\nerror_step_two: Step failed two" - ) + expect { workflow.run(10) }.to raise_error(StandardError, "Step failed") + end + + context "when steps are parallel workflows" do + let(:flatten_sum_aggregator) do + MARS::Aggregator.new("Sum Aggregator", operation: ->(inputs) { inputs.flatten.sum }) + end + + it "executes nested parallel workflows correctly" do + add_five = add_step_class.new(5, name: "add_five") + multiply_three = multiply_step_class.new(3, name: "multiply_three") + inner_workflow1 = described_class.new("inner_workflow_1", steps: [add_five, multiply_three]) + + add_two = add_step_class.new(2, name: "add_two") + multiply_four = multiply_step_class.new(4, name: "multiply_four") + inner_workflow2 = described_class.new("inner_workflow_2", steps: [add_two, multiply_four]) + outer_workflow = described_class.new( + "outer_workflow", + steps: [inner_workflow1, inner_workflow2], + aggregator: flatten_sum_aggregator + ) + + # inner_workflow_1: 10 + 5 = 15, 10 * 3 = 30 + # inner_workflow_2: 10 + 2 = 12, 10 * 4 = 40 + # outer_workflow results: [15, 30, 12, 40] => sum = 97 + expect(outer_workflow.run(10)).to eq(97) + end end end