From cc8da2fd2a8ed3125e8e802002be5014b1d80336 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 10 Oct 2025 18:10:00 -0500 Subject: [PATCH 1/9] Extract AsyncHelper from SqlUtil.cs into the utilities namespace # Conflicts: # src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj # src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs # src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs # src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj # src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs # src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs --- .../Connection/SqlConnectionInternal.cs | 12 +- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 3 +- .../Data/SqlClient/SqlCommand.Encryption.cs | 1 + .../Data/SqlClient/SqlCommand.NonQuery.cs | 75 ++--- .../Data/SqlClient/SqlCommand.Reader.cs | 2 +- .../Data/SqlClient/SqlCommand.Xml.cs | 61 +++-- .../Microsoft/Data/SqlClient/SqlConnection.cs | 4 +- .../src/Microsoft/Data/SqlClient/SqlUtil.cs | 244 ----------------- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 4 +- .../Data/SqlClient/TdsParserStateObject.cs | 37 ++- .../Data/SqlClient/Utilities/AsyncHelper.cs | 256 ++++++++++++++++++ 11 files changed, 357 insertions(+), 342 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs index 295286b1e5..db80c05919 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.Data.Common; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Security; using System.Text; @@ -14,12 +13,15 @@ using System.Threading.Tasks; using System.Transactions; using Microsoft.Data.Common; -using Microsoft.Data.Common.ConnectionString; using Microsoft.Data.ProviderBase; -using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.ConnectionPool; -using IsolationLevel = System.Data.IsolationLevel; using Microsoft.Data.SqlClient.Internal; +using Microsoft.Data.SqlClient.Utilities; +using IsolationLevel = System.Data.IsolationLevel; + +#if NETFRAMEWORK +using Microsoft.Data.Common.ConnectionString; +#endif namespace Microsoft.Data.SqlClient.Connection { @@ -400,7 +402,7 @@ internal SqlConnectionInternal( try { - // If we want to consider pool operations against the overall connect timeout, + // If we want to consider pool operations against the overall connect timeout, // use the provided timeout. Otherwise, start a fresh timeout to receive the full // connect timeout. _timeout = ResolveLoginTimeout(timeout, connectionOptions.ConnectTimeout); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index af2091d7f6..b5306bc0d8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -17,6 +17,7 @@ using Microsoft.Data.Common; using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.Internal; +using Microsoft.Data.SqlClient.Utilities; namespace Microsoft.Data.SqlClient { @@ -776,7 +777,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i // Remove it from our unmatched set. unmatchedColumns.Remove(localColumn.MappedDestinationColumn); - + // Check for column types that we refuse to bulk load, even // though we found a match. // diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs index fa5cfb5ab2..9f95cbe8d0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs @@ -14,6 +14,7 @@ using System.Threading.Tasks; using Microsoft.Data.Common; using Microsoft.Data.SqlClient.Connection; +using Microsoft.Data.SqlClient.Utilities; namespace Microsoft.Data.SqlClient { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs index 141416537f..41602d392b 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs @@ -12,6 +12,7 @@ using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.Internal; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; @@ -82,11 +83,11 @@ public override int ExecuteNonQuery() #if NETFRAMEWORK SqlConnection.ExecutePermission.Demand(); #endif - + // Reset _pendingCancel upon entry into any Execute - used to synchronize state // between entry into Execute* API and the thread obtaining the stateObject. _pendingCancel = false; - + using var diagnosticScope = s_diagnosticListener.CreateCommandScope(this, _transaction); using var eventScope = SqlClientEventScope.Create($"SqlCommand.ExecuteNonQuery | API | Object Id {ObjectID}"); @@ -150,9 +151,9 @@ public override Task ExecuteNonQueryAsync(CancellationToken cancellationTok IsProviderRetriable ? InternalExecuteNonQueryWithRetryAsync(cancellationToken) : InternalExecuteNonQueryAsync(cancellationToken); - + #endregion - + #region Private Methods // @TODO: This can be inlined into InternalExecuteNonQueryAsync before restructuring into async pathway @@ -164,7 +165,7 @@ private IAsyncResult BeginExecuteNonQueryAsync(AsyncCallback callback, object st $"Activity Id {ActivityCorrelator.Current}, " + $"Client Connection Id {_activeConnection?.ClientConnectionId}, " + $"Command Text '{CommandText}'"); - + return BeginExecuteNonQueryInternal( CommandBehavior.Default, callback, @@ -247,7 +248,7 @@ private IAsyncResult BeginExecuteNonQueryInternal( // When we use query caching for parameter encryption we need to retry on specific errors. // In these cases finalize the call internally and trigger a retry when needed. - // @TODO: store this method call in a variable, it's faaaaar too big to be used in an if statement + // @TODO: store this method call in a variable, it's faaaaar too big to be used in an if statement if ( !TriggerInternalEndAndRetryIfNecessary( behavior, @@ -270,7 +271,7 @@ private IAsyncResult BeginExecuteNonQueryInternal( { globalCompletion = localCompletion; } - + // Add callback after work is done to avoid overlapping Begin/End methods if (callback is not null) { @@ -317,31 +318,31 @@ private void CleanupAfterExecuteNonQueryAsync(Task task, TaskCompletionSour if (task.IsFaulted) { Exception e = task.Exception?.InnerException; - + s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); - + source.SetException(e); } else if (task.IsCanceled) { s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); - + source.SetCanceled(); } else { // Task successful s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); - + source.SetResult(task.Result); } } - + // @TODO: This can be inlined into InternalExecuteNonQueryAsync before restructuring into async pathway private int EndExecuteNonQueryAsync(IAsyncResult asyncResult) { Debug.Assert(!_internalEndExecuteInitiated || _stateObj == null); - + SqlClientEventSource.Log.TryCorrelationTraceEvent( "SqlCommand.EndExecuteNonQueryAsync | Info | Correlation | " + $"Object Id {ObjectID}, " + @@ -357,7 +358,7 @@ private int EndExecuteNonQueryAsync(IAsyncResult asyncResult) ReliablePutStateObject(); throw asyncException.InnerException; } - + ThrowIfReconnectionHasBeenCanceled(); // lock on _stateObj prevents races with close/cancel. // If we have already initiated the End call internally, we have already done that, so @@ -369,7 +370,7 @@ private int EndExecuteNonQueryAsync(IAsyncResult asyncResult) return EndExecuteNonQueryInternal(asyncResult); } } - + return EndExecuteNonQueryInternal(asyncResult); } @@ -413,7 +414,7 @@ private int EndExecuteNonQueryInternal(IAsyncResult asyncResult) WriteEndExecuteEvent(success, sqlExceptionNumber, isSynchronous: false); } } - + // @TODO: Return int? private object InternalEndExecuteNonQuery( IAsyncResult asyncResult, @@ -426,17 +427,17 @@ private object InternalEndExecuteNonQuery( $"Client Connection Id {_activeConnection?.ClientConnectionId}, " + $"MARS={_activeConnection?.Parser.MARSOn}, " + $"AsyncCommandInProgress={_activeConnection?.AsyncCommandInProgress}"); - + VerifyEndExecuteState((Task)asyncResult, endMethod); WaitForAsyncResults(asyncResult, isInternal); - + // If column encryption is enabled, also check the state after waiting for the task. // It would be better to do this for all cases, but avoiding for compatibility reasons. if (IsColumnEncryptionEnabled) { VerifyEndExecuteState((Task)asyncResult, endMethod, fullCheckForColumnEncryption: true); } - + bool processFinallyBlock = true; try { @@ -508,13 +509,13 @@ private object InternalEndExecuteNonQuery( PutStateObject(); } } - + Debug.Assert(_stateObj == null, "non-null state object in EndExecuteNonQuery"); - + return _rowsAffected; // @TODO: CER Exception Handling was removed here (see GH#3581) } - + // @TODO: Restructure to make this a sync-only method private Task InternalExecuteNonQuery( TaskCompletionSource completion, @@ -536,7 +537,7 @@ private Task InternalExecuteNonQuery( SqlStatistics statistics = Statistics; _rowsAffected = -1; - + // @TODO: Break into smaller methods ("full" and "simple") // This function may throw for an invalid connection @@ -630,16 +631,16 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok #if NETFRAMEWORK SqlConnection.ExecutePermission.Demand(); #endif - + SqlClientEventSource.Log.TryCorrelationTraceEvent( "SqlCommand.InternalExecuteNonQueryAsync | API | Correlation | " + $"Object Id {ObjectID}, " + $"Activity Id {ActivityCorrelator.Current}, " + $"Client Connection Id {_activeConnection?.ClientConnectionId}, " + $"Command Text '{CommandText}'"); - + Guid operationId = s_diagnosticListener.WriteCommandBefore(this, _transaction); - + // Connection can be used as state in RegisterForConnectionCloseNotification continuation // to avoid an allocation so use it as the state value if possible but it can be changed if // you need it for a more important piece of data that justifies the tuple allocation later @@ -698,14 +699,14 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok catch (Exception e) { s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); - + source.SetException(e); context.Dispose(); } return returnedTask; } - + private Task InternalExecuteNonQueryWithRetry( // @TODO: Task is ignored bool sendToPipe, int timeout, @@ -725,7 +726,7 @@ private Task InternalExecuteNonQueryWithRetry( // @TODO: Task is ignored asyncWrite, isRetry, methodName)); - + usedCache = innerUsedCache; return result; } @@ -735,7 +736,7 @@ private Task InternalExecuteNonQueryWithRetryAsync(CancellationToken cancel sender: this, function: () => InternalExecuteNonQueryAsync(cancellationToken), cancellationToken); - + // @TODO: Sort args, drop TDS from name // @TODO: Restructure to make this the common method for sync and async methods (not InternalExecuteNonQuery) private Task RunExecuteNonQueryTds(string methodName, bool isAsync, int timeout, bool asyncWrite) @@ -755,7 +756,7 @@ private Task RunExecuteNonQueryTds(string methodName, bool isAsync, int timeout, TaskCompletionSource completion = new TaskCompletionSource(); _activeConnection.RegisterWaitingForReconnect(completion.Task); _reconnectionCompletionSource = completion; - + // Basically, this RunExecuteNonQueryTds onto the end of the reconnection RunExecuteNonQueryTdsSetupReconnnectContinuation( methodName, @@ -842,12 +843,12 @@ private Task RunExecuteNonQueryTds(string methodName, bool isAsync, int timeout, return null; } - + /// /// Since we use CompareExchange, we cannot make the reconnect success continuation static. /// Thus, we cannot use the "WithState" continuation helper. If this was part of /// RunExecuteNonQueryTds, we would be allocating the lambda each time. So, we make this a - /// separate method. + /// separate method. /// // @TODO: Sort args, fix name private void RunExecuteNonQueryTdsSetupReconnnectContinuation( @@ -865,7 +866,7 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation( timeout, static () => SQL.CR_ReconnectTimeout(), timeoutCts.Token); - + AsyncHelper.ContinueTask( reconnectTask, completion, @@ -875,7 +876,7 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation( { return; } - + Interlocked.CompareExchange(ref _reconnectionCompletionSource, null, completion); timeoutCts.Cancel(); @@ -899,14 +900,14 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation( } }); } - + #endregion internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext { public SqlCommand Command => _owner; - + public Guid OperationId { get; set; } public TaskCompletionSource TaskCompletionSource => _source; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index a98d82118b..9a9615d5ff 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -13,10 +13,10 @@ using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.Internal; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; -using Microsoft.Data.SqlClient.Utilities; #endif namespace Microsoft.Data.SqlClient diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs index 2be6166e8b..e39a794453 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs @@ -13,6 +13,7 @@ using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.Server; using Microsoft.Data.SqlClient.Internal; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; @@ -43,7 +44,7 @@ public IAsyncResult BeginExecuteXmlReader(AsyncCallback callback, object stateOb #if NETFRAMEWORK SqlConnection.ExecutePermission.Demand(); #endif - + SqlClientEventSource.Log.TryCorrelationTraceEvent( "SqlCommand.BeginExecuteXmlReader | API | Correlation | " + $"Object Id {ObjectID}, " + @@ -77,18 +78,18 @@ public XmlReader EndExecuteXmlReader(IAsyncResult asyncResult) $"Command Text '{CommandText}'"); } } - + /// public XmlReader ExecuteXmlReader() { #if NETFRAMEWORK SqlConnection.ExecutePermission.Demand(); #endif - + // Reset _pendingCancel upon entry into any Execute - used to synchronize state // between entry into Execute* API and the thread obtaining the stateObject. _pendingCancel = false; - + using var diagnosticScope = s_diagnosticListener.CreateCommandScope(this, _transaction); using var eventScope = SqlClientEventScope.Create($"SqlCommand.ExecuteXmlReader | API | Object Id {ObjectID}"); @@ -142,11 +143,11 @@ public Task ExecuteXmlReaderAsync(CancellationToken cancellationToken IsProviderRetriable ? InternalExecuteXmlReaderWithRetryAsync(cancellationToken) : InternalExecuteXmlReaderAsync(cancellationToken); - + #endregion - + #region Private Methods - + private static XmlReader CompleteXmlReader(SqlDataReader dataReader, bool isAsync) { XmlReader xmlReader = null; @@ -189,7 +190,7 @@ private IAsyncResult BeginExecuteXmlReaderAsync(AsyncCallback callback, object s $"Activity Id {ActivityCorrelator.Current}, " + $"Client Connection Id {_activeConnection?.ClientConnectionId}, " + $"Command Text '{CommandText}'"); - + return BeginExecuteXmlReaderInternal( CommandBehavior.SequentialAccess, callback, @@ -215,7 +216,7 @@ private IAsyncResult BeginExecuteXmlReaderInternal( // Reset _pendingCancel upon entry into any Execute - used to synchronize state // between entry into Execute* API and the thread obtaining the stateObject. _pendingCancel = false; - + // Special case - done outside of try/catches to prevent putting a stateObj back // into pool when we should not. ValidateAsyncCommand(); @@ -247,7 +248,7 @@ private IAsyncResult BeginExecuteXmlReaderInternal( asyncWrite, isRetry); - // @TODO: NonQuery pathway has the continueTaskWithState block inside this try. One or the other seems wrong + // @TODO: NonQuery pathway has the continueTaskWithState block inside this try. One or the other seems wrong } catch (Exception e) when (ADP.IsCatchableOrSecurityExceptionType(e)) { @@ -352,21 +353,21 @@ private void CleanupAfterExecuteXmlReaderAsync( if (task.IsFaulted) { Exception e = task.Exception?.InnerException; - + s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); - + source.SetException(e); } else if (task.IsCanceled) { s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); - + source.SetCanceled(); } else { s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); - + source.SetResult(task.Result); } } @@ -374,7 +375,7 @@ private void CleanupAfterExecuteXmlReaderAsync( private XmlReader EndExecuteXmlReaderAsync(IAsyncResult asyncResult) { Debug.Assert(!_internalEndExecuteInitiated || _stateObj is null); - + SqlClientEventSource.Log.TryCorrelationTraceEvent( "SqlCommand.EndExecuteXmlReaderAsync | API | Correlation | " + $"Object Id {ObjectID}, " + @@ -389,9 +390,9 @@ private XmlReader EndExecuteXmlReaderAsync(IAsyncResult asyncResult) ReliablePutStateObject(); throw asyncException.InnerException; } - + ThrowIfReconnectionHasBeenCanceled(); - + // Locking _stateObj prevents races with close/cancel. // If we have already initiated the End call internally, we have already done that, so // no point doing it again. @@ -405,7 +406,7 @@ private XmlReader EndExecuteXmlReaderAsync(IAsyncResult asyncResult) return EndExecuteXmlReaderInternal(asyncResult); } - + private XmlReader EndExecuteXmlReaderInternal(IAsyncResult asyncResult) { bool success = false; @@ -443,27 +444,27 @@ private XmlReader EndExecuteXmlReaderInternal(IAsyncResult asyncResult) WriteEndExecuteEvent(success, sqlExceptionNumber, isSynchronous: false); } } - + private Task InternalExecuteXmlReaderAsync(CancellationToken cancellationToken) { #if NETFRAMEWORK SqlConnection.ExecutePermission.Demand(); #endif - + SqlClientEventSource.Log.TryCorrelationTraceEvent( "SqlCommand.InternalExecuteXmlReaderAsync | API | Correlation | " + $"Object Id {ObjectID}, " + $"Activity Id {ActivityCorrelator.Current}, " + $"Client Connection Id {_activeConnection?.ClientConnectionId}, " + $"Command Text '{CommandText}'"); - + Guid operationId = s_diagnosticListener.WriteCommandBefore(this, _transaction); - + // Connection can be used as state in RegisterForConnectionCloseNotification continuation // to avoid an allocation so use it as the state value if possible but it can be changed if // you need it for a more important piece of data that justifies the tuple allocation later TaskCompletionSource source = new TaskCompletionSource(_activeConnection); - + CancellationTokenRegistration registration = new CancellationTokenRegistration(); if (cancellationToken.CanBeCanceled) { @@ -476,7 +477,7 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella registration = cancellationToken.Register(callback: s_cancelIgnoreFailure, state: this); } - // @TODO: This can be cleaned up to lines if InnerConnection is always SqlInternalConnection + // @TODO: This can be cleaned up to lines if InnerConnection is always SqlInternalConnection ExecuteXmlReaderAsyncCallContext context = null; if (_activeConnection?.InnerConnection is SqlConnectionInternal sqlInternalConnection) { @@ -516,34 +517,34 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella TaskCompletionSource source = context.TaskCompletionSource; context.Dispose(); - + command.CleanupAfterExecuteXmlReaderAsync(task, source, operationId); }, scheduler: TaskScheduler.Default); } - catch (Exception e) + catch (Exception e) { s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); - + source.SetException(e); } return returnedTask; } - + private Task InternalExecuteXmlReaderWithRetryAsync(CancellationToken cancellationToken) => RetryLogicProvider.ExecuteAsync( sender: this, () => InternalExecuteXmlReaderAsync(cancellationToken), cancellationToken); - + #endregion internal sealed class ExecuteXmlReaderAsyncCallContext : AAsyncCallContext { public SqlCommand Command => _owner; - + public Guid OperationId { get; set; } public TaskCompletionSource TaskCompletionSource => _source; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs index 6a4d7b8972..b27b7081c6 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -23,8 +23,10 @@ using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.ConnectionPool; using Microsoft.Data.SqlClient.Diagnostics; -using Microsoft.SqlServer.Server; using Microsoft.Data.SqlClient.Internal; +using Microsoft.Data.SqlClient.Utilities; +using Microsoft.SqlServer.Server; + #if NETFRAMEWORK using System.Runtime.CompilerServices; using System.Security.Permissions; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs index 96312a491d..64a3b2f21f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -46,250 +46,6 @@ internal static ArgumentOutOfRangeException InvalidMinAndMaxPair(string minParam => new ArgumentOutOfRangeException(minParamName, StringsHelper.GetString(Strings.SqlRetryLogic_InvalidMinMaxPair, minValue, maxValue, minParamName, maxParamName)); } - internal static class AsyncHelper - { - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - Action onFailure = null) - { - if (task == null) - { - onSuccess(); - return null; - } - else - { - TaskCompletionSource completion = new TaskCompletionSource(); - ContinueTaskWithState( - task, - completion, - state: Tuple.Create(onSuccess, onFailure, completion), - onSuccess: static (object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action success = parameters.Item1; - TaskCompletionSource taskCompletionSource = parameters.Item3; - success(); - taskCompletionSource.SetResult(null); - }, - onFailure: static (Exception exception, object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action failure = parameters.Item2; - failure?.Invoke(exception); - } - ); - return completion.Task; - } - } - - internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) - { - if (task == null) - { - onSuccess(state); - return null; - } - else - { - var completion = new TaskCompletionSource(); - ContinueTaskWithState(task, completion, state, - onSuccess: (object continueState) => - { - onSuccess(continueState); - completion.SetResult(null); - }, - onFailure: onFailure - ); - return completion.Task; - } - } - - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - T1 arg1, - T2 arg2, - Action onFailure = null) - { - return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); - } - - internal static void ContinueTask(Task task, - TaskCompletionSource completion, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - tsk => - { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } - try - { - onFailure?.Invoke(exc); - } - finally - { - completion.TrySetException(exc); - } - } - else if (tsk.IsCanceled) - { - try - { - onCancellation?.Invoke(); - } - finally - { - completion.TrySetCanceled(); - } - } - else - { - try - { - onSuccess(); - } - // @TODO: CER Exception Handling was removed here (see GH#3581) - catch (Exception e) - { - completion.SetException(e); - } - } - }, TaskScheduler.Default - ); - } - - // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure - // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties - internal static void ContinueTaskWithState(Task task, - TaskCompletionSource completion, - object state, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - (Task tsk, object state2) => - { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } - - try - { - onFailure?.Invoke(exc, state2); - } - finally - { - completion.TrySetException(exc); - } - } - else if (tsk.IsCanceled) - { - try - { - onCancellation?.Invoke(state2); - } - finally - { - completion.TrySetCanceled(); - } - } - else - { - try - { - onSuccess(state2); - } - // @TODO: CER Exception Handling was removed here (see GH#3581) - catch (Exception e) - { - completion.SetException(e); - } - } - }, - state: state, - scheduler: TaskScheduler.Default - ); - } - - internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) - { - try - { - task.Wait(timeout > 0 ? (1000 * timeout) : Timeout.Infinite); - } - catch (AggregateException ae) - { - if (rethrowExceptions) - { - Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); - ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); - } - } - if (!task.IsCompleted) - { - task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception - onTimeout?.Invoke(); - } - } - - internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) - { - if (timeout > 0) - { - Task.Delay(timeout * 1000, ctoken).ContinueWith( - (Task task) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure()); - } - } - ); - } - } - - internal static void SetTimeoutExceptionWithState( - TaskCompletionSource completion, - int timeout, - object state, - Func onFailure, - CancellationToken cancellationToken) - { - if (timeout <= 0) - { - return; - } - - Task.Delay(timeout * 1000, cancellationToken).ContinueWith( - (task, innerState) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure(innerState)); - } - }, - state: state, - cancellationToken: CancellationToken.None); - } - } - internal static class SQL { // The class SQL defines the exceptions that are specific to the SQL Adapter. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index bb625d563d..e0a5e7cd8a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -25,15 +25,13 @@ using Microsoft.Data.SqlClient.DataClassification; using Microsoft.Data.SqlClient.LocalDb; using Microsoft.Data.SqlClient.Server; +using Microsoft.Data.SqlClient.Internal; using Microsoft.Data.SqlClient.Utilities; using Microsoft.SqlServer.Server; -using Microsoft.Data.SqlClient.Internal; #if NETFRAMEWORK using System.Runtime.CompilerServices; -#if _WINDOWS using Interop.Windows.Sni; -#endif using Microsoft.Data.SqlTypes; #endif diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 45ed29c2ce..a1bcec983f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -17,6 +17,7 @@ using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.ManagedSni; using Microsoft.Data.SqlClient.Internal; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Runtime.ConstrainedExecution; @@ -24,10 +25,6 @@ namespace Microsoft.Data.SqlClient { -#if NETFRAMEWORK - using RuntimeHelpers = System.Runtime.CompilerServices.RuntimeHelpers; -#endif - sealed internal class LastIOTimer { internal long _value; @@ -189,7 +186,7 @@ private enum SnapshotStatus private readonly TimerCallback _onTimeoutAsync; private readonly WeakReference _cancellationOwner = new WeakReference(null); - // Below 2 properties are used to enforce timeout delays in code to + // Below 2 properties are used to enforce timeout delays in code to // reproduce issues related to theadpool starvation and timeout delay. // It should always be set to false by default, and only be enabled during testing. internal bool _enforceTimeoutDelay = false; @@ -2178,7 +2175,7 @@ internal TdsOperationStatus TryReadStringWithEncoding(int length, System.Text.En buf = TryTakeSnapshotStorage() as byte[]; Debug.Assert(buf != null || !isContinuing, "if continuing stored buffer must be present to contain previous data to continue from"); Debug.Assert(buf == null || buf.Length == length, "stored buffer length must be null or must have been created with the correct length"); - + if (buf != null) { startOffset = GetSnapshotTotalSize(); @@ -2191,7 +2188,7 @@ internal TdsOperationStatus TryReadStringWithEncoding(int length, System.Text.En } TdsOperationStatus result = TryReadByteArray(buf, length, out _, startOffset, canContinue); - + if (result != TdsOperationStatus.Done) { if (result == TdsOperationStatus.NeedMoreData) @@ -3485,7 +3482,7 @@ internal TdsOperationStatus TryReadNetworkPacket() while (_inBytesRead == 0) { // a partial packet must have taken the packet data so we - // need to read more data to complete the packet, but we + // need to read more data to complete the packet, but we // can't return NeedMoreData in sync mode so we have to // spin fetching more data here until we have something // that the caller can read @@ -3772,7 +3769,7 @@ private void OnTimeoutAsync(object state) TimeoutState timeoutState = (TimeoutState)state; if (timeoutState.IdentityValue == _timeoutIdentityValue) { - // the return value is not useful here because no choice is going to be made using it + // the return value is not useful here because no choice is going to be made using it // we only want to make this call to set the state knowing that it will be seen later OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredAsync); } @@ -3935,7 +3932,7 @@ internal void ReadSni(TaskCompletionSource completion) { Debug.Assert(completion != null, "Async on but null asyncResult passed"); - // if the state is currently stopped then change it to running and allocate a new identity value from + // if the state is currently stopped then change it to running and allocate a new identity value from // the identity source. The identity value is used to correlate timer callback events to the currently // running timeout and prevents a late timer callback affecting a result it does not relate to int previousTimeoutState = Interlocked.CompareExchange(ref _timeoutState, TimeoutState.Running, TimeoutState.Stopped); @@ -4372,7 +4369,7 @@ private void WriteBytesSetupContinuation(byte[] array, int len, TaskCompletionSo /// packet parsing. /// /// - internal string DumpBuffer() + internal string DumpBuffer() { StringBuilder buffer = new StringBuilder(128); buffer.AppendLine("dumping buffer"); @@ -4381,7 +4378,7 @@ internal string DumpBuffer() int cc = 0; // character counter int i; buffer.AppendLine("used buffer:"); - for (i=0; i< _inBytesUsed; i++) + for (i=0; i< _inBytesUsed; i++) { if (cc==16) { buffer.AppendLine(); @@ -4390,16 +4387,16 @@ internal string DumpBuffer() buffer.AppendFormat("{0,-2:X2} ", _inBuff[i]); cc++; } - if (cc>0) + if (cc>0) { buffer.AppendLine(); } cc = 0; buffer.AppendLine("unused buffer:"); - for (i=_inBytesUsed; i<_inBytesRead; i++) + for (i=_inBytesUsed; i<_inBytesRead; i++) { - if (cc==16) + if (cc==16) { buffer.AppendLine(); cc = 0; @@ -4407,13 +4404,13 @@ internal string DumpBuffer() buffer.AppendFormat("{0,-2:X2} ", _inBuff[i]); cc++; } - if (cc>0) + if (cc>0) { buffer.AppendLine(); } return buffer.ToString(); } - + internal void SetSnapshot() { StateSnapshot snapshot = _snapshot; @@ -4444,7 +4441,7 @@ internal void ResetSnapshot() } /// - /// Returns true if the state object is in the state of continuing from a previously stored snapshot packet + /// Returns true if the state object is in the state of continuing from a previously stored snapshot packet /// meaning that consumers should resume from the point where they last needed more data instead of beginning /// to process packets in the snapshot from the beginning again /// @@ -4539,7 +4536,7 @@ internal int GetSnapshotPacketID() /// /// sets a value on the snapshot to allow the ContinueEnabled property to return true.
/// this function should be called only by functions that explicitly support the snapshot status - /// status + /// status ///
internal void RequestContinue(bool value) { @@ -4862,7 +4859,7 @@ partial void SetDebugDataHashImpl() { Hash = null; } - + } partial void CheckDebugDataHashImpl() diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs new file mode 100644 index 0000000000..3885d9da9e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -0,0 +1,256 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient.Utilities +{ + internal static class AsyncHelper + { + internal static void ContinueTask(Task task, + TaskCompletionSource completion, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null) + { + task.ContinueWith( + tsk => + { + if (tsk.Exception != null) + { + Exception exc = tsk.Exception.InnerException; + if (exceptionConverter != null) + { + exc = exceptionConverter(exc); + } + try + { + onFailure?.Invoke(exc); + } + finally + { + completion.TrySetException(exc); + } + } + else if (tsk.IsCanceled) + { + try + { + onCancellation?.Invoke(); + } + finally + { + completion.TrySetCanceled(); + } + } + else + { + try + { + onSuccess(); + } + // @TODO: CER Exception Handling was removed here (see GH#3581) + catch (Exception e) + { + completion.SetException(e); + } + } + }, TaskScheduler.Default + ); + } + + // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure + // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties + internal static void ContinueTaskWithState(Task task, + TaskCompletionSource completion, + object state, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null) + { + task.ContinueWith( + (Task tsk, object state2) => + { + if (tsk.Exception != null) + { + Exception exc = tsk.Exception.InnerException; + if (exceptionConverter != null) + { + exc = exceptionConverter(exc); + } + + try + { + onFailure?.Invoke(exc, state2); + } + finally + { + completion.TrySetException(exc); + } + } + else if (tsk.IsCanceled) + { + try + { + onCancellation?.Invoke(state2); + } + finally + { + completion.TrySetCanceled(); + } + } + else + { + try + { + onSuccess(state2); + } + // @TODO: CER Exception Handling was removed here (see GH#3581) + catch (Exception e) + { + completion.SetException(e); + } + } + }, + state: state, + scheduler: TaskScheduler.Default + ); + } + + internal static Task CreateContinuationTask( + Task task, + Action onSuccess, + Action onFailure = null) + { + if (task == null) + { + onSuccess(); + return null; + } + else + { + TaskCompletionSource completion = new TaskCompletionSource(); + ContinueTaskWithState( + task, + completion, + state: Tuple.Create(onSuccess, onFailure, completion), + onSuccess: static (object state) => + { + var parameters = (Tuple, TaskCompletionSource>)state; + Action success = parameters.Item1; + TaskCompletionSource taskCompletionSource = parameters.Item3; + success(); + taskCompletionSource.SetResult(null); + }, + onFailure: static (Exception exception, object state) => + { + var parameters = (Tuple, TaskCompletionSource>)state; + Action failure = parameters.Item2; + failure?.Invoke(exception); + } + ); + return completion.Task; + } + } + + internal static Task CreateContinuationTask( + Task task, + Action onSuccess, + T1 arg1, + T2 arg2, + Action onFailure = null) + { + return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); + } + + internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) + { + if (task == null) + { + onSuccess(state); + return null; + } + else + { + var completion = new TaskCompletionSource(); + ContinueTaskWithState(task, completion, state, + onSuccess: (object continueState) => + { + onSuccess(continueState); + completion.SetResult(null); + }, + onFailure: onFailure + ); + return completion.Task; + } + } + + internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) + { + if (timeout > 0) + { + Task.Delay(timeout * 1000, ctoken).ContinueWith( + (Task task) => + { + if (!task.IsCanceled && !completion.Task.IsCompleted) + { + completion.TrySetException(onFailure()); + } + } + ); + } + } + + internal static void SetTimeoutExceptionWithState( + TaskCompletionSource completion, + int timeout, + object state, + Func onFailure, + CancellationToken cancellationToken) + { + if (timeout <= 0) + { + return; + } + + Task.Delay(timeout * 1000, cancellationToken).ContinueWith( + (task, innerState) => + { + if (!task.IsCanceled && !completion.Task.IsCompleted) + { + completion.TrySetException(onFailure(innerState)); + } + }, + state: state, + cancellationToken: CancellationToken.None); + } + + internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) + { + try + { + task.Wait(timeout > 0 ? (1000 * timeout) : Timeout.Infinite); + } + catch (AggregateException ae) + { + if (rethrowExceptions) + { + Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); + ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); + } + } + if (!task.IsCompleted) + { + task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception + onTimeout?.Invoke(); + } + } + } +} From e20fc9b2dca6f61ee711b2f521ce9ee7cb9ce638 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 2 Jun 2026 13:57:31 -0500 Subject: [PATCH 2/9] Implement AsyncHelpers with generic state, add unit tests --- .../Data/SqlClient/Utilities/AsyncHelper.cs | 781 ++++++++-- .../tests/Directory.Packages.props | 1 + .../Microsoft.Data.SqlClient.UnitTests.csproj | 2 + .../SqlClient/Utilities/AsyncHelperTest.cs | 1353 +++++++++++++++++ .../UnitTests/Utilities/MockExtensions.cs | 64 + 5 files changed, 2081 insertions(+), 120 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 3885d9da9e..9b4080d62f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -7,250 +7,791 @@ using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Data.SqlClient.Internal; + +#nullable enable namespace Microsoft.Data.SqlClient.Utilities { + /// + /// Provides helpers for interacting with asynchronous tasks. + /// + /// + /// These helpers mainly provide continuation and timeout functionality. They utilize + /// at their core, and as such are fairly antiquated + /// implementations. If possible these methods should be utilized less and async/await native + /// constructs should be used. + /// internal static class AsyncHelper { - internal static void ContinueTask(Task task, - TaskCompletionSource completion, + /// + /// Continues a task and signals failure of the continuation via the provided + /// . + /// + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// helper will try to set an exception on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this current + /// continuation. + /// + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static void ContinueTask( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) + Action? onFailure = null, + Action? onCancellation = null) { - task.ContinueWith( - tsk => + ContinuationState continuationState = new ContinuationState( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (tsk, continuationState2) => { + ContinuationState typedState = (ContinuationState)continuationState2!; + if (tsk.Exception != null) { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } + Exception innerException = tsk.Exception.InnerException ?? tsk.Exception; try { - onFailure?.Invoke(exc); + typedState.OnFailure?.Invoke(innerException); } finally { - completion.TrySetException(exc); + typedState.TaskCompletionSource.TrySetException(innerException); } } else if (tsk.IsCanceled) { try { - onCancellation?.Invoke(); + typedState.OnCancellation?.Invoke(); } finally { - completion.TrySetCanceled(); + typedState.TaskCompletionSource.TrySetCanceled(); } } else { try { - onSuccess(); + typedState.OnSuccess(); } // @TODO: CER Exception Handling was removed here (see GH#3581) catch (Exception e) { - completion.SetException(e); + typedState.TaskCompletionSource.TrySetException(e); } } - }, TaskScheduler.Default - ); + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); } - // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure - // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties - internal static void ContinueTaskWithState(Task task, - TaskCompletionSource completion, - object state, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) + /// + /// Continues a task and signals failure of the continuation via the provided + /// . This overload provides a single state object + /// to the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// helper will try to set an exception on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this current + /// continuation. + /// + /// Type of the state object to provide to the callbacks + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// State object to provide to callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static void ContinueTaskWithState( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, + TState state, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) { - task.ContinueWith( - (Task tsk, object state2) => + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State: state, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => { - if (tsk.Exception != null) + ContinuationState typedState2 = (ContinuationState)continuationState2!; + + if (task.Exception is not null) + { + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState2.OnFailure?.Invoke(typedState2.State, innerException); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) + try + { + typedState2.OnCancellation?.Invoke(typedState2.State); + } + finally { - exc = exceptionConverter(exc); + typedState2.TaskCompletionSource.TrySetCanceled(); } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State); + // @TODO: The one unpleasant thing with this code is that the TCS is not set completed and left to the caller to do or not do (which is more unpleasant) + } + catch (Exception e) + { + typedState2.TaskCompletionSource.TrySetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + } + + /// + /// Continues a task and signals failure of the continuation via the provided + /// . This overload provides two state objects to + /// the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// helper will try to set an exception on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this + /// current continuation. + /// + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// Type of the first state object to provide to callbacks + /// Type of the second state object to provide to callbacks + /// First state object to provide to callbacks + /// Second state object to provide to callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static void ContinueTaskWithState( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, + TState1 state1, + TState2 state2, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) + { + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State1: state1, + State2: state2, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState2 = + (ContinuationState)continuationState2!; + + if (task.Exception is not null) + { + Exception innerException = task.Exception.InnerException ?? task.Exception; try { - onFailure?.Invoke(exc, state2); + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); } finally { - completion.TrySetException(exc); + typedState2.TaskCompletionSource.TrySetException(innerException); } } - else if (tsk.IsCanceled) + else if (task.IsCanceled) { try { - onCancellation?.Invoke(state2); + typedState2.OnCancellation?.Invoke(typedState2.State1, typedState2.State2); } finally { - completion.TrySetCanceled(); + typedState2.TaskCompletionSource.TrySetCanceled(); } } else { try { - onSuccess(state2); + typedState2.OnSuccess(typedState2.State1, typedState2.State2); } - // @TODO: CER Exception Handling was removed here (see GH#3581) catch (Exception e) { - completion.SetException(e); + typedState2.TaskCompletionSource.TrySetException(e); } } }, - state: state, - scheduler: TaskScheduler.Default - ); + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); } - internal static Task CreateContinuationTask( - Task task, + /// + /// Continues a task and returns the continuation task. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTask( + Task? taskToContinue, Action onSuccess, - Action onFailure = null) + Action? onFailure = null, + Action? onCancellation = null) { - if (task == null) + if (taskToContinue is null) { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. onSuccess(); return null; } - else - { - TaskCompletionSource completion = new TaskCompletionSource(); - ContinueTaskWithState( - task, - completion, - state: Tuple.Create(onSuccess, onFailure, completion), - onSuccess: static (object state) => + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState = (ContinuationState)continuationState2!; + if (task.Exception is not null) { - var parameters = (Tuple, TaskCompletionSource>)state; - Action success = parameters.Item1; - TaskCompletionSource taskCompletionSource = parameters.Item3; - success(); - taskCompletionSource.SetResult(null); - }, - onFailure: static (Exception exception, object state) => + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState.OnFailure?.Invoke(innerException); + } + finally + { + typedState.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) { - var parameters = (Tuple, TaskCompletionSource>)state; - Action failure = parameters.Item2; - failure?.Invoke(exception); + try + { + typedState.OnCancellation?.Invoke(); + } + finally + { + typedState.TaskCompletionSource.TrySetCanceled(); + } } - ); - return completion.Task; - } - } + else + { + try + { + typedState.OnSuccess(); + typedState.TaskCompletionSource.TrySetResult(null); + } + catch (Exception e) + { + typedState.TaskCompletionSource.TrySetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - T1 arg1, - T2 arg2, - Action onFailure = null) - { - return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + + return taskCompletionSource.Task; } - internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) + /// + /// Continues a task and returns the continuation task. This overload allows a state object + /// to be passed into the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// Type of the state object to pass to callbacks + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// State object to pass to the callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTaskWithState( + Task? taskToContinue, + TState state, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) { - if (task == null) + if (taskToContinue is null) { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. onSuccess(state); return null; } - else - { - var completion = new TaskCompletionSource(); - ContinueTaskWithState(task, completion, state, - onSuccess: (object continueState) => + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State: state, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState2 = (ContinuationState)continuationState2!; + + if (task.Exception is not null) { - onSuccess(continueState); - completion.SetResult(null); - }, - onFailure: onFailure - ); - return completion.Task; - } + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState2.OnFailure?.Invoke(typedState2.State, innerException); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State); + typedState2.TaskCompletionSource.TrySetResult(null); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.TrySetException(e); + } + } + + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + + return taskCompletionSource.Task; } - internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) + /// + /// Continues a task and returns the continuation task. This overload allows two state + /// objects to be passed into the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// Type of the first state object to pass to callbacks + /// Type of the second state object to pass to callbacks + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// First state object to pass to the callbacks + /// Second state object to pass to the callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTaskWithState( + Task? taskToContinue, + TState1 state1, + TState2 state2, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) { - if (timeout > 0) + if (taskToContinue is null) { - Task.Delay(timeout * 1000, ctoken).ContinueWith( - (Task task) => + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. + onSuccess(state1, state2); + return null; + } + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State1: state1, + State2: state2, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState2 = + (ContinuationState)continuationState2!; + + if (task.Exception is not null) { - if (!task.IsCanceled && !completion.Task.IsCompleted) + Exception innerException = task.Exception.InnerException ?? task.Exception; + try { - completion.TrySetException(onFailure()); + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(innerException); } } - ); + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State1, typedState2.State2); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State1, typedState2.State2); + typedState2.TaskCompletionSource.TrySetResult(null); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.TrySetException(e); + } + } + + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + + return taskCompletionSource.Task; + } + + /// + /// Executes a timeout task in parallel with the provided + /// . If the timeout completes before the task + /// completion source, the provided is executed and the + /// exception returned is set as the exception that completes the task completion source. + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// + /// Callback to execute when the task does not complete within the allotted time. The + /// exception returned by the callback is set on the . + /// + /// Cancellation token to prematurely cancel timeout + internal static void SetTimeoutException( + TaskCompletionSource taskCompletionSource, + int timeoutInSeconds, + Func onTimeout, + CancellationToken cancellationToken) + { + if (timeoutInSeconds <= 0) + { + return; } + + Task.Delay(TimeSpan.FromSeconds(timeoutInSeconds), cancellationToken) + .ContinueWith( + task => + { + // If the timeout ran to completion AND the task to complete did not complete + // then the timeout expired first, run the timeout handler + if (!task.IsCanceled && !taskCompletionSource.Task.IsCompleted) + { + taskCompletionSource.TrySetException(onTimeout()); + } + }, + cancellationToken: CancellationToken.None); } - internal static void SetTimeoutExceptionWithState( - TaskCompletionSource completion, - int timeout, - object state, - Func onFailure, + /// + /// Executes a timeout task in parallel with the provided + /// . If the timeout completes before the task + /// completion source, the provided is executed and the + /// exception returned is set as the exception that completes the task completion source. + /// This overload provides a state object to the timeout callback. + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// State object to pass to the callback + /// + /// Callback to execute when the task does not complete within the allotted time. The + /// exception returned by the callback is set on the . + /// + /// Cancellation token to prematurely cancel timeout + internal static void SetTimeoutExceptionWithState( + TaskCompletionSource taskCompletionSource, + int timeoutInSeconds, + TState state, + Func onTimeout, CancellationToken cancellationToken) { - if (timeout <= 0) + if (timeoutInSeconds <= 0) { return; } - Task.Delay(timeout * 1000, cancellationToken).ContinueWith( - (task, innerState) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) + Task.Delay(TimeSpan.FromSeconds(timeoutInSeconds), cancellationToken) + .ContinueWith( + (task, state2) => { - completion.TrySetException(onFailure(innerState)); - } - }, - state: state, - cancellationToken: CancellationToken.None); + // If the timeout ran to completion AND the task to complete did not complete + // then the timeout expired first, run the timeout handler + if (!task.IsCanceled && !taskCompletionSource.Task.IsCompleted) + { + taskCompletionSource.TrySetException(onTimeout((TState)state2!)); + } + }, + state: state, + cancellationToken: CancellationToken.None); } - internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) + /// + /// Waits for a maximum of seconds for completion of + /// the provided . + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// + /// Callback to execute when the task does not complete within the allotted time. + /// + /// + /// If true, the inner exception of any raised + /// during execution, including timeout of the task, will be rethrown. + /// + internal static void WaitForCompletion( + Task task, + int timeoutInSeconds, + Action? onTimeout = null, + bool rethrowExceptions = true) { try { - task.Wait(timeout > 0 ? (1000 * timeout) : Timeout.Infinite); + TimeSpan timeout = timeoutInSeconds > 0 + ? TimeSpan.FromSeconds(timeoutInSeconds) + : Timeout.InfiniteTimeSpan; + task.Wait(timeout); } catch (AggregateException ae) { if (rethrowExceptions) { + Debug.Assert(ae.InnerException is not null, "Inner exception is null"); Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); - ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); + ExceptionDispatchInfo.Capture(ae.InnerException!).Throw(); } } + if (!task.IsCompleted) { - task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception + // Ensure the task does not leave an unobserved exception + task.ContinueWith(static t => { _ = t.Exception; }); onTimeout?.Invoke(); } } + + /// + /// This method is intended to be used within the above helpers to ensure that any + /// exceptions thrown during callbacks do not go unobserved. If these exceptions were + /// to go unobserved, they will trigger events to be raised by the default task scheduler. + /// Neither situation is ideal: + /// * If an application assigns a listener to this event, it will generate events that + /// should be reported to us. But, because it happens outside the stack that caused the + /// exception, most of the context of the exception is lost. Furthermore, the event is + /// triggered when the GC runs, so the event happens asynchronous to the action that + /// caused it. + /// * Adding this forced observation of the exception prevents applications from receiving + /// the event, effectively swallowing it. + /// * However, if we log the exception when we observe it, we can still log that the + /// unobserved exception happened without causing undue disruption to the application + /// or leaking resources and causing overhead by raising the event. + /// + private static void ObserveContinuationException(Task continuationTask) + { + continuationTask.ContinueWith( + static task => + { + SqlClientEventSource.Log.TryTraceEvent($"Unobserved task exception: {task.Exception}"); + return _ = task.Exception; + }, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously); + } + + private record ContinuationState( + Action? OnCancellation, + Action? OnFailure, + Action OnSuccess, + TaskCompletionSource TaskCompletionSource); + + private record ContinuationState( + Action? OnCancellation, + Action? OnFailure, + Action OnSuccess, + TState State, + TaskCompletionSource TaskCompletionSource); + + private record ContinuationState( + Action? OnCancellation, + Action? OnFailure, + Action OnSuccess, + TState1 State1, + TState2 State2, + TaskCompletionSource TaskCompletionSource); } } diff --git a/src/Microsoft.Data.SqlClient/tests/Directory.Packages.props b/src/Microsoft.Data.SqlClient/tests/Directory.Packages.props index 7925952d11..4f3dbacb21 100644 --- a/src/Microsoft.Data.SqlClient/tests/Directory.Packages.props +++ b/src/Microsoft.Data.SqlClient/tests/Directory.Packages.props @@ -6,6 +6,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj index e09e49c967..12ea9c52d3 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj @@ -98,6 +98,7 @@ --> + @@ -115,6 +116,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs new file mode 100644 index 0000000000..7e6bb3f53d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -0,0 +1,1353 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient.UnitTests.Utilities; +using Microsoft.Data.SqlClient.Utilities; +using Moq; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient.Utilities +{ + public class AsyncHelperTest + { + // This timeout is set fairly high. The tests are expected to complete quickly, but are + // dependent on congestion of the thread pool. If the thread pool is congested, like on a + // full CI run, short timeouts may elapse even if the code under test would behave as + // expected. As such, we set a long timeout to ride out reasonable congestion on the + // thread pool, but still trigger a failure if the code under test hangs. + // @TODO: If suite-level timeouts are added, these timeouts can likely be removed. + private static readonly TimeSpan RunTimeout = TimeSpan.FromSeconds(30); + + #region ContinueTask + + [Fact] + public async Task ContinueTask_TaskCompletes() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Note: We have to set up mockOnSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock mockOnSuccess = new(); + mockOnSuccess.Setup(action => action()) + .Callback(() => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTask( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTask_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + + // - mockOnSuccess handler throws + Mock mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTask_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that is cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + + Mock mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(), Times.Once); + } + + [Fact] + public async Task ContinueTask_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue that is cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTask_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that is faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTask_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue that is cancelled + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + onFailure: null, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region ContinueTaskWithState + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCompletes() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Note: We have to set up mockOnSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(state1)) + .Callback(_ => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(It.IsAny())).Throws(); + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have faulted + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + onFailure: null, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region ContinueTaskWithState + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCompletes() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Note: We have to set up mockOnSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(state1, state2)) + .Callback((_, _) => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(o => o(It.IsAny(), It.IsAny())).Throws(); + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have faulted + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + // - mockOnSuccess was called with state obj + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1, state2), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + onFailure: null, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTask + + [Fact] + public void CreateContinuationTask_NullTask() + { + // Arrange + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue: null, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskCompletes() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // - mockOnSuccess handler throws + Mock mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTask_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + Mock> mockOnFailure = new(); + Mock mockOnSuccess = new(); + + Mock mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(), Times.Once); + } + + [Fact] + public async Task CreateContinuationTask_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = GetCancelledTask(); + Mock> mockOnFailure = new(); + Mock mockOnSuccess = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTask_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.FromException(new Exception()); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTaskWithState + + [Fact] + public void CreateContinuationTaskWithState_1Generic_NullTask() + { + // Arrange + const int state1 = 123; + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: null, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCompletes() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1), Times.Once); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTaskWithState + + [Fact] + public void CreateContinuationTaskWithState_2Generics_NullTask() + { + // Arrange + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: null, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCompletes() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1, state2), Times.Once); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region WaitForCompletion + + [Fact] + public void WaitForCompletion_DoesNotCreateUnobservedException() + { + // Arrange + // - Create a handler to capture any unhandled exception + Exception? unhandledException = null; + EventHandler handleUnobservedException = + (_, args) => unhandledException = args.Exception; + + // @TODO: Can we do this with a custom scheduler to avoid changing global state? + TaskScheduler.UnobservedTaskException += handleUnobservedException; + + try + { + // Act + // - Run task that will always time out + TaskCompletionSource tcs = new(); + AsyncHelper.WaitForCompletion( + tcs.Task, + timeoutInSeconds: 1, + onTimeout: null, + rethrowExceptions: true); + + // - Force collection of unobserved task + GC.Collect(); + GC.WaitForPendingFinalizers(); + + // Assert + // - Make sure no unobserved tasks happened + Assert.Null(unhandledException); + } + finally + { + // Cleanup + // - Remove the unobserved task handler + TaskScheduler.UnobservedTaskException -= handleUnobservedException; + } + } + + #endregion + + private static Task GetCancelledTask() + { + using CancellationTokenSource cts = new(); + cts.Cancel(); + + return Task.FromCanceled(cts.Token); + } + + private static TaskCompletionSource GetTaskCompletionSource() + => new(TaskCreationOptions.RunContinuationsAsynchronously); + + private static async Task RunWithTimeout([NotNull] Task? taskToRun, TimeSpan timeout) + { + if (taskToRun is null) + { + Assert.Fail("Expected non-null task for timeout"); + } + + Task winner = await Task.WhenAny(taskToRun, Task.Delay(timeout)); + if (winner != taskToRun) + { + Assert.Fail("Timeout elapsed."); + } + + // Force observation of any exception + _ = taskToRun.Exception; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs new file mode 100644 index 0000000000..fbb4e5fd41 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Moq; + +namespace Microsoft.Data.SqlClient.UnitTests.Utilities +{ + public static class MockExtensions + { + public static void SetupThrows(this Mock mock) + where TException : Exception, new() + { + mock.Setup(action => action()) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny())) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny(), It.IsAny())) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny(), It.IsAny(), It.IsAny())) + .Throws(); + } + + public static void VerifyNeverCalled(this Mock mock) => + mock.Verify(action => action(), Times.Never); + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny()), + Times.Never); + } + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny(), It.IsAny()), + Times.Never); + } + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + } + } +} From ab424fabd7e12eb9b93bd525453eda27dab7dc27 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 2 Jun 2026 17:13:05 -0500 Subject: [PATCH 3/9] Change existing async helper calls in SqlBulkCopy.cs --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 59 +++++++++++-------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index b5306bc0d8..485eed2108 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -2856,18 +2856,17 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults, task, source, state: this, - onSuccess: (object state) => + onSuccess: state => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); + Task continuedTask = state.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); if (continuedTask == null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + state.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }, - onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false), - onCancellation: static (object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: true)); + onFailure: static (state, _) => state.CopyBatchesAsyncContinuedOnError(cleanupParser: false), + onCancellation: static state => state.CopyBatchesAsyncContinuedOnError(cleanupParser: true)); return source.Task; } @@ -2918,24 +2917,23 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal writeTask, source, state: this, - onSuccess: (object state) => + onSuccess: state => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; try { - sqlBulkCopy.RunParser(); - sqlBulkCopy.CommitTransaction(); + state.RunParser(); + state.CommitTransaction(); } catch (Exception) { - sqlBulkCopy.CopyBatchesAsyncContinuedOnError(cleanupParser: false); + state.CopyBatchesAsyncContinuedOnError(cleanupParser: false); throw; } // Always call back into CopyBatchesAsync - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + state.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); }, - onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false)); + onFailure: static (state, _ ) => state.CopyBatchesAsyncContinuedOnError(cleanupParser: false)); return source.Task; } } @@ -3190,21 +3188,20 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection. AsyncHelper.SetTimeoutExceptionWithState( - completion: cancellableReconnectTS, - timeout: BulkCopyTimeout, + taskCompletionSource: cancellableReconnectTS, + timeoutInSeconds: BulkCopyTimeout, state: _destinationTableName, - onFailure: static state => - SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()), + onTimeout: static state => SQL.BulkLoadInvalidDestinationTable(state, SQL.CR_ReconnectTimeout()), cancellationToken: CancellationToken.None ); AsyncHelper.ContinueTaskWithState( - task: cancellableReconnectTS.Task, - completion: source, + taskToContinue:cancellableReconnectTS.Task, + taskCompletionSource: source, state: regReconnectCancel, - onSuccess: (object state) => + onSuccess: state => { - ((StrongBox)state).Value.Dispose(); + state.Value.Dispose(); if (_parserLock != null) { _parserLock.Release(); @@ -3214,10 +3211,22 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio _parserLock.Wait(canReleaseFromAnyThread: true); WriteToServerInternalRestAsync(cts, source); }, - onFailure: static (_, state) => ((StrongBox)state).Value.Dispose(), - onCancellation: static state => ((StrongBox)state).Value.Dispose(), - exceptionConverter: ex => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex) - ); + onFailure: (regReconnectCancel2, exception) => + { + regReconnectCancel2.Value.Dispose(); + + // Convert exception and set it on the source + // Note: This is safe because the helper will only try to set the + // exception and b/c it is already set will pass without setting + // to the original exception. + Exception convertedException = SQL.BulkLoadInvalidDestinationTable( + _destinationTableName, + exception); + source.TrySetException(convertedException); + }, + onCancellation: static regReconnectCancel2 => + regReconnectCancel2.Value.Dispose()); + return; } else From 1608dfda5582c95ec9942312e8655474d4c91ab9 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 2 Jun 2026 18:04:51 -0500 Subject: [PATCH 4/9] Change existing async helper calls in TdsParser.cs --- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index e0a5e7cd8a..23c2dfc9dc 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -12284,11 +12284,11 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } else { - return AsyncHelper.CreateContinuationTask( + return AsyncHelper.CreateContinuationTaskWithState( unterminatedWriteTask, - onSuccess: WriteInt, - arg1: 0, - arg2: stateObj); + state1: 0, + state2: stateObj, + onSuccess: WriteInt); } } else @@ -13245,11 +13245,11 @@ private Task WriteEncryptionMetadata(Task terminatedWriteTask, SqlColumnEncrypti else { // Otherwise, create a continuation task to write the encryption metadata after the previous write completes. - return AsyncHelper.CreateContinuationTask( + return AsyncHelper.CreateContinuationTaskWithState( terminatedWriteTask, - onSuccess: WriteEncryptionMetadata, - arg1: columnEncryptionParameterInfo, - arg2: stateObj); + state1: columnEncryptionParameterInfo, + state2: stateObj, + onSuccess: WriteEncryptionMetadata); } } From 80beeff8de905077bef896870f1a4aa9d9b4387b Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 2 Jun 2026 18:06:28 -0500 Subject: [PATCH 5/9] Change existing async helper calls in TdsParser.cs --- .../src/Microsoft/Data/SqlClient/TdsParserStateObject.cs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index a1bcec983f..97acdcbea9 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -1269,13 +1269,12 @@ internal Task ExecuteFlush() else { return AsyncHelper.CreateContinuationTaskWithState( - task: writePacketTask, + taskToContinue: writePacketTask, state: this, - onSuccess: static (object state) => + onSuccess: static state => { - TdsParserStateObject stateObject = (TdsParserStateObject)state; - stateObject.HasPendingData = true; - stateObject._messageStatus = 0; + state.HasPendingData = true; + state._messageStatus = 0; } ); } From 734fa77fc1c55c511b281db1e3d8102ba7f1ab8f Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Tue, 2 Jun 2026 18:18:04 -0500 Subject: [PATCH 6/9] Change existing async helper calls in SqlCommand --- .../Data/SqlClient/SqlCommand.Encryption.cs | 23 +++++------- .../Data/SqlClient/SqlCommand.NonQuery.cs | 12 +++--- .../Data/SqlClient/SqlCommand.Reader.cs | 37 +++++++++---------- .../Data/SqlClient/SqlCommand.Xml.cs | 4 +- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs index 9f95cbe8d0..3b0fc6e93f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs @@ -253,23 +253,22 @@ private SqlDataReader GetParameterEncryptionDataReader( bool isRetry) { returnTask = AsyncHelper.CreateContinuationTaskWithState( - task: fetchInputParameterEncryptionInfoTask, + taskToContinue: fetchInputParameterEncryptionInfoTask, state: this, - onSuccess: state => + onSuccess: sqlCommand => { - SqlCommand command = (SqlCommand)state; bool processFinallyBlockAsync = true; bool decrementAsyncCountInFinallyBlockAsync = true; try { // Check for any exceptions on network write, before reading. - command.CheckThrowSNIException(); + sqlCommand.CheckThrowSNIException(); // If it is async, then TryFetchInputParameterEncryptionInfo -> // RunExecuteReaderTds would have incremented the async count. Decrement it // when we are about to complete async execute reader. - SqlConnectionInternal internalConnectionTds = command._activeConnection.GetOpenTdsConnection(); + SqlConnectionInternal internalConnectionTds = sqlCommand._activeConnection.GetOpenTdsConnection(); if (internalConnectionTds is not null) { internalConnectionTds.DecrementAsyncCount(); @@ -278,13 +277,13 @@ private SqlDataReader GetParameterEncryptionDataReader( // Complete executereader. // @TODO: If we can remove this reference, this could be a static lambda - describeParameterEncryptionDataReader = command.CompleteAsyncExecuteReader( + describeParameterEncryptionDataReader = sqlCommand.CompleteAsyncExecuteReader( isInternal: false, forDescribeParameterEncryption: true); - Debug.Assert(command._stateObj is null, "non-null state object in PrepareForTransparentEncryption."); + Debug.Assert(sqlCommand._stateObj is null, "non-null state object in PrepareForTransparentEncryption."); // Read the results of describe parameter encryption. - command.ReadDescribeEncryptionParameterResults( + sqlCommand.ReadDescribeEncryptionParameterResults( describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, isRetry); @@ -304,7 +303,7 @@ private SqlDataReader GetParameterEncryptionDataReader( } finally { - command.PrepareTransparentEncryptionFinallyBlock( + sqlCommand.PrepareTransparentEncryptionFinallyBlock( closeDataReader: processFinallyBlockAsync, decrementAsyncCount: decrementAsyncCountInFinallyBlockAsync, clearDataStructures: processFinallyBlockAsync, @@ -313,11 +312,9 @@ private SqlDataReader GetParameterEncryptionDataReader( describeParameterEncryptionDataReader: describeParameterEncryptionDataReader); } }, - onFailure: static (exception, state) => + onFailure: static (sqlCommand, exception) => { - SqlCommand command = (SqlCommand)state; - command.CachedAsyncState?.ResetAsyncState(); - + sqlCommand.CachedAsyncState?.ResetAsyncState(); if (exception is not null) { throw exception; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs index 41602d392b..ac9ccbbc6b 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs @@ -223,13 +223,13 @@ private IAsyncResult BeginExecuteNonQueryInternal( if (execNonQuery is not null) { AsyncHelper.ContinueTaskWithState( - task: execNonQuery, - completion: localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => + taskToContinue: execNonQuery, + taskCompletionSource: localCompletion, + state1: this, + state2: localCompletion, + onSuccess: static (sqlCommand, localCompletion) => { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteNonQueryInternalReadStage(parameters.Item2); + sqlCommand.BeginExecuteNonQueryInternalReadStage(localCompletion); }); } else diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index 9a9615d5ff..bfdbd40fd5 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -1599,18 +1599,18 @@ private Task RunExecuteReaderTdsSetupContinuation( { // @TODO: Why use the state version if we can't make this a static helper? return AsyncHelper.CreateContinuationTaskWithState( - task: writeTask, + taskToContinue: writeTask, state: _activeConnection, - onSuccess: state => + onSuccess: sqlConnection => { // This will throw if the connection is closed. // @TODO: So... can we have something that specifically does that? - ((SqlConnection)state).GetOpenTdsConnection(); + sqlConnection.GetOpenTdsConnection(); CachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings); }, - onFailure: static (exception, state) => + onFailure: static (sqlConnection, _) => { - ((SqlConnection)state).GetOpenTdsConnection().DecrementAsyncCount(); + sqlConnection.GetOpenTdsConnection().DecrementAsyncCount(); }); } @@ -1632,7 +1632,7 @@ private void RunExecuteReaderTdsSetupReconnectContinuation( AsyncHelper.SetTimeoutException( completion, timeout, - onFailure: static () => SQL.CR_ReconnectTimeout(), + onTimeout: static () => SQL.CR_ReconnectTimeout(), timeoutCts.Token); // @TODO: With an object to pass around we can use the state-based version @@ -1703,14 +1703,13 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( // @TODO: This is a prime candidate for proper async-await execution TaskCompletionSource completion = new TaskCompletionSource(); AsyncHelper.ContinueTaskWithState( - task: describeParameterEncryptionTask, - completion: completion, + taskToContinue: describeParameterEncryptionTask, + taskCompletionSource: completion, state: this, - onSuccess: state => + onSuccess: sqlCommand => { - SqlCommand command = (SqlCommand)state; - command.GenerateEnclavePackage(); - command.RunExecuteReaderTds( + sqlCommand.GenerateEnclavePackage(); + sqlCommand.RunExecuteReaderTds( cmdBehavior, runBehavior, returnStream, @@ -1729,23 +1728,23 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( else { AsyncHelper.ContinueTaskWithState( - task: subTask, - completion: completion, + taskToContinue: subTask, + taskCompletionSource: completion, state: completion, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); } }, - onFailure: static (exception, state) => + onFailure: static (sqlCommand, exception) => { - ((SqlCommand)state).CachedAsyncState?.ResetAsyncState(); + sqlCommand.CachedAsyncState?.ResetAsyncState(); if (exception is not null) { throw exception; } }, - onCancellation: static state => + onCancellation: static sqlCommand => { - ((SqlCommand)state).CachedAsyncState?.ResetAsyncState(); + sqlCommand.CachedAsyncState?.ResetAsyncState(); }); task = completion.Task; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs index e39a794453..eba617ad59 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs @@ -262,8 +262,8 @@ private IAsyncResult BeginExecuteXmlReaderInternal( if (writeTask is not null) { AsyncHelper.ContinueTaskWithState( - task: writeTask, - completion: localCompletion, + taskToContinue: writeTask, + taskCompletionSource: localCompletion, state: Tuple.Create(this, localCompletion), onSuccess: static state => { From c0127ec52b96d32c4a69e0674a493a112b1aa83c Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Wed, 3 Jun 2026 12:19:29 -0500 Subject: [PATCH 7/9] Remove old tests --- .../tests/FunctionalTests/SqlHelperTest.cs | 62 ------------------- 1 file changed, 62 deletions(-) delete mode 100644 src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs deleted file mode 100644 index 44286b8c0e..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs +++ /dev/null @@ -1,62 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Xunit; - -namespace Microsoft.Data.SqlClient.Tests -{ - public class SqlHelperTest - { - private void TimeOutATask() - { - var sqlClientAssembly = Assembly.GetAssembly(typeof(SqlCommand)); - //We're using reflection to avoid exposing the internals - MethodInfo waitForCompletion = sqlClientAssembly.GetType("Microsoft.Data.SqlClient.AsyncHelper") - ?.GetMethod("WaitForCompletion", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.False(waitForCompletion == null, "Running a test on SqlUtil.WaitForCompletion but could not find this method"); - TaskCompletionSource tcs = new TaskCompletionSource(); - waitForCompletion.Invoke(null, new object[] { tcs.Task, 1, null, true }); //Will time out as task uncompleted - tcs.SetException(new TimeoutException("Dummy timeout exception")); //Our task now completes with an error - } - - private Exception UnwrapException(Exception e) - { - return e?.InnerException != null ? UnwrapException(e.InnerException) : e; - } - - [Fact] - public void WaitForCompletion_DoesNotCreateUnobservedException() - { - var unobservedExceptionHappenedEvent = new AutoResetEvent(false); - Exception unhandledException = null; - void handleUnobservedException(object o, UnobservedTaskExceptionEventArgs a) - { unhandledException = a.Exception; unobservedExceptionHappenedEvent.Set(); } - - TaskScheduler.UnobservedTaskException += handleUnobservedException; - - try - { - TimeOutATask(); //Create the task in another function so the task has no reference remaining - GC.Collect(); //Force collection of unobserved task - GC.WaitForPendingFinalizers(); - - bool unobservedExceptionHappend = unobservedExceptionHappenedEvent.WaitOne(1); - if (unobservedExceptionHappend) //Save doing string interpolation in the happy case - { - var e = UnwrapException(unhandledException); - Assert.Fail($"Did not expect an unobserved exception, but found a {e?.GetType()} with message \"{e?.Message}\""); - } - } - finally - { - TaskScheduler.UnobservedTaskException -= handleUnobservedException; - } - } - } -} From e516e3b4a1b1cdf60e3ab33913c73501a70072e2 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Thu, 4 Jun 2026 12:57:04 -0500 Subject: [PATCH 8/9] Comments from copilot round 1 --- .../src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs | 6 ++++-- .../Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 9b4080d62f..926a3b1d5e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -768,9 +768,11 @@ private static void ObserveContinuationException(Task continuationTask) static task => { SqlClientEventSource.Log.TryTraceEvent($"Unobserved task exception: {task.Exception}"); - return _ = task.Exception; + _ = task.Exception; }, - TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously); + CancellationToken.None, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); } private record ContinuationState( diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs index 7e6bb3f53d..ed5ac3b62f 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -1304,9 +1304,13 @@ public void WaitForCompletion_DoesNotCreateUnobservedException() onTimeout: null, rethrowExceptions: true); + // - Task has timed out, simulate faulting task completion source + tcs.SetException(new Exception("late failure")); + // - Force collection of unobserved task GC.Collect(); GC.WaitForPendingFinalizers(); + GC.Collect(); // Assert // - Make sure no unobserved tasks happened From 8a3ff940ea09f0f0c53739c7575a22822f4aa9b1 Mon Sep 17 00:00:00 2001 From: Ben Russell Date: Fri, 5 Jun 2026 14:04:20 -0500 Subject: [PATCH 9/9] Address feedback from wraith - including constraining state types to class types. --- .../Data/SqlClient/SqlCommand.Xml.cs | 8 ++-- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 4 +- .../Data/SqlClient/Utilities/AsyncHelper.cs | 47 +++++++++++++++---- 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs index eba617ad59..447627375e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs @@ -264,11 +264,11 @@ private IAsyncResult BeginExecuteXmlReaderInternal( AsyncHelper.ContinueTaskWithState( taskToContinue: writeTask, taskCompletionSource: localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => + state1: this, + state2: localCompletion, + onSuccess: static (sqlCommand, localCompletion) => { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteXmlReaderInternalReadStage(parameters.Item2); + sqlCommand.BeginExecuteXmlReaderInternalReadStage(localCompletion); }); } else diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 23c2dfc9dc..13abb31a12 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -12286,9 +12286,9 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy { return AsyncHelper.CreateContinuationTaskWithState( unterminatedWriteTask, - state1: 0, + state1: this, state2: stateObj, - onSuccess: WriteInt); + onSuccess: static (parser, state) => parser.WriteInt(0, state)); } } else diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs index 926a3b1d5e..d41759fbce 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -137,7 +137,10 @@ internal static void ContinueTask( /// is to allow the task completion source to be continued even more after this current /// continuation. /// - /// Type of the state object to provide to the callbacks + /// + /// Type of the state object to provide to the callbacks, constrained to class types to + /// prevent accidental modification of pass-by-value types. + /// /// Task to continue with provided callbacks /// /// Completion source used to track completion of the continuation, see remarks for details @@ -153,6 +156,7 @@ internal static void ContinueTaskWithState( Action onSuccess, Action? onFailure = null, Action? onCancellation = null) + where TState : class { ContinuationState continuationState = new( OnCancellation: onCancellation, @@ -237,8 +241,14 @@ internal static void ContinueTaskWithState( /// /// Completion source used to track completion of the continuation, see remarks for details /// - /// Type of the first state object to provide to callbacks - /// Type of the second state object to provide to callbacks + /// + /// Type of the first state object to provide to the callbacks, constrained to class types + /// to prevent accidental modification of pass-by-value types. + /// + /// + /// Type of the second state object to provide to the callbacks, constrained to class types + /// to prevent accidental modification of pass-by-value types. + /// /// First state object to provide to callbacks /// Second state object to provide to callbacks /// Callback to execute on successful completion of the task @@ -252,6 +262,8 @@ internal static void ContinueTaskWithState( Action onSuccess, Action? onFailure = null, Action? onCancellation = null) + where TState1 : class + where TState2 : class { ContinuationState continuationState = new( OnCancellation: onCancellation, @@ -428,7 +440,10 @@ internal static void ContinueTaskWithState( /// task will be completed with the exception. /// * The task will be completed as successful. /// - /// Type of the state object to pass to callbacks + /// + /// Type of the state object to pass to callbacks, constrained to class types to prevent + /// accidental modification of pass-by-value types. + /// /// /// Task to continue with provided callbacks, if null, null will be returned. /// @@ -442,6 +457,7 @@ internal static void ContinueTaskWithState( Action onSuccess, Action? onFailure = null, Action? onCancellation = null) + where TState : class { if (taskToContinue is null) { @@ -533,8 +549,13 @@ internal static void ContinueTaskWithState( /// task will be completed with the exception. /// * The task will be completed as successful. /// - /// Type of the first state object to pass to callbacks - /// Type of the second state object to pass to callbacks + /// + /// Type of the first state object to pass to callbacks, constrained to class types to + /// prevent accidental modification of pass-by-value types. + /// + /// + /// Type of the second state object to pass to callbacks, constrained to class types to + /// prevent accidental modification of pass-by-value types. /// /// Task to continue with provided callbacks, if null, null will be returned. /// @@ -550,6 +571,8 @@ internal static void ContinueTaskWithState( Action onSuccess, Action? onFailure = null, Action? onCancellation = null) + where TState1 : class + where TState2 : class { if (taskToContinue is null) { @@ -667,6 +690,10 @@ internal static void SetTimeoutException( /// exception returned is set as the exception that completes the task completion source. /// This overload provides a state object to the timeout callback. /// + /// + /// Type of the state object to pass to callbacks, constrained to class types to prevent + /// accidental modification of pass-by-value types. + /// /// Task to execute with a timeout /// Number of seconds to wait until timing out the task /// State object to pass to the callback @@ -681,6 +708,7 @@ internal static void SetTimeoutExceptionWithState( TState state, Func onTimeout, CancellationToken cancellationToken) + where TState : class { if (timeoutInSeconds <= 0) { @@ -786,7 +814,8 @@ private record ContinuationState( Action? OnFailure, Action OnSuccess, TState State, - TaskCompletionSource TaskCompletionSource); + TaskCompletionSource TaskCompletionSource) + where TState : class; private record ContinuationState( Action? OnCancellation, @@ -794,6 +823,8 @@ private record ContinuationState( Action OnSuccess, TState1 State1, TState2 State2, - TaskCompletionSource TaskCompletionSource); + TaskCompletionSource TaskCompletionSource) + where TState1 : class + where TState2 : class; } }