From 84d86f89eafecf46e4e5481d7e01a7d63f4eab1a Mon Sep 17 00:00:00 2001 From: Jon Sequeira Date: Wed, 11 Mar 2026 11:18:57 -0700 Subject: [PATCH] fix #2771 and #2772 --- .../Invocation/InvocationTests.cs | 95 ++++++++++++++++ src/System.CommandLine.Tests/TestActions.cs | 14 ++- .../Invocation/InvocationPipeline.cs | 106 +++++++++++------- 3 files changed, 170 insertions(+), 45 deletions(-) diff --git a/src/System.CommandLine.Tests/Invocation/InvocationTests.cs b/src/System.CommandLine.Tests/Invocation/InvocationTests.cs index 73181c2479..c523c86147 100644 --- a/src/System.CommandLine.Tests/Invocation/InvocationTests.cs +++ b/src/System.CommandLine.Tests/Invocation/InvocationTests.cs @@ -436,6 +436,101 @@ public async Task Nonterminating_option_actions_handle_exceptions_and_return_an_ returnCode.Should().Be(1); } + [Theory] // https://github.com/dotnet/command-line-api/issues/2771 + [InlineData(true)] + [InlineData(false)] + public async Task Nonterminating_option_action_is_invoked_when_command_has_no_action(bool invokeAsync) + { + bool optionActionWasCalled = false; + SynchronousTestAction optionAction = new(_ => optionActionWasCalled = true, terminating: false); + + Option option = new("--test") + { + Action = optionAction + }; + RootCommand command = new() + { + option + }; + + ParseResult parseResult = command.Parse("--test"); + + if (invokeAsync) + { + await parseResult.InvokeAsync(); + } + else + { + parseResult.Invoke(); + } + + optionActionWasCalled.Should().BeTrue(); + } + + [Theory] // https://github.com/dotnet/command-line-api/issues/2772 + [InlineData(true)] + [InlineData(false)] + public async Task Nonterminating_option_action_return_value_is_propagated(bool invokeAsync) + { + SynchronousTestAction optionAction = new(_ => { }, terminating: false, returnValue: 42); + + Option option = new("--test") + { + Action = optionAction + }; + RootCommand command = new() + { + option + }; + command.SetAction(_ => { }); + + ParseResult parseResult = command.Parse("--test"); + + int result; + if (invokeAsync) + { + result = await parseResult.InvokeAsync(); + } + else + { + result = parseResult.Invoke(); + } + + result.Should().Be(42); + } + + [Theory] // https://github.com/dotnet/command-line-api/issues/2772 + [InlineData(true)] + [InlineData(false)] + public async Task When_preaction_and_command_action_both_return_nonzero_then_preaction_value_wins(bool invokeAsync) + { + SynchronousTestAction optionAction = new(_ => { }, terminating: false, returnValue: 42); + + Option option = new("--test") + { + Action = optionAction + }; + RootCommand command = new() + { + option + }; + command.SetAction(_ => 99); + + ParseResult parseResult = command.Parse("--test"); + + int result; + if (invokeAsync) + { + result = await parseResult.InvokeAsync(); + } + else + { + result = parseResult.Invoke(); + } + + result.Should().Be(42); + } + [Fact] public async Task Command_InvokeAsync_with_cancelation_token_invokes_command_handler() { diff --git a/src/System.CommandLine.Tests/TestActions.cs b/src/System.CommandLine.Tests/TestActions.cs index ab44b12926..629ab3d3ee 100644 --- a/src/System.CommandLine.Tests/TestActions.cs +++ b/src/System.CommandLine.Tests/TestActions.cs @@ -10,14 +10,17 @@ namespace System.CommandLine.Tests; public class SynchronousTestAction : SynchronousCommandLineAction { private readonly Action _invoke; + private readonly int _returnValue; public SynchronousTestAction( Action invoke, bool terminating = true, - bool clearsParseErrors = false) + bool clearsParseErrors = false, + int returnValue = 0) { ClearsParseErrors = clearsParseErrors; _invoke = invoke; + _returnValue = returnValue; Terminating = terminating; } @@ -28,21 +31,24 @@ public SynchronousTestAction( public override int Invoke(ParseResult parseResult) { _invoke(parseResult); - return 0; + return _returnValue; } } public class AsynchronousTestAction : AsynchronousCommandLineAction { private readonly Action _invoke; + private readonly int _returnValue; public AsynchronousTestAction( Action invoke, bool terminating = true, - bool clearsParseErrors = false) + bool clearsParseErrors = false, + int returnValue = 0) { ClearsParseErrors = clearsParseErrors; _invoke = invoke; + _returnValue = returnValue; Terminating = terminating; } @@ -53,6 +59,6 @@ public AsynchronousTestAction( public override Task InvokeAsync(ParseResult parseResult, CancellationToken cancellationToken = default) { _invoke(parseResult); - return Task.FromResult(0); + return Task.FromResult(_returnValue); } } \ No newline at end of file diff --git a/src/System.CommandLine/Invocation/InvocationPipeline.cs b/src/System.CommandLine/Invocation/InvocationPipeline.cs index c41d2686d6..7b929809d2 100644 --- a/src/System.CommandLine/Invocation/InvocationPipeline.cs +++ b/src/System.CommandLine/Invocation/InvocationPipeline.cs @@ -10,38 +10,52 @@ internal static class InvocationPipeline { internal static async Task InvokeAsync(ParseResult parseResult, CancellationToken cancellationToken) { - if (parseResult.Action is null) - { - return ReturnCodeForMissingAction(parseResult); - } - ProcessTerminationHandler? terminationHandler = null; using CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); try { + int exitCode = 0; + if (parseResult.PreActions is not null) { for (int i = 0; i < parseResult.PreActions.Count; i++) { var action = parseResult.PreActions[i]; + int preActionResult; switch (action) { case SynchronousCommandLineAction syncAction: - syncAction.Invoke(parseResult); + preActionResult = syncAction.Invoke(parseResult); break; case AsynchronousCommandLineAction asyncAction: - await asyncAction.InvokeAsync(parseResult, cts.Token); + preActionResult = await asyncAction.InvokeAsync(parseResult, cts.Token); + break; + default: + preActionResult = 0; break; } + + if (exitCode == 0) + { + exitCode = preActionResult; + } } } + if (parseResult.Action is null) + { + return exitCode != 0 ? exitCode : ReturnCodeForMissingAction(parseResult); + } + + int actionResult; + switch (parseResult.Action) { case SynchronousCommandLineAction syncAction: - return syncAction.Invoke(parseResult); + actionResult = syncAction.Invoke(parseResult); + break; case AsynchronousCommandLineAction asyncAction: var startedInvocation = asyncAction.InvokeAsync(parseResult, cts.Token); @@ -55,7 +69,7 @@ internal static async Task InvokeAsync(ParseResult parseResult, Cancellatio if (terminationHandler is null) { - return await startedInvocation; + actionResult = await startedInvocation; } else { @@ -63,12 +77,15 @@ internal static async Task InvokeAsync(ParseResult parseResult, Cancellatio // In such cases, when CancelOnProcessTermination is configured and user presses Ctrl+C, // ProcessTerminationCompletionSource completes first, with the result equal to native exit code for given signal. Task firstCompletedTask = await Task.WhenAny(startedInvocation, terminationHandler.ProcessTerminationCompletionSource.Task); - return await firstCompletedTask; // return the result or propagate the exception + actionResult = await firstCompletedTask; // return the result or propagate the exception } + break; default: throw new ArgumentOutOfRangeException(nameof(parseResult.Action)); } + + return exitCode != 0 ? exitCode : actionResult; } catch (Exception ex) when (parseResult.InvocationConfiguration.EnableDefaultExceptionHandler) { @@ -82,48 +99,55 @@ internal static async Task InvokeAsync(ParseResult parseResult, Cancellatio internal static int Invoke(ParseResult parseResult) { - switch (parseResult.Action) + try { - case null: - return ReturnCodeForMissingAction(parseResult); + int exitCode = 0; - case SynchronousCommandLineAction syncAction: - try + if (parseResult.PreActions is not null) + { +#if DEBUG + for (var i = 0; i < parseResult.PreActions.Count; i++) { - if (parseResult.PreActions is not null) + var action = parseResult.PreActions[i]; + + if (action is not SynchronousCommandLineAction) { -#if DEBUG - for (var i = 0; i < parseResult.PreActions.Count; i++) - { - var action = parseResult.PreActions[i]; - - if (action is not SynchronousCommandLineAction) - { - parseResult.InvocationConfiguration.EnableDefaultExceptionHandler = false; - throw new Exception( - $"This should not happen. An instance of {nameof(AsynchronousCommandLineAction)} ({action}) was called within {nameof(InvocationPipeline)}.{nameof(Invoke)}. This is supposed to be detected earlier resulting in a call to {nameof(InvocationPipeline)}{nameof(InvokeAsync)}"); - } - } + parseResult.InvocationConfiguration.EnableDefaultExceptionHandler = false; + throw new Exception( + $"This should not happen. An instance of {nameof(AsynchronousCommandLineAction)} ({action}) was called within {nameof(InvocationPipeline)}.{nameof(Invoke)}. This is supposed to be detected earlier resulting in a call to {nameof(InvocationPipeline)}{nameof(InvokeAsync)}"); + } + } #endif - for (var i = 0; i < parseResult.PreActions.Count; i++) + for (var i = 0; i < parseResult.PreActions.Count; i++) + { + if (parseResult.PreActions[i] is SynchronousCommandLineAction syncPreAction) + { + int preActionResult = syncPreAction.Invoke(parseResult); + if (exitCode == 0) { - if (parseResult.PreActions[i] is SynchronousCommandLineAction syncPreAction) - { - syncPreAction.Invoke(parseResult); - } + exitCode = preActionResult; } } - - return syncAction.Invoke(parseResult); - } - catch (Exception ex) when (parseResult.InvocationConfiguration.EnableDefaultExceptionHandler) - { - return DefaultExceptionHandler(ex, parseResult); } + } + + switch (parseResult.Action) + { + case null: + return exitCode != 0 ? exitCode : ReturnCodeForMissingAction(parseResult); - default: - throw new InvalidOperationException($"{nameof(AsynchronousCommandLineAction)} called within non-async invocation."); + case SynchronousCommandLineAction syncAction: + int actionResult = syncAction.Invoke(parseResult); + return exitCode != 0 ? exitCode : actionResult; + + default: + throw new InvalidOperationException($"{nameof(AsynchronousCommandLineAction)} called within non-async invocation."); + } + } + catch (Exception ex) when (parseResult.InvocationConfiguration.EnableDefaultExceptionHandler) + { + return DefaultExceptionHandler(ex, parseResult); } }